CNNCreator.ftl 11.1 KB
Newer Older
Bernhard Rumpe's avatar
BR-sy    
Bernhard Rumpe committed
1
<#-- (c) https://github.com/MontiCore/monticore -->
2
3
4
import mxnet as mx
import logging
import os
Julian Treiber's avatar
Julian Treiber committed
5
import shutil
6
import warnings
7
import inspect
8

9
10
<#list tc.architecture.networkInstructions as networkInstruction>
from CNNNet_${tc.fullArchitectureName} import Net_${networkInstruction?index}
11
</#list>
12
13

class ${tc.fileNameWithoutEnding}:
nilsfreyer's avatar
nilsfreyer committed
14
15
    _model_dir_ = "model/${tc.componentName}/"
    _model_prefix_ = "model"
16

Nicola Gatto's avatar
Nicola Gatto committed
17
18
    def __init__(self):
        self.weight_initializer = mx.init.Normal()
19
        self.networks = {}
Julian Treiber's avatar
Julian Treiber committed
20
21
22
23
24
<#if (tc.weightsPath)??>
        self._weights_dir_ = "${tc.weightsPath}/"
<#else>
        self._weights_dir_ = None
</#if>
25

Julian Treiber's avatar
Julian Treiber committed
26
    def load(self, context):
27
28
29
30
31
        earliestLastEpoch = None

        for i, network in self.networks.items():
            lastEpoch = 0
            param_file = None
32
33
34
35
            if hasattr(network, 'episodic_sub_nets'):
                num_episodic_sub_nets = len(network.episodic_sub_nets)
                lastMemEpoch = [0]*num_episodic_sub_nets
                mem_files = [None]*num_episodic_sub_nets
36
37
38
39
40
41
42
43
44

            try:
                os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + "_newest-0000.params")
            except OSError:
                pass
            try:
                os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + "_newest-symbol.json")
            except OSError:
                pass
Julian Treiber's avatar
Julian Treiber committed
45

46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
            if hasattr(network, 'episodic_sub_nets'):
                try:
                    os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_sub_net_' + str(0) + "-0000.params")
                except OSError:
                    pass
                try:
                    os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_sub_net_' + str(0) + "-symbol.json")
                except OSError:
                    pass

                for j in range(len(network.episodic_sub_nets)):
                    try:
                        os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_sub_net_' + str(j+1) + "-0000.params")
                    except OSError:
                        pass
                    try:
                        os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_sub_net_' + str(j+1) + "-symbol.json")
                    except OSError:
                        pass
                    try:
                        os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_query_net_' + str(j+1) + "-0000.params")
                    except OSError:
                        pass
                    try:
                        os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_query_net_' + str(j+1) + "-symbol.json")
                    except OSError:
                        pass
                    try:
                        os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_loss' + "-0000.params")
                    except OSError:
                        pass
                    try:
                        os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_loss' + "-symbol.json")
                    except OSError:
                        pass
                    try:
                        os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + "_newest_episodic_memory_sub_net_" + str(j + 1) + "-0000")
                    except OSError:
                        pass

86
87
            if os.path.isdir(self._model_dir_):
                for file in os.listdir(self._model_dir_):
88
                    if ".params" in file and self._model_prefix_ + "_" + str(i) in file and not "loss" in file:
89
                        epochStr = file.replace(".params", "").replace(self._model_prefix_ + "_" + str(i) + "-", "")
90
                        epoch = int(epochStr)
91
                        if epoch >= lastEpoch:
92
93
                            lastEpoch = epoch
                            param_file = file
94
95
96
97
98
99
100
101
102
                    elif hasattr(network, 'episodic_sub_nets') and self._model_prefix_ + "_" + str(i) + "_episodic_memory_sub_net_" in file:
                        relMemPathInfo = file.replace(self._model_prefix_ + "_" + str(i) + "_episodic_memory_sub_net_", "").split("-")
                        memSubNet = int(relMemPathInfo[0])
                        memEpochStr = relMemPathInfo[1]
                        memEpoch = int(memEpochStr)
                        if memEpoch >= lastMemEpoch[memSubNet-1]:
                            lastMemEpoch[memSubNet-1] = memEpoch
                            mem_files[memSubNet-1] = file

