CNNCreator.ftl 11.9 KB
Newer Older
Bernhard Rumpe's avatar
BR-sy    
Bernhard Rumpe committed
1
<#-- (c) https://github.com/MontiCore/monticore -->
Jean Meurice's avatar
Jean Meurice committed
2
3
<#-- So that the license is in the generated file: -->
# (c) https://github.com/MontiCore/monticore
4
5
6
import mxnet as mx
import logging
import os
Julian Treiber's avatar
Julian Treiber committed
7
import shutil
8
import warnings
9
import inspect
10
11
import sys

12
<#if tc.architecture.customPyFilesPath??>
13
14
sys.path.insert(1, '${tc.architecture.customPyFilesPath}')
from custom_layers import *
15
</#if>
16

17
<#list tc.architecture.networkInstructions as networkInstruction>
lr119628's avatar
lr119628 committed
18
<#if tc.containsAdaNet()>
lr119628's avatar
lr119628 committed
19
from CNNNet_${tc.fullArchitectureName} import Net_${networkInstruction?index},DataClass_${networkInstruction?index}
lr119628's avatar
lr119628 committed
20
21
<#else>
from CNNNet_${tc.fullArchitectureName} import Net_${networkInstruction?index}
lr119628's avatar
lr119628 committed
22
</#if>
lr119628's avatar
lr119628 committed
23
24
</#list>

25
class ${tc.fileNameWithoutEnding}:
nilsfreyer's avatar
nilsfreyer committed
26
27
    _model_dir_ = "model/${tc.componentName}/"
    _model_prefix_ = "model"
28

Nicola Gatto's avatar
Nicola Gatto committed
29
30
    def __init__(self):
        self.weight_initializer = mx.init.Normal()
31
        self.networks = {}
Julian Treiber's avatar
Julian Treiber committed
32
33
34
35
36
<#if (tc.weightsPath)??>
        self._weights_dir_ = "${tc.weightsPath}/"
<#else>
        self._weights_dir_ = None
</#if>
37

Julian Treiber's avatar
Julian Treiber committed
38
    def load(self, context):
39
40
41
42
43
        earliestLastEpoch = None

        for i, network in self.networks.items():
            lastEpoch = 0
            param_file = None
44
45
46
47
            if hasattr(network, 'episodic_sub_nets'):
                num_episodic_sub_nets = len(network.episodic_sub_nets)
                lastMemEpoch = [0]*num_episodic_sub_nets
                mem_files = [None]*num_episodic_sub_nets
48
49
50
51
52
53
54
55
56

            try:
                os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + "_newest-0000.params")
            except OSError:
                pass
            try:
                os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + "_newest-symbol.json")
            except OSError:
                pass
Julian Treiber's avatar
Julian Treiber committed
57

58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
            if hasattr(network, 'episodic_sub_nets'):
                try:
                    os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_sub_net_' + str(0) + "-0000.params")
                except OSError:
                    pass
                try:
                    os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_sub_net_' + str(0) + "-symbol.json")
                except OSError:
                    pass

                for j in range(len(network.episodic_sub_nets)):
                    try:
                        os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_sub_net_' + str(j+1) + "-0000.params")
                    except OSError:
                        pass
                    try:
                        os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_sub_net_' + str(j+1) + "-symbol.json")
                    except OSError:
                        pass
                    try:
                        os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_query_net_' + str(j+1) + "-0000.params")
                    except OSError:
                        pass
                    try:
                        os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_query_net_' + str(j+1) + "-symbol.json")
                    except OSError:
                        pass
                    try:
                        os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_loss' + "-0000.params")
                    except OSError:
                        pass
                    try:
                        os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_loss' + "-symbol.json")
                    except OSError:
                        pass
                    try:
                        os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + "_newest_episodic_memory_sub_net_" + str(j + 1) + "-0000")
                    except OSError:
                        pass

98
99
            if os.path.isdir(self._model_dir_):
                for file in os.listdir(self._model_dir_):
100
                    if ".params" in file and self._model_prefix_ + "_" + str(i) in file and not "loss" in file:
101
                        epochStr = file.replace(".params", "").replace(self._model_prefix_ + "_" + str(i) + "-", "")
102
                        epoch = int(epochStr)
103
                        if epoch >= lastEpoch:
104
105
                            lastEpoch = epoch
                            param_file = file
