CNNCreator_CifarClassifierNetwork.py 9.77 KB
Newer Older
1
2
3
4
5
6
7
8
import mxnet as mx
import logging
import os
import errno
import shutil
import h5py
import sys
import numpy as np
Nicola Gatto's avatar
Nicola Gatto committed
9
10
11
import time
from mxnet import gluon, autograd, nd
from CNNNet_CifarClassifierNetwork import Net
12
13
14
15
16
17
18
19
20
21
22
23
24

@mx.init.register
class MyConstant(mx.init.Initializer):
    def __init__(self, value):
        super(MyConstant, self).__init__(value=value)
        self.value = value
    def _init_weight(self, _, arr):
        arr[:] = mx.nd.array(self.value)

class CNNCreator_CifarClassifierNetwork:

    _data_dir_ = "data/CifarClassifierNetwork/"
    _model_dir_ = "model/CifarClassifierNetwork/"
nilsfreyer's avatar
nilsfreyer committed
25
    _model_prefix_ = "model"
26
27
28
29
    _input_names_ = ['data']
    _input_shapes_ = [(3,32,32)]
    _output_names_ = ['softmax_label']

Nicola Gatto's avatar
Nicola Gatto committed
30
31
32
    def __init__(self):
        self.weight_initializer = mx.init.Normal()
        self.net = None
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58

    def load(self, context):
        lastEpoch = 0
        param_file = None

        try:
            os.remove(self._model_dir_ + self._model_prefix_ + "_newest-0000.params")
        except OSError:
            pass
        try:
            os.remove(self._model_dir_ + self._model_prefix_ + "_newest-symbol.json")
        except OSError:
            pass

        if os.path.isdir(self._model_dir_):
            for file in os.listdir(self._model_dir_):
                if ".params" in file and self._model_prefix_ in file:
                    epochStr = file.replace(".params","").replace(self._model_prefix_ + "-","")
                    epoch = int(epochStr)
                    if epoch > lastEpoch:
                        lastEpoch = epoch
                        param_file = file
        if param_file is None:
            return 0
        else:
            logging.info("Loading checkpoint: " + param_file)
Nicola Gatto's avatar
Nicola Gatto committed
59
            self.net.load_parameters(self._model_dir_ + param_file)
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
            return lastEpoch


    def load_data(self, batch_size):
        train_h5, test_h5 = self.load_h5_files()

        data_mean = train_h5[self._input_names_[0]][:].mean(axis=0)
        data_std = train_h5[self._input_names_[0]][:].std(axis=0) + 1e-5

        train_iter = mx.io.NDArrayIter(train_h5[self._input_names_[0]],
                                       train_h5[self._output_names_[0]],
                                       batch_size=batch_size,
                                       data_name=self._input_names_[0],
                                       label_name=self._output_names_[0])
        test_iter = None
        if test_h5 != None:
            test_iter = mx.io.NDArrayIter(test_h5[self._input_names_[0]],
                                          test_h5[self._output_names_[0]],
                                          batch_size=batch_size,
                                          data_name=self._input_names_[0],
                                          label_name=self._output_names_[0])
        return train_iter, test_iter, data_mean, data_std

    def load_h5_files(self):
        train_h5 = None
        test_h5 = None
        train_path = self._data_dir_ + "train.h5"
        test_path = self._data_dir_ + "test.h5"
        if os.path.isfile(train_path):
            train_h5 = h5py.File(train_path, 'r')
            if not (self._input_names_[0] in train_h5 and self._output_names_[0] in train_h5):
                logging.error("The HDF5 file '" + os.path.abspath(train_path) + "' has to contain the datasets: "
                              + "'" + self._input_names_[0] + "', '" + self._output_names_[0] + "'")
                sys.exit(1)
            test_iter = None
            if os.path.isfile(test_path):
                test_h5 = h5py.File(test_path, 'r')
                if not (self._input_names_[0] in test_h5 and self._output_names_[0] in test_h5):
                    logging.error("The HDF5 file '" + os.path.abspath(test_path) + "' has to contain the datasets: "
                                  + "'" + self._input_names_[0] + "', '" + self._output_names_[0] + "'")
                    sys.exit(1)
            else:
                logging.warning("Couldn't load test set. File '" + os.path.abspath(test_path) + "' does not exist.")
            return train_h5, test_h5
        else:
            logging.error("Data loading failure. File '" + os.path.abspath(train_path) + "' does not exist.")
            sys.exit(1)


109
    def train(self, batch_size=64,
110
              num_epoch=10,
Svetlana Pavlitskaya's avatar
Svetlana Pavlitskaya committed
111
              eval_metric='acc',
112
113
114
              optimizer='adam',
              optimizer_params=(('learning_rate', 0.001),),
              load_checkpoint=True,
115
              context='gpu',
116
117
              checkpoint_period=5,
              normalize=True):
118
119
120
121
122
123
        if context == 'gpu':
            mx_context = mx.gpu()
        elif context == 'cpu':
            mx_context = mx.cpu()
        else:
            logging.error("Context argument is '" + context + "'. Only 'cpu' and 'gpu are valid arguments'.")
