CNNCreator.ftl 12 KB
Newer Older
Bernhard Rumpe's avatar
BR-sy    
Bernhard Rumpe committed
1
<#-- (c) https://github.com/MontiCore/monticore -->
Jean Meurice's avatar
Jean Meurice committed
2
3
<#-- So that the license is in the generated file: -->
# (c) https://github.com/MontiCore/monticore
4
5
6
import mxnet as mx
import logging
import os
Julian Treiber's avatar
Julian Treiber committed
7
import shutil
8
import warnings
9
import inspect
10
11
import sys

12
<#if tc.architecture.customPyFilesPath??>
13
14
sys.path.insert(1, '${tc.architecture.customPyFilesPath}')
from custom_layers import *
15
</#if>
16

17
<#list tc.architecture.networkInstructions as networkInstruction>
lr119628's avatar
lr119628 committed
18
<#if tc.containsAdaNet()>
lr119628's avatar
lr119628 committed
19
from CNNNet_${tc.fullArchitectureName} import Net_${networkInstruction?index},DataClass_${networkInstruction?index}
lr119628's avatar
lr119628 committed
20
21
<#else>
from CNNNet_${tc.fullArchitectureName} import Net_${networkInstruction?index}
lr119628's avatar
lr119628 committed
22
</#if>
lr119628's avatar
lr119628 committed
23
24
</#list>

25
class ${tc.fileNameWithoutEnding}:
nilsfreyer's avatar
nilsfreyer committed
26
27
    _model_dir_ = "model/${tc.componentName}/"
    _model_prefix_ = "model"
28

Nicola Gatto's avatar
Nicola Gatto committed
29
30
    def __init__(self):
        self.weight_initializer = mx.init.Normal()
31
        self.networks = {}
32
        <#if tc.containsAdaNet()>
lr119628's avatar
lr119628 committed
33
        self.dataClass = {}
34
        </#if>
Julian Treiber's avatar
Julian Treiber committed
35
36
37
38
39
<#if (tc.weightsPath)??>
        self._weights_dir_ = "${tc.weightsPath}/"
<#else>
        self._weights_dir_ = None
</#if>
40

Julian Treiber's avatar
Julian Treiber committed
41
    def load(self, context):
42
43
44
45
46
        earliestLastEpoch = None

        for i, network in self.networks.items():
            lastEpoch = 0
            param_file = None
47
48
49
50
            if hasattr(network, 'episodic_sub_nets'):
                num_episodic_sub_nets = len(network.episodic_sub_nets)
                lastMemEpoch = [0]*num_episodic_sub_nets
                mem_files = [None]*num_episodic_sub_nets
51
52
53
54
55
56
57
58
59

            try:
                os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + "_newest-0000.params")
            except OSError:
                pass
            try:
                os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + "_newest-symbol.json")
            except OSError:
                pass
Julian Treiber's avatar
Julian Treiber committed
60

61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
            if hasattr(network, 'episodic_sub_nets'):
                try:
                    os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_sub_net_' + str(0) + "-0000.params")
                except OSError:
                    pass
                try:
                    os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_sub_net_' + str(0) + "-symbol.json")
                except OSError:
                    pass

                for j in range(len(network.episodic_sub_nets)):
                    try:
                        os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_sub_net_' + str(j+1) + "-0000.params")
                    except OSError:
                        pass
                    try:
                        os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_sub_net_' + str(j+1) + "-symbol.json")
                    except OSError:
                        pass
                    try:
                        os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_query_net_' + str(j+1) + "-0000.params")
                    except OSError:
                        pass
                    try:
                        os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_query_net_' + str(j+1) + "-symbol.json")
                    except OSError:
                        pass
                    try:
                        os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_loss' + "-0000.params")
                    except OSError:
                        pass
                    try:
                        os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_loss' + "-symbol.json")
                    except OSError:
                        pass
                    try:
                        os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + "_newest_episodic_memory_sub_net_" + str(j + 1) + "-0000")
                    except OSError:
                        pass

101
102
            if os.path.isdir(self._model_dir_):
                for file in os.listdir(self._model_dir_):
103
                    if ".params" in file and self._model_prefix_ + "_" + str(i) in file and not "loss" in file:
104
                        epochStr = file.replace(".params", "").replace(self._model_prefix_ + "_" + str(i) + "-", "")
105
                        epoch = int(epochStr)
106
                        if epoch >= lastEpoch:
107
108
                            lastEpoch = epoch
                            param_file = file