103
104
105
106
107
            if param_file is None:
                earliestLastEpoch = 0
            else:
                logging.info("Loading checkpoint: " + param_file)
                network.load_parameters(self._model_dir_ + param_file)
108
109
110
111
112
113
                if hasattr(network, 'episodic_sub_nets'):
                    for j, sub_net in enumerate(network.episodic_sub_nets):
                        if mem_files[j] != None:
                            logging.info("Loading Replay Memory: " + mem_files[j])
                            mem_layer = [param for param in inspect.getmembers(sub_net, lambda x: not(inspect.isroutine(x))) if param[0].startswith("memory")][0][1]
                            mem_layer.load_memory(self._model_dir_ + mem_files[j])
114

115
116
                if earliestLastEpoch == None or lastEpoch + 1 < earliestLastEpoch:
                    earliestLastEpoch = lastEpoch + 1
117
118

        return earliestLastEpoch
119

Julian Treiber's avatar
Julian Treiber committed
120
121
122
123
124
125
126
    def load_pretrained_weights(self, context):
        if os.path.isdir(self._model_dir_):
            shutil.rmtree(self._model_dir_)
        if self._weights_dir_ is not None:
            for i, network in self.networks.items():
                # param_file = self._model_prefix_ + "_" + str(i) + "_newest-0000.params"
                param_file = None
127
128
129
130
131
                if hasattr(network, 'episodic_sub_nets'):
                    num_episodic_sub_nets = len(network.episodic_sub_nets)
                    lastMemEpoch = [0] * num_episodic_sub_nets
                    mem_files = [None] * num_episodic_sub_nets

Julian Treiber's avatar
Julian Treiber committed
132
133
134
135
136
                if os.path.isdir(self._weights_dir_):
                    lastEpoch = 0

                    for file in os.listdir(self._weights_dir_):

137
                        if ".params" in file and self._model_prefix_ + "_" + str(i) in file and not "loss" in file:
Julian Treiber's avatar
Julian Treiber committed
138
139
                            epochStr = file.replace(".params","").replace(self._model_prefix_ + "_" + str(i) + "-","")
                            epoch = int(epochStr)
140
                            if epoch >= lastEpoch:
Julian Treiber's avatar
Julian Treiber committed
141
142
                                lastEpoch = epoch
                                param_file = file
143
144
145
146
147
148
149
150
151
                        elif hasattr(network, 'episodic_sub_nets') and self._model_prefix_ + "_" + str(i) + "_episodic_memory_sub_net_" in file:
                            relMemPathInfo = file.replace(self._model_prefix_ + "_" + str(i) + "_episodic_memory_sub_net_").split("-")
                            memSubNet = int(relMemPathInfo[0])
                            memEpochStr = relMemPathInfo[1]
                            memEpoch = int(memEpochStr)
                            if memEpoch >= lastMemEpoch[memSubNet-1]:
                                lastMemEpoch[memSubNet-1] = memEpoch
                                mem_files[memSubNet-1] = file

Julian Treiber's avatar
Julian Treiber committed
152
                    logging.info("Loading pretrained weights: " + self._weights_dir_ + param_file)
153
                    network.load_parameters(self._weights_dir_ + param_file, allow_missing=True, ignore_extra=True)
154
155
156
157
158
159
160
161
162
                    if hasattr(network, 'episodic_sub_nets'):
                        assert lastEpoch == lastMemEpoch
                        for j, sub_net in enumerate(network.episodic_sub_nets):
                            if mem_files[j] != None:
                                logging.info("Loading pretrained Replay Memory: " + mem_files[j])
                                mem_layer = \
                                [param for param in inspect.getmembers(sub_net, lambda x: not (inspect.isroutine(x))) if
                                 param[0].startswith("memory")][0][1]
                                mem_layer.load_memory(self._model_dir_ + mem_files[j])
