CNNCreator_Alexnet.py 2.61 KB
Newer Older
1
2
3
import mxnet as mx
import logging
import os
4

5
from CNNNet_Alexnet import Net_0
6
7
8

class CNNCreator_Alexnet:
    _model_dir_ = "model/Alexnet/"
nilsfreyer's avatar
nilsfreyer committed
9
    _model_prefix_ = "model"
10

Nicola Gatto's avatar
Nicola Gatto committed
11
12
    def __init__(self):
        self.weight_initializer = mx.init.Normal()
13
        self.networks = {}
14

15
    def load(self, context):
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
        earliestLastEpoch = None

        for i, network in self.networks.items():
            lastEpoch = 0
            param_file = None

            try:
                os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + "_newest-0000.params")
            except OSError:
                pass
            try:
                os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + "_newest-symbol.json")
            except OSError:
                pass

            if os.path.isdir(self._model_dir_):
                for file in os.listdir(self._model_dir_):
                    if ".params" in file and self._model_prefix_ + "_" + str(i) in file:
                        epochStr = file.replace(".params","").replace(self._model_prefix_ + "_" + str(i) + "-","")
                        epoch = int(epochStr)
                        if epoch > lastEpoch:
                            lastEpoch = epoch
                            param_file = file
            if param_file is None:
                earliestLastEpoch = 0
            else:
                logging.info("Loading checkpoint: " + param_file)
                network.load_parameters(self._model_dir_ + param_file)

                if earliestLastEpoch == None or lastEpoch < earliestLastEpoch:
                    earliestLastEpoch = lastEpoch

        return earliestLastEpoch
49
50

    def construct(self, context, data_mean=None, data_std=None):
51
52
53
        self.networks[0] = Net_0(data_mean=data_mean, data_std=data_std)
        self.networks[0].collect_params().initialize(self.weight_initializer, ctx=context)
        self.networks[0].hybridize()
54
        self.networks[0](mx.nd.zeros((1, 3,224,224,), ctx=context))
55

Nicola Gatto's avatar
Nicola Gatto committed
56
57
58
        if not os.path.exists(self._model_dir_):
            os.makedirs(self._model_dir_)

59
60
        for i, network in self.networks.items():
            network.export(self._model_dir_ + self._model_prefix_ + "_" + str(i), epoch=0)
Julian Dierkes's avatar
Julian Dierkes committed
61
62
63
64
65
66
67
68
69
70
71
72
73
74

    def getInputs(self):
        inputs = {}
        input_dimensions = (3,224,224,)
        input_domains = (int,0.0,255.0,)
        inputs["data_"] = input_domains + (input_dimensions,)
        return inputs

    def getOutputs(self):
        outputs = {}
        output_dimensions = (10,1,1,)
        output_domains = (float,0.0,1.0,)
        outputs["predictions_"] = output_domains + (output_dimensions,)
        return outputs