124
125
126
127
128
129
130
131
132
133

        if 'weight_decay' in optimizer_params:
            optimizer_params['wd'] = optimizer_params['weight_decay']
            del optimizer_params['weight_decay']
        if 'learning_rate_decay' in optimizer_params:
            min_learning_rate = 1e-08
            if 'learning_rate_minimum' in optimizer_params:
                min_learning_rate = optimizer_params['learning_rate_minimum']
                del optimizer_params['learning_rate_minimum']
            optimizer_params['lr_scheduler'] = mx.lr_scheduler.FactorScheduler(
Nicola Gatto's avatar
Nicola Gatto committed
134
135
136
                optimizer_params['step_size'],
                factor=optimizer_params['learning_rate_decay'],
                stop_factor_lr=min_learning_rate)
137
138
139
140
141
            del optimizer_params['step_size']
            del optimizer_params['learning_rate_decay']


        train_iter, test_iter, data_mean, data_std = self.load_data(batch_size)
Nicola Gatto's avatar
Nicola Gatto committed
142
        if self.net == None:
143
            if normalize:
Nicola Gatto's avatar
Nicola Gatto committed
144
                self.construct(context=mx_context, data_mean=nd.array(data_mean), data_std=nd.array(data_std))
145
            else:
Nicola Gatto's avatar
Nicola Gatto committed
146
                self.construct(context=mx_context)
147
148
149

        begin_epoch = 0
        if load_checkpoint:
150
            begin_epoch = self.load(mx_context)
151
152
153
154
155
156
157
158
159
160
        else:
            if os.path.isdir(self._model_dir_):
                shutil.rmtree(self._model_dir_)

        try:
            os.makedirs(self._model_dir_)
        except OSError:
            if not os.path.isdir(self._model_dir_):
                raise

Nicola Gatto's avatar
Nicola Gatto committed
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
        trainer = mx.gluon.Trainer(self.net.collect_params(), optimizer, optimizer_params)

        if self.net.last_layer == 'softmax':
            loss_function = mx.gluon.loss.SoftmaxCrossEntropyLoss()
        elif self.net.last_layer == 'sigmoid':
            loss_function = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss()
        elif self.net.last_layer == 'linear':
            loss_function = mx.gluon.loss.L2Loss()
        else: # TODO: Change default?
            loss_function = mx.gluon.loss.L2Loss()
            logging.warning("Invalid last_layer, defaulting to L2 loss")

        speed_period = 50
        tic = None

        for epoch in range(begin_epoch, begin_epoch + num_epoch):
            train_iter.reset()
            for batch_i, batch in enumerate(train_iter):
                data = batch.data[0].as_in_context(mx_context)
                label = batch.label[0].as_in_context(mx_context)
                with autograd.record():
                    output = self.net(data)
                    loss = loss_function(output, label)

                loss.backward()
                trainer.step(batch_size)

                if tic is None:
                    tic = time.time()
                else:
                    if batch_i % speed_period == 0:
                        try:
                            speed = speed_period * batch_size / (time.time() - tic)
                        except ZeroDivisionError:
                            speed = float("inf")

                        logging.info("Epoch[%d] Batch[%d] Speed: %.2f samples/sec" % (epoch, batch_i, speed))

                        tic = time.time()

            tic = None

            train_iter.reset()
            metric = mx.metric.create(eval_metric)
            for batch_i, batch in enumerate(train_iter):
                data = batch.data[0].as_in_context(mx_context)
                label = batch.label[0].as_in_context(mx_context)
                output = self.net(data)
                predictions = mx.nd.argmax(output, axis=1)
                metric.update(preds=predictions, labels=label)
            train_metric_score = metric.get()[1]

            test_iter.reset()
            metric = mx.metric.create(eval_metric)
            for batch_i, batch in enumerate(test_iter):
                data = batch.data[0].as_in_context(mx_context)
                label = batch.label[0].as_in_context(mx_context)
                output = self.net(data)
                predictions = mx.nd.argmax(output, axis=1)
                metric.update(preds=predictions, labels=label)
            test_metric_score = metric.get()[1]

            logging.info("Epoch[%d] Train: %f, Test: %f" % (epoch, train_metric_score, test_metric_score))

            if (epoch - begin_epoch) % checkpoint_period == 0:
Nicola Gatto's avatar
Nicola Gatto committed
226
                self.net.save_parameters(self._model_dir_ + self._model_prefix_ + '-' + str(epoch).zfill(4) + '.params')
Nicola Gatto's avatar
Nicola Gatto committed
227

Nicola Gatto's avatar
Nicola Gatto committed
228
229
230
        self.net.save_parameters(self._model_dir_ + self._model_prefix_ + '-'
                                 + str(num_epoch + begin_epoch).zfill(4) + '.params')
        self.net.export(self._model_dir_ + self._model_prefix_ + '_newest', epoch=0)
231
232
233


    def construct(self, context, data_mean=None, data_std=None):
Nicola Gatto's avatar
Nicola Gatto committed
234
235
236
237
        self.net = Net(data_mean=data_mean, data_std=data_std)
        self.net.collect_params().initialize(self.weight_initializer, ctx=context)
        self.net.hybridize()
        self.net(mx.nd.zeros((1,)+self._input_shapes_[0], ctx=context))
Nicola Gatto's avatar
Nicola Gatto committed
238
239
240
241
242

        if not os.path.exists(self._model_dir_):
            os.makedirs(self._model_dir_)

        self.net.export(self._model_dir_ + self._model_prefix_, epoch=0)