CNNCreator_VGG16.py 9.31 KB
Newer Older
1
2
3
import mxnet as mx
import logging
import os
Julian Treiber's avatar
Julian Treiber committed
4
import shutil
5
import warnings
6
import inspect
danielkisov's avatar
danielkisov committed
7
import sys
8

9
from CNNNet_VGG16 import Net_0
10
11
12

class CNNCreator_VGG16:
    _model_dir_ = "model/VGG16/"
nilsfreyer's avatar
nilsfreyer committed
13
    _model_prefix_ = "model"
14

Nicola Gatto's avatar
Nicola Gatto committed
15
16
    def __init__(self):
        self.weight_initializer = mx.init.Normal()
17
        self.networks = {}
Julian Treiber's avatar
Julian Treiber committed
18
        self._weights_dir_ = None
19

20
    def load(self, context):
21
22
23
24
25
        earliestLastEpoch = None

        for i, network in self.networks.items():
            lastEpoch = 0
            param_file = None
26
27
28
29
            if hasattr(network, 'episodic_sub_nets'):
                num_episodic_sub_nets = len(network.episodic_sub_nets)
                lastMemEpoch = [0]*num_episodic_sub_nets
                mem_files = [None]*num_episodic_sub_nets
30
31
32
33
34
35
36
37
38
39

            try:
                os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + "_newest-0000.params")
            except OSError:
                pass
            try:
                os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + "_newest-symbol.json")
            except OSError:
                pass

40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
            if hasattr(network, 'episodic_sub_nets'):
                try:
                    os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_sub_net_' + str(0) + "-0000.params")
                except OSError:
                    pass
                try:
                    os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_sub_net_' + str(0) + "-symbol.json")
                except OSError:
                    pass

                for j in range(len(network.episodic_sub_nets)):
                    try:
                        os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_sub_net_' + str(j+1) + "-0000.params")
                    except OSError:
                        pass
                    try:
                        os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_sub_net_' + str(j+1) + "-symbol.json")
                    except OSError:
                        pass
                    try:
                        os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_query_net_' + str(j+1) + "-0000.params")
                    except OSError:
                        pass
                    try:
                        os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_query_net_' + str(j+1) + "-symbol.json")
                    except OSError:
                        pass
                    try:
                        os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_loss' + "-0000.params")
                    except OSError:
                        pass
                    try:
                        os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_loss' + "-symbol.json")
                    except OSError:
                        pass
                    try:
                        os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + "_newest_episodic_memory_sub_net_" + str(j + 1) + "-0000")
                    except OSError:
                        pass

80
81
            if os.path.isdir(self._model_dir_):
                for file in os.listdir(self._model_dir_):
82
83
                    if ".params" in file and self._model_prefix_ + "_" + str(i) in file and not "loss" in file:
                        epochStr = file.replace(".params", "").replace(self._model_prefix_ + "_" + str(i) + "-", "")
84
                        epoch = int(epochStr)
85
                        if epoch >= lastEpoch:
86
87
                            lastEpoch = epoch
                            param_file = file
88
89
90
91
92
93
94
95
96
                    elif hasattr(network, 'episodic_sub_nets') and self._model_prefix_ + "_" + str(i) + "_episodic_memory_sub_net_" in file:
                        relMemPathInfo = file.replace(self._model_prefix_ + "_" + str(i) + "_episodic_memory_sub_net_", "").split("-")
                        memSubNet = int(relMemPathInfo[0])
                        memEpochStr = relMemPathInfo[1]
                        memEpoch = int(memEpochStr)
                        if memEpoch >= lastMemEpoch[memSubNet-1]:
                            lastMemEpoch[memSubNet-1] = memEpoch
                            mem_files[memSubNet-1] = file

97
98
99
100
101
            if param_file is None:
                earliestLastEpoch = 0
            else:
                logging.info("Loading checkpoint: " + param_file)
                network.load_parameters(self._model_dir_ + param_file)
102
103
104
105
106
107
                if hasattr(network, 'episodic_sub_nets'):
                    for j, sub_net in enumerate(network.episodic_sub_nets):
                        if mem_files[j] != None:
                            logging.info("Loading Replay Memory: " + mem_files[j])
                            mem_layer = [param for param in inspect.getmembers(sub_net, lambda x: not(inspect.isroutine(x))) if param[0].startswith("memory")][0][1]
                            mem_layer.load_memory(self._model_dir_ + mem_files[j])
108

109
110
                if earliestLastEpoch == None or lastEpoch + 1 < earliestLastEpoch:
                    earliestLastEpoch = lastEpoch + 1
111
112

        return earliestLastEpoch
113