109
110
111
112
113
114
115
116
117
                    elif hasattr(network, 'episodic_sub_nets') and self._model_prefix_ + "_" + str(i) + "_episodic_memory_sub_net_" in file:
                        relMemPathInfo = file.replace(self._model_prefix_ + "_" + str(i) + "_episodic_memory_sub_net_", "").split("-")
                        memSubNet = int(relMemPathInfo[0])
                        memEpochStr = relMemPathInfo[1]
                        memEpoch = int(memEpochStr)
                        if memEpoch >= lastMemEpoch[memSubNet-1]:
                            lastMemEpoch[memSubNet-1] = memEpoch
                            mem_files[memSubNet-1] = file

118
119
120
121
122
            if param_file is None:
                earliestLastEpoch = 0
            else:
                logging.info("Loading checkpoint: " + param_file)
                network.load_parameters(self._model_dir_ + param_file)
123
124
125
126
127
128
                if hasattr(network, 'episodic_sub_nets'):
                    for j, sub_net in enumerate(network.episodic_sub_nets):
                        if mem_files[j] != None:
                            logging.info("Loading Replay Memory: " + mem_files[j])
                            mem_layer = [param for param in inspect.getmembers(sub_net, lambda x: not(inspect.isroutine(x))) if param[0].startswith("memory")][0][1]
                            mem_layer.load_memory(self._model_dir_ + mem_files[j])
129

130
131
                if earliestLastEpoch == None or lastEpoch + 1 < earliestLastEpoch:
                    earliestLastEpoch = lastEpoch + 1
132
133

        return earliestLastEpoch
134

Julian Treiber's avatar
Julian Treiber committed
135
136
137
138
139
140
141
    def load_pretrained_weights(self, context):
        if os.path.isdir(self._model_dir_):
            shutil.rmtree(self._model_dir_)
        if self._weights_dir_ is not None:
            for i, network in self.networks.items():
                # param_file = self._model_prefix_ + "_" + str(i) + "_newest-0000.params"
                param_file = None
142
143
144
145
146
                if hasattr(network, 'episodic_sub_nets'):
                    num_episodic_sub_nets = len(network.episodic_sub_nets)
                    lastMemEpoch = [0] * num_episodic_sub_nets
                    mem_files = [None] * num_episodic_sub_nets

Julian Treiber's avatar
Julian Treiber committed
147
148
149
150
151
                if os.path.isdir(self._weights_dir_):
                    lastEpoch = 0

                    for file in os.listdir(self._weights_dir_):

152
                        if ".params" in file and self._model_prefix_ + "_" + str(i) in file and not "loss" in file:
Julian Treiber's avatar
Julian Treiber committed
153
154
                            epochStr = file.replace(".params","").replace(self._model_prefix_ + "_" + str(i) + "-","")
                            epoch = int(epochStr)
155
                            if epoch >= lastEpoch:
Julian Treiber's avatar
Julian Treiber committed
156
157
                                lastEpoch = epoch
                                param_file = file
158
159
160
161
162
163
164
165
166
                        elif hasattr(network, 'episodic_sub_nets') and self._model_prefix_ + "_" + str(i) + "_episodic_memory_sub_net_" in file:
                            relMemPathInfo = file.replace(self._model_prefix_ + "_" + str(i) + "_episodic_memory_sub_net_").split("-")
                            memSubNet = int(relMemPathInfo[0])
                            memEpochStr = relMemPathInfo[1]
                            memEpoch = int(memEpochStr)
                            if memEpoch >= lastMemEpoch[memSubNet-1]:
                                lastMemEpoch[memSubNet-1] = memEpoch
                                mem_files[memSubNet-1] = file

Julian Treiber's avatar
Julian Treiber committed
167
                    logging.info("Loading pretrained weights: " + self._weights_dir_ + param_file)
168
                    network.load_parameters(self._weights_dir_ + param_file, allow_missing=True, ignore_extra=True)
169
170
171
172
173
174
175
176
177
                    if hasattr(network, 'episodic_sub_nets'):
                        assert lastEpoch == lastMemEpoch
                        for j, sub_net in enumerate(network.episodic_sub_nets):
                            if mem_files[j] != None:
                                logging.info("Loading pretrained Replay Memory: " + mem_files[j])
                                mem_layer = \
                                [param for param in inspect.getmembers(sub_net, lambda x: not (inspect.isroutine(x))) if
                                 param[0].startswith("memory")][0][1]
                                mem_layer.load_memory(self._model_dir_ + mem_files[j])