106
107
108
109
110
111
112
113
114
                    elif hasattr(network, 'episodic_sub_nets') and self._model_prefix_ + "_" + str(i) + "_episodic_memory_sub_net_" in file:
                        relMemPathInfo = file.replace(self._model_prefix_ + "_" + str(i) + "_episodic_memory_sub_net_", "").split("-")
                        memSubNet = int(relMemPathInfo[0])
                        memEpochStr = relMemPathInfo[1]
                        memEpoch = int(memEpochStr)
                        if memEpoch >= lastMemEpoch[memSubNet-1]:
                            lastMemEpoch[memSubNet-1] = memEpoch
                            mem_files[memSubNet-1] = file

115
116
117
118
119
            if param_file is None:
                earliestLastEpoch = 0
            else:
                logging.info("Loading checkpoint: " + param_file)
                network.load_parameters(self._model_dir_ + param_file)
120
121
122
123
124
125
                if hasattr(network, 'episodic_sub_nets'):
                    for j, sub_net in enumerate(network.episodic_sub_nets):
                        if mem_files[j] != None:
                            logging.info("Loading Replay Memory: " + mem_files[j])
                            mem_layer = [param for param in inspect.getmembers(sub_net, lambda x: not(inspect.isroutine(x))) if param[0].startswith("memory")][0][1]
                            mem_layer.load_memory(self._model_dir_ + mem_files[j])
126

127
128
                if earliestLastEpoch == None or lastEpoch + 1 < earliestLastEpoch:
                    earliestLastEpoch = lastEpoch + 1
129
130

        return earliestLastEpoch
131

Julian Treiber's avatar
Julian Treiber committed
132
133
134
135
136
137
138
    def load_pretrained_weights(self, context):
        if os.path.isdir(self._model_dir_):
            shutil.rmtree(self._model_dir_)
        if self._weights_dir_ is not None:
            for i, network in self.networks.items():
                # param_file = self._model_prefix_ + "_" + str(i) + "_newest-0000.params"
                param_file = None
139
140
141
142
143
                if hasattr(network, 'episodic_sub_nets'):
                    num_episodic_sub_nets = len(network.episodic_sub_nets)
                    lastMemEpoch = [0] * num_episodic_sub_nets
                    mem_files = [None] * num_episodic_sub_nets

Julian Treiber's avatar
Julian Treiber committed
144
145
146
147
148
                if os.path.isdir(self._weights_dir_):
                    lastEpoch = 0

                    for file in os.listdir(self._weights_dir_):

149
                        if ".params" in file and self._model_prefix_ + "_" + str(i) in file and not "loss" in file:
Julian Treiber's avatar
Julian Treiber committed
150
151
                            epochStr = file.replace(".params","").replace(self._model_prefix_ + "_" + str(i) + "-","")
                            epoch = int(epochStr)
152
                            if epoch >= lastEpoch:
Julian Treiber's avatar
Julian Treiber committed
153
154
                                lastEpoch = epoch
                                param_file = file
155
156
157
158
159
160
161
162
163
                        elif hasattr(network, 'episodic_sub_nets') and self._model_prefix_ + "_" + str(i) + "_episodic_memory_sub_net_" in file:
                            relMemPathInfo = file.replace(self._model_prefix_ + "_" + str(i) + "_episodic_memory_sub_net_").split("-")
                            memSubNet = int(relMemPathInfo[0])
                            memEpochStr = relMemPathInfo[1]
                            memEpoch = int(memEpochStr)
                            if memEpoch >= lastMemEpoch[memSubNet-1]:
                                lastMemEpoch[memSubNet-1] = memEpoch
                                mem_files[memSubNet-1] = file

Julian Treiber's avatar
Julian Treiber committed
164
                    logging.info("Loading pretrained weights: " + self._weights_dir_ + param_file)
165
                    network.load_parameters(self._weights_dir_ + param_file, allow_missing=True, ignore_extra=True)
166
167
168
169
170
171
172
173
174
                    if hasattr(network, 'episodic_sub_nets'):
                        assert lastEpoch == lastMemEpoch
                        for j, sub_net in enumerate(network.episodic_sub_nets):
                            if mem_files[j] != None:
                                logging.info("Loading pretrained Replay Memory: " + mem_files[j])
                                mem_layer = \
                                [param for param in inspect.getmembers(sub_net, lambda x: not (inspect.isroutine(x))) if
                                 param[0].startswith("memory")][0][1]
                                mem_layer.load_memory(self._model_dir_ + mem_files[j])
