CNNCreator_Alexnet.py 3.89 KB
Newer Older
1
2
3
import mxnet as mx
import logging
import os
Julian Treiber's avatar
Julian Treiber committed
4
import shutil
5

6
from CNNNet_Alexnet import Net_0
7
8
9

class CNNCreator_Alexnet:
    _model_dir_ = "model/Alexnet/"
nilsfreyer's avatar
nilsfreyer committed
10
    _model_prefix_ = "model"
11

Nicola Gatto's avatar
Nicola Gatto committed
12
13
    def __init__(self):
        self.weight_initializer = mx.init.Normal()
14
        self.networks = {}
Julian Treiber's avatar
Julian Treiber committed
15
        self._weights_dir_ = None
16

17
    def load(self, context):
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
        earliestLastEpoch = None

        for i, network in self.networks.items():
            lastEpoch = 0
            param_file = None

            try:
                os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + "_newest-0000.params")
            except OSError:
                pass
            try:
                os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + "_newest-symbol.json")
            except OSError:
                pass

            if os.path.isdir(self._model_dir_):
                for file in os.listdir(self._model_dir_):
                    if ".params" in file and self._model_prefix_ + "_" + str(i) in file:
                        epochStr = file.replace(".params","").replace(self._model_prefix_ + "_" + str(i) + "-","")
                        epoch = int(epochStr)
                        if epoch > lastEpoch:
                            lastEpoch = epoch
                            param_file = file
            if param_file is None:
                earliestLastEpoch = 0
            else:
                logging.info("Loading checkpoint: " + param_file)
                network.load_parameters(self._model_dir_ + param_file)

                if earliestLastEpoch == None or lastEpoch < earliestLastEpoch:
                    earliestLastEpoch = lastEpoch

        return earliestLastEpoch
51

Julian Treiber's avatar
Julian Treiber committed
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
    def load_pretrained_weights(self, context):
        if os.path.isdir(self._model_dir_):
            shutil.rmtree(self._model_dir_)
        if self._weights_dir_ is not None:
            for i, network in self.networks.items():
                # param_file = self._model_prefix_ + "_" + str(i) + "_newest-0000.params"
                param_file = None
                if os.path.isdir(self._weights_dir_):
                    lastEpoch = 0

                    for file in os.listdir(self._weights_dir_):

                        if ".params" in file and self._model_prefix_ + "_" + str(i) in file:
                            epochStr = file.replace(".params","").replace(self._model_prefix_ + "_" + str(i) + "-","")
                            epoch = int(epochStr)
                            if epoch > lastEpoch:
                                lastEpoch = epoch
                                param_file = file
                    logging.info("Loading pretrained weights: " + self._weights_dir_ + param_file)
                    network.load_parameters(self._weights_dir_ + param_file, allow_missing=True, ignore_extra=True)
                else:
                    logging.info("No pretrained weights available at: " + self._weights_dir_ + param_file)

75
    def construct(self, context, data_mean=None, data_std=None):
76
77
78
        self.networks[0] = Net_0(data_mean=data_mean, data_std=data_std)
        self.networks[0].collect_params().initialize(self.weight_initializer, ctx=context)
        self.networks[0].hybridize()
79
        self.networks[0](mx.nd.zeros((1, 3,224,224,), ctx=context))
80

Nicola Gatto's avatar
Nicola Gatto committed
81
82
83
        if not os.path.exists(self._model_dir_):
            os.makedirs(self._model_dir_)

84
85
        for i, network in self.networks.items():
            network.export(self._model_dir_ + self._model_prefix_ + "_" + str(i), epoch=0)
Julian Dierkes's avatar
Julian Dierkes committed
86
87
88
89
90
91
92
93
94
95
96
97
98
99

    def getInputs(self):
        inputs = {}
        input_dimensions = (3,224,224,)
        input_domains = (int,0.0,255.0,)
        inputs["data_"] = input_domains + (input_dimensions,)
        return inputs

    def getOutputs(self):
        outputs = {}
        output_dimensions = (10,1,1,)
        output_domains = (float,0.0,1.0,)
        outputs["predictions_"] = output_domains + (output_dimensions,)
        return outputs