Julian Treiber's avatar
Julian Treiber committed
163
164
                else:
                    logging.info("No pretrained weights available at: " + self._weights_dir_ + param_file)
165

166
    def construct(self, context, data_mean=None, data_std=None):
167
<#list tc.architecture.networkInstructions as networkInstruction>
168
        self.networks[${networkInstruction?index}] = Net_${networkInstruction?index}(data_mean=data_mean, data_std=data_std, mx_context=context, prefix="")
169
170
171
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            self.networks[${networkInstruction?index}].collect_params().initialize(self.weight_initializer, force_reinit=False, ctx=context)
172
        self.networks[${networkInstruction?index}].hybridize()
173
        self.networks[${networkInstruction?index}](<#list tc.getStreamInputDimensions(networkInstruction.body) as dimensions><#if tc.cutDimensions(dimensions)[tc.cutDimensions(dimensions)?size-1] == "1" && tc.cutDimensions(dimensions)?size != 1>mx.nd.zeros((${tc.join(tc.cutDimensions(dimensions), ",")},), ctx=context[0])<#else>mx.nd.zeros((1, ${tc.join(tc.cutDimensions(dimensions), ",")},), ctx=context[0])</#if><#sep>, </#list>)
174
<#if networkInstruction.body.episodicSubNetworks?has_content>
175
        self.networks[0].episodicsubnet0_(<#list tc.getStreamInputDimensions(networkInstruction.body) as dimensions><#if tc.cutDimensions(dimensions)[tc.cutDimensions(dimensions)?size-1] == "1" && tc.cutDimensions(dimensions)?size != 1>mx.nd.zeros((${tc.join(tc.cutDimensions(dimensions), ",")},), ctx=context[0])<#else>mx.nd.zeros((1, ${tc.join(tc.cutDimensions(dimensions), ",")},), ctx=context[0])</#if><#sep>, </#list>)
176
</#if>
177
</#list>
Nicola Gatto's avatar
Nicola Gatto committed
178
179
180
181

        if not os.path.exists(self._model_dir_):
            os.makedirs(self._model_dir_)

182
183
        for i, network in self.networks.items():
            network.export(self._model_dir_ + self._model_prefix_ + "_" + str(i), epoch=0)
Julian Dierkes's avatar
Julian Dierkes committed
184

185
186
187
    def setWeightInitializer(self, initializer):
        self.weight_initializer = initializer

Julian Dierkes's avatar
Julian Dierkes committed
188
189
190
191
192
193
    def getInputs(self):
        inputs = {}
<#list tc.architecture.streams as stream>
<#assign dimensions = (tc.getStreamInputs(stream, false))>
<#assign domains = (tc.getStreamInputDomains(stream))>
<#list tc.getStreamInputVariableNames(stream, false) as name>
Julian Dierkes's avatar
Julian Dierkes committed
194
195
        input_dimensions = (${tc.join(dimensions[name], ",")},)
        input_domains = (${tc.join(domains[name], ",")},)
Julian Dierkes's avatar
Julian Dierkes committed
196
197
198
199
200
201
202
203
204
205
206
        inputs["${name}"] = input_domains + (input_dimensions,)
</#list>
</#list>
        return inputs

    def getOutputs(self):
        outputs = {}
<#list tc.architecture.streams as stream>
<#assign dimensions = (tc.getStreamOutputs(stream, false))>
<#assign domains = (tc.getStreamOutputDomains(stream))>
<#list tc.getStreamOutputVariableNames(stream, false) as name>
Julian Dierkes's avatar
Julian Dierkes committed
207
208
        output_dimensions = (${tc.join(dimensions[name], ",")},)
        output_domains = (${tc.join(domains[name], ",")},)
Julian Dierkes's avatar
Julian Dierkes committed
209
210
211
212
        outputs["${name}"] = output_domains + (output_dimensions,)
</#list>
</#list>
        return outputs