Julian Treiber's avatar
Julian Treiber committed
175
176
                else:
                    logging.info("No pretrained weights available at: " + self._weights_dir_ + param_file)
177

178
    def construct(self, context, data_mean=None, data_std=None):
179
<#list tc.architecture.networkInstructions as networkInstruction>
180
        <#if tc.containsAdaNet()>
lr119628's avatar
lr119628 committed
181
182
        self.networks[${networkInstruction?index}] = Net_${networkInstruction?index}(prefix="",operations=None)
        self.dataClass = DataClass_${networkInstruction?index}()
183
        <#else>
184
        self.networks[${networkInstruction?index}] = Net_${networkInstruction?index}(data_mean=data_mean, data_std=data_std, mx_context=context, prefix="")
185
        </#if>
186
187
188
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            self.networks[${networkInstruction?index}].collect_params().initialize(self.weight_initializer, force_reinit=False, ctx=context)
189
        self.networks[${networkInstruction?index}].hybridize()
190
        self.networks[${networkInstruction?index}](<#list tc.getStreamInputDimensions(networkInstruction.body) as dimensions><#if tc.cutDimensions(dimensions)[tc.cutDimensions(dimensions)?size-1] == "1" && tc.cutDimensions(dimensions)?size != 1>mx.nd.zeros((${tc.join(tc.cutDimensions(dimensions), ",")},), ctx=context[0])<#else>mx.nd.zeros((1, ${tc.join(tc.cutDimensions(dimensions), ",")},), ctx=context[0])</#if><#sep>, </#list>)
191
<#if networkInstruction.body.episodicSubNetworks?has_content>
192
        self.networks[0].episodicsubnet0_(<#list tc.getStreamInputDimensions(networkInstruction.body) as dimensions><#if tc.cutDimensions(dimensions)[tc.cutDimensions(dimensions)?size-1] == "1" && tc.cutDimensions(dimensions)?size != 1>mx.nd.zeros((${tc.join(tc.cutDimensions(dimensions), ",")},), ctx=context[0])<#else>mx.nd.zeros((1, ${tc.join(tc.cutDimensions(dimensions), ",")},), ctx=context[0])</#if><#sep>, </#list>)
193
</#if>
194
</#list>
Nicola Gatto's avatar
Nicola Gatto committed
195
196
197
198

        if not os.path.exists(self._model_dir_):
            os.makedirs(self._model_dir_)

199
200
        for i, network in self.networks.items():
            network.export(self._model_dir_ + self._model_prefix_ + "_" + str(i), epoch=0)
Julian Dierkes's avatar
Julian Dierkes committed
201

202
203
204
    def setWeightInitializer(self, initializer):
        self.weight_initializer = initializer

Julian Dierkes's avatar
Julian Dierkes committed
205
206
207
208
209
210
    def getInputs(self):
        inputs = {}
<#list tc.architecture.streams as stream>
<#assign dimensions = (tc.getStreamInputs(stream, false))>
<#assign domains = (tc.getStreamInputDomains(stream))>
<#list tc.getStreamInputVariableNames(stream, false) as name>
Julian Dierkes's avatar
Julian Dierkes committed
211
212
        input_dimensions = (${tc.join(dimensions[name], ",")},)
        input_domains = (${tc.join(domains[name], ",")},)
Julian Dierkes's avatar
Julian Dierkes committed
213
214
215
216
217
218
219
220
221
222
223
        inputs["${name}"] = input_domains + (input_dimensions,)
</#list>
</#list>
        return inputs

    def getOutputs(self):
        outputs = {}
<#list tc.architecture.streams as stream>
<#assign dimensions = (tc.getStreamOutputs(stream, false))>
<#assign domains = (tc.getStreamOutputDomains(stream))>
<#list tc.getStreamOutputVariableNames(stream, false) as name>
Julian Dierkes's avatar
Julian Dierkes committed
224
225
        output_dimensions = (${tc.join(dimensions[name], ",")},)
        output_domains = (${tc.join(domains[name], ",")},)
Julian Dierkes's avatar
Julian Dierkes committed
226
227
228
229
        outputs["${name}"] = output_domains + (output_dimensions,)
</#list>
</#list>
        return outputs
230
231
232
233

    def validate_parameters(self):
<#list tc.architecture.networkInstructions as networkInstruction>
${tc.include(networkInstruction.body, "PARAMETER_VALIDATION")}
234
</#list>        pass