Julian Treiber's avatar
Julian Treiber committed
114
115
116
117
118
119
120
    def load_pretrained_weights(self, context):
        if os.path.isdir(self._model_dir_):
            shutil.rmtree(self._model_dir_)
        if self._weights_dir_ is not None:
            for i, network in self.networks.items():
                # param_file = self._model_prefix_ + "_" + str(i) + "_newest-0000.params"
                param_file = None
121
122
123
124
125
                if hasattr(network, 'episodic_sub_nets'):
                    num_episodic_sub_nets = len(network.episodic_sub_nets)
                    lastMemEpoch = [0] * num_episodic_sub_nets
                    mem_files = [None] * num_episodic_sub_nets

Julian Treiber's avatar
Julian Treiber committed
126
127
128
129
130
                if os.path.isdir(self._weights_dir_):
                    lastEpoch = 0

                    for file in os.listdir(self._weights_dir_):

131
                        if ".params" in file and self._model_prefix_ + "_" + str(i) in file and not "loss" in file:
Julian Treiber's avatar
Julian Treiber committed
132
133
                            epochStr = file.replace(".params","").replace(self._model_prefix_ + "_" + str(i) + "-","")
                            epoch = int(epochStr)
134
                            if epoch >= lastEpoch:
Julian Treiber's avatar
Julian Treiber committed
135
136
                                lastEpoch = epoch
                                param_file = file
137
138
139
140
141
142
143
144
145
                        elif hasattr(network, 'episodic_sub_nets') and self._model_prefix_ + "_" + str(i) + "_episodic_memory_sub_net_" in file:
                            relMemPathInfo = file.replace(self._model_prefix_ + "_" + str(i) + "_episodic_memory_sub_net_").split("-")
                            memSubNet = int(relMemPathInfo[0])
                            memEpochStr = relMemPathInfo[1]
                            memEpoch = int(memEpochStr)
                            if memEpoch >= lastMemEpoch[memSubNet-1]:
                                lastMemEpoch[memSubNet-1] = memEpoch
                                mem_files[memSubNet-1] = file

Julian Treiber's avatar
Julian Treiber committed
146
147
                    logging.info("Loading pretrained weights: " + self._weights_dir_ + param_file)
                    network.load_parameters(self._weights_dir_ + param_file, allow_missing=True, ignore_extra=True)
148
149
150
151
152
153
154
155
156
                    if hasattr(network, 'episodic_sub_nets'):
                        assert lastEpoch == lastMemEpoch
                        for j, sub_net in enumerate(network.episodic_sub_nets):
                            if mem_files[j] != None:
                                logging.info("Loading pretrained Replay Memory: " + mem_files[j])
                                mem_layer = \
                                [param for param in inspect.getmembers(sub_net, lambda x: not (inspect.isroutine(x))) if
                                 param[0].startswith("memory")][0][1]
                                mem_layer.load_memory(self._model_dir_ + mem_files[j])
Julian Treiber's avatar
Julian Treiber committed
157
158
                else:
                    logging.info("No pretrained weights available at: " + self._weights_dir_ + param_file)
159

160
    def construct(self, context, data_mean=None, data_std=None):
161
        self.networks[0] = Net_0(data_mean=data_mean, data_std=data_std, mx_context=context, prefix="")
162
163
164
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            self.networks[0].collect_params().initialize(self.weight_initializer, force_reinit=False, ctx=context)
165
        self.networks[0].hybridize()
166
        self.networks[0](mx.nd.zeros((1, 3,224,224,), ctx=context[0]))
167

Nicola Gatto's avatar
Nicola Gatto committed
168
169
170
        if not os.path.exists(self._model_dir_):
            os.makedirs(self._model_dir_)

171
172
        for i, network in self.networks.items():
            network.export(self._model_dir_ + self._model_prefix_ + "_" + str(i), epoch=0)
Julian Dierkes's avatar
Julian Dierkes committed
173

174
175
176
    def setWeightInitializer(self, initializer):
        self.weight_initializer = initializer

Julian Dierkes's avatar
Julian Dierkes committed
177
178
179
180
181
182
183
184
185
186
187
188
189
    def getInputs(self):
        inputs = {}
        input_dimensions = (3,224,224,)
        input_domains = (int,0.0,255.0,)
        inputs["data_"] = input_domains + (input_dimensions,)
        return inputs

    def getOutputs(self):
        outputs = {}
        output_dimensions = (1000,1,1,)
        output_domains = (float,0.0,1.0,)
        outputs["predictions_"] = output_domains + (output_dimensions,)
        return outputs
danielkisov's avatar
danielkisov committed
190
191
192
193

    def validate_parameters(self):

            pass