CNNCreator.ftl 1.88 KB
Newer Older
1
2
3
import mxnet as mx
import logging
import os
Nicola Gatto's avatar
Nicola Gatto committed
4
from CNNNet_${tc.fullArchitectureName} import Net
5
6

class ${tc.fileNameWithoutEnding}:
nilsfreyer's avatar
nilsfreyer committed
7
8
    _model_dir_ = "model/${tc.componentName}/"
    _model_prefix_ = "model"
9
10
    _input_shapes_ = [<#list tc.architecture.inputs as input>(${tc.join(input.definition.type.dimensions, ",")})</#list>]

Nicola Gatto's avatar
Nicola Gatto committed
11
12
13
    def __init__(self):
        self.weight_initializer = mx.init.Normal()
        self.net = None
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39

    def load(self, context):
        lastEpoch = 0
        param_file = None

        try:
            os.remove(self._model_dir_ + self._model_prefix_ + "_newest-0000.params")
        except OSError:
            pass
        try:
            os.remove(self._model_dir_ + self._model_prefix_ + "_newest-symbol.json")
        except OSError:
            pass

        if os.path.isdir(self._model_dir_):
            for file in os.listdir(self._model_dir_):
                if ".params" in file and self._model_prefix_ in file:
                    epochStr = file.replace(".params","").replace(self._model_prefix_ + "-","")
                    epoch = int(epochStr)
                    if epoch > lastEpoch:
                        lastEpoch = epoch
                        param_file = file
        if param_file is None:
            return 0
        else:
            logging.info("Loading checkpoint: " + param_file)
Nicola Gatto's avatar
Nicola Gatto committed
40
            self.net.load_parameters(self._model_dir_ + param_file)
41
42
43
44
            return lastEpoch


    def construct(self, context, data_mean=None, data_std=None):
Nicola Gatto's avatar
Nicola Gatto committed
45
46
47
48
        self.net = Net(data_mean=data_mean, data_std=data_std)
        self.net.collect_params().initialize(self.weight_initializer, ctx=context)
        self.net.hybridize()
        self.net(mx.nd.zeros((1,)+self._input_shapes_[0], ctx=context))
Nicola Gatto's avatar
Nicola Gatto committed
49
50
51
52
53

        if not os.path.exists(self._model_dir_):
            os.makedirs(self._model_dir_)

        self.net.export(self._model_dir_ + self._model_prefix_, epoch=0)