CNNCreator_Alexnet.py 4.04 KB
Newer Older
1
2
3
import mxnet as mx
import logging
import os
Julian Treiber's avatar
Julian Treiber committed
4
import shutil
5
import warnings
6

7
from CNNNet_Alexnet import Net_0
8
9
10

class CNNCreator_Alexnet:
    _model_dir_ = "model/Alexnet/"
nilsfreyer's avatar
nilsfreyer committed
11
    _model_prefix_ = "model"
12

Nicola Gatto's avatar
Nicola Gatto committed
13
14
    def __init__(self):
        self.weight_initializer = mx.init.Normal()
15
        self.networks = {}
Julian Treiber's avatar
Julian Treiber committed
16
        self._weights_dir_ = None
17

18
    def load(self, context):
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
        earliestLastEpoch = None

        for i, network in self.networks.items():
            lastEpoch = 0
            param_file = None

            try:
                os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + "_newest-0000.params")
            except OSError:
                pass
            try:
                os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + "_newest-symbol.json")
            except OSError:
                pass

            if os.path.isdir(self._model_dir_):
                for file in os.listdir(self._model_dir_):
                    if ".params" in file and self._model_prefix_ + "_" + str(i) in file:
                        epochStr = file.replace(".params","").replace(self._model_prefix_ + "_" + str(i) + "-","")
                        epoch = int(epochStr)
                        if epoch > lastEpoch:
                            lastEpoch = epoch
                            param_file = file
            if param_file is None:
                earliestLastEpoch = 0
            else:
                logging.info("Loading checkpoint: " + param_file)
                network.load_parameters(self._model_dir_ + param_file)

                if earliestLastEpoch == None or lastEpoch < earliestLastEpoch:
                    earliestLastEpoch = lastEpoch

        return earliestLastEpoch
52

Julian Treiber's avatar
Julian Treiber committed
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
    def load_pretrained_weights(self, context):
        if os.path.isdir(self._model_dir_):
            shutil.rmtree(self._model_dir_)
        if self._weights_dir_ is not None:
            for i, network in self.networks.items():
                # param_file = self._model_prefix_ + "_" + str(i) + "_newest-0000.params"
                param_file = None
                if os.path.isdir(self._weights_dir_):
                    lastEpoch = 0

                    for file in os.listdir(self._weights_dir_):

                        if ".params" in file and self._model_prefix_ + "_" + str(i) in file:
                            epochStr = file.replace(".params","").replace(self._model_prefix_ + "_" + str(i) + "-","")
                            epoch = int(epochStr)
                            if epoch > lastEpoch:
                                lastEpoch = epoch
                                param_file = file
                    logging.info("Loading pretrained weights: " + self._weights_dir_ + param_file)
                    network.load_parameters(self._weights_dir_ + param_file, allow_missing=True, ignore_extra=True)
                else:
                    logging.info("No pretrained weights available at: " + self._weights_dir_ + param_file)
75
    
76
    def construct(self, context, data_mean=None, data_std=None):
77
78
79
80
        self.networks[0] = Net_0(data_mean=data_mean, data_std=data_std, mx_context=context)
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            self.networks[0].collect_params().initialize(self.weight_initializer, force_reinit=False, ctx=context)
81
        self.networks[0].hybridize()
82
        self.networks[0](mx.nd.zeros((1, 3,224,224,), ctx=context[0]))
83

Nicola Gatto's avatar
Nicola Gatto committed
84
85
86
        if not os.path.exists(self._model_dir_):
            os.makedirs(self._model_dir_)

87
88
        for i, network in self.networks.items():
            network.export(self._model_dir_ + self._model_prefix_ + "_" + str(i), epoch=0)
Julian Dierkes's avatar
Julian Dierkes committed
89
90
91
92
93
94
95
96
97
98
99
100
101
102

    def getInputs(self):
        inputs = {}
        input_dimensions = (3,224,224,)
        input_domains = (int,0.0,255.0,)
        inputs["data_"] = input_domains + (input_dimensions,)
        return inputs

    def getOutputs(self):
        outputs = {}
        output_dimensions = (10,1,1,)
        output_domains = (float,0.0,1.0,)
        outputs["predictions_"] = output_domains + (output_dimensions,)
        return outputs