Julian Treiber's avatar
Julian Treiber committed
178
179
                else:
                    logging.info("No pretrained weights available at: " + self._weights_dir_ + param_file)
180

181
    def construct(self, context, data_mean=None, data_std=None):
182
<#list tc.architecture.networkInstructions as networkInstruction>
183
        <#if tc.containsAdaNet()>
184
        self.networks[${networkInstruction?index}] = Net_${networkInstruction?index}()
185
        self.dataClass[${networkInstruction?index}] = DataClass_${networkInstruction?index}
186
        <#else>
187
        self.networks[${networkInstruction?index}] = Net_${networkInstruction?index}(data_mean=data_mean, data_std=data_std, mx_context=context, prefix="")
188
        </#if>
189
190
191
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            self.networks[${networkInstruction?index}].collect_params().initialize(self.weight_initializer, force_reinit=False, ctx=context)
192
        self.networks[${networkInstruction?index}].hybridize()
193
        self.networks[${networkInstruction?index}](<#list tc.getStreamInputDimensions(networkInstruction.body) as dimensions><#if tc.cutDimensions(dimensions)[tc.cutDimensions(dimensions)?size-1] == "1" && tc.cutDimensions(dimensions)?size != 1>mx.nd.zeros((${tc.join(tc.cutDimensions(dimensions), ",")},), ctx=context[0])<#else>mx.nd.zeros((1, ${tc.join(tc.cutDimensions(dimensions), ",")},), ctx=context[0])</#if><#sep>, </#list>)
194
<#if networkInstruction.body.episodicSubNetworks?has_content>
195
        self.networks[0].episodicsubnet0_(<#list tc.getStreamInputDimensions(networkInstruction.body) as dimensions><#if tc.cutDimensions(dimensions)[tc.cutDimensions(dimensions)?size-1] == "1" && tc.cutDimensions(dimensions)?size != 1>mx.nd.zeros((${tc.join(tc.cutDimensions(dimensions), ",")},), ctx=context[0])<#else>mx.nd.zeros((1, ${tc.join(tc.cutDimensions(dimensions), ",")},), ctx=context[0])</#if><#sep>, </#list>)
196
</#if>
197
</#list>
Nicola Gatto's avatar
Nicola Gatto committed
198
199
200
201

        if not os.path.exists(self._model_dir_):
            os.makedirs(self._model_dir_)

202
203
        for i, network in self.networks.items():
            network.export(self._model_dir_ + self._model_prefix_ + "_" + str(i), epoch=0)
Julian Dierkes's avatar
Julian Dierkes committed
204

205
206
207
    def setWeightInitializer(self, initializer):
        self.weight_initializer = initializer

Julian Dierkes's avatar
Julian Dierkes committed
208
209
210
211
212
213
    def getInputs(self):
        inputs = {}
<#list tc.architecture.streams as stream>
<#assign dimensions = (tc.getStreamInputs(stream, false))>
<#assign domains = (tc.getStreamInputDomains(stream))>
<#list tc.getStreamInputVariableNames(stream, false) as name>
Julian Dierkes's avatar
Julian Dierkes committed
214
215
        input_dimensions = (${tc.join(dimensions[name], ",")},)
        input_domains = (${tc.join(domains[name], ",")},)
Julian Dierkes's avatar
Julian Dierkes committed
216
217
218
219
220
221
222
223
224
225
226
        inputs["${name}"] = input_domains + (input_dimensions,)
</#list>
</#list>
        return inputs

    def getOutputs(self):
        outputs = {}
<#list tc.architecture.streams as stream>
<#assign dimensions = (tc.getStreamOutputs(stream, false))>
<#assign domains = (tc.getStreamOutputDomains(stream))>
<#list tc.getStreamOutputVariableNames(stream, false) as name>
Julian Dierkes's avatar
Julian Dierkes committed
227
228
        output_dimensions = (${tc.join(dimensions[name], ",")},)
        output_domains = (${tc.join(domains[name], ",")},)
Julian Dierkes's avatar
Julian Dierkes committed
229
230
231
232
        outputs["${name}"] = output_domains + (output_dimensions,)
</#list>
</#list>
        return outputs
233
234
235
236

    def validate_parameters(self):
<#list tc.architecture.networkInstructions as networkInstruction>
${tc.include(networkInstruction.body, "PARAMETER_VALIDATION")}
237
</#list>        pass