CNNSupervisedTrainer_Alexnet.py 11.3 KB
Newer Older
Nicola Gatto's avatar
Nicola Gatto committed
1 2 3 4 5 6 7 8
import mxnet as mx
import logging
import numpy as np
import time
import os
import shutil
from mxnet import gluon, autograd, nd

Eyüp Harputlu's avatar
Eyüp Harputlu committed
9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
class CrossEntropyLoss(gluon.loss.Loss):
    def __init__(self, axis=-1, sparse_label=True, weight=None, batch_axis=0, **kwargs):
        super(CrossEntropyLoss, self).__init__(weight, batch_axis, **kwargs)
        self._axis = axis
        self._sparse_label = sparse_label

    def hybrid_forward(self, F, pred, label, sample_weight=None):
        pred = F.log(pred)
        if self._sparse_label:
            loss = -F.pick(pred, label, axis=self._axis, keepdims=True)
        else:
            label = gluon.loss._reshape_like(F, label, pred)
            loss = -F.sum(pred * label, axis=self._axis, keepdims=True)
        loss = gluon.loss._apply_weighting(F, loss, self._weight, sample_weight)
        return F.mean(loss, axis=self._batch_axis, exclude=True)

Eyüp Harputlu's avatar
Eyüp Harputlu committed
25 26 27 28 29 30 31 32 33
class LogCoshLoss(gluon.loss.Loss):
    def __init__(self, weight=None, batch_axis=0, **kwargs):
        super(LogCoshLoss, self).__init__(weight, batch_axis, **kwargs)

    def hybrid_forward(self, F, pred, label, sample_weight=None):
        loss = F.log(F.cosh(pred - label))
        loss = gluon.loss._apply_weighting(F, loss, self._weight, sample_weight)
        return F.mean(loss, axis=self._batch_axis, exclude=True)

34
class CNNSupervisedTrainer_Alexnet:
Christian Fuß's avatar
Christian Fuß committed
35
    def applyBeamSearch(input, length, width, maxLength, currProb, netIndex, bestOutput):
36
        bestProb = 0.0
Christian Fuß's avatar
Christian Fuß committed
37 38
        while length < maxLength:
            length += 1
39 40 41 42 43 44 45 46
            batchIndex = 0
            for batchEntry in input:
                top_k_indices = mx.nd.topk(batchEntry, axis=0, k=width)
                top_k_values = mx.nd.topk(batchEntry, ret_typ='value', axis=0, k=width)
                for index in range(top_k_indices.size):

                    #print mx.nd.array(top_k_indices[index])
                    #print top_k_values[index]
Christian Fuß's avatar
Christian Fuß committed
47
                    if length == 1:
48
                        #print mx.nd.array(top_k_indices[index])
Christian Fuß's avatar
Christian Fuß committed
49
                        result = applyBeamSearch(self._networks[netIndex](mx.nd.array(top_k_indices[index])), length, width, maxLength,
50 51
                            currProb * top_k_values[index], netIndex, self._networks[netIndex](mx.nd.array(top_k_indices[index])))
                    else:
Christian Fuß's avatar
Christian Fuß committed
52
                        result = applyBeamSearch(self._networks[netIndex](mx.nd.array(top_k_indices[index])), length, width, maxLength,
53 54
                            currProb * top_k_values[index], netIndex, bestOutput)

Christian Fuß's avatar
Christian Fuß committed
55
                    if length == maxLength:
56 57 58 59 60 61 62 63 64 65 66 67
                        #print currProb
                        if currProb > bestProb:
                            bestProb = currProb
                            bestOutput[batchIndex] = result[batchIndex]
                            #print "new bestOutput: ", bestOutput

                batchIndex += 1
        #print bestOutput
        #print bestProb
        return bestOutput


68
    def __init__(self, data_loader, net_constructor):
Nicola Gatto's avatar
Nicola Gatto committed
69 70
        self._data_loader = data_loader
        self._net_creator = net_constructor
71
        self._networks = {}
Nicola Gatto's avatar
Nicola Gatto committed
72 73 74 75

    def train(self, batch_size=64,
              num_epoch=10,
              eval_metric='acc',
Eyüp Harputlu's avatar
Eyüp Harputlu committed
76 77
              loss ='softmax_cross_entropy',
              loss_params={},
Nicola Gatto's avatar
Nicola Gatto committed
78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107
              optimizer='adam',
              optimizer_params=(('learning_rate', 0.001),),
              load_checkpoint=True,
              context='gpu',
              checkpoint_period=5,
              normalize=True):
        if context == 'gpu':
            mx_context = mx.gpu()
        elif context == 'cpu':
            mx_context = mx.cpu()
        else:
            logging.error("Context argument is '" + context + "'. Only 'cpu' and 'gpu are valid arguments'.")

        if 'weight_decay' in optimizer_params:
            optimizer_params['wd'] = optimizer_params['weight_decay']
            del optimizer_params['weight_decay']
        if 'learning_rate_decay' in optimizer_params:
            min_learning_rate = 1e-08
            if 'learning_rate_minimum' in optimizer_params:
                min_learning_rate = optimizer_params['learning_rate_minimum']
                del optimizer_params['learning_rate_minimum']
            optimizer_params['lr_scheduler'] = mx.lr_scheduler.FactorScheduler(
                                                   optimizer_params['step_size'],
                                                   factor=optimizer_params['learning_rate_decay'],
                                                   stop_factor_lr=min_learning_rate)
            del optimizer_params['step_size']
            del optimizer_params['learning_rate_decay']


        train_iter, test_iter, data_mean, data_std = self._data_loader.load_data(batch_size)
108 109 110 111 112

        if normalize:
            self._net_creator.construct(context=mx_context, data_mean=data_mean, data_std=data_std)
        else:
            self._net_creator.construct(context=mx_context)
Nicola Gatto's avatar
Nicola Gatto committed
113 114 115 116 117 118 119 120

        begin_epoch = 0
        if load_checkpoint:
            begin_epoch = self._net_creator.load(mx_context)
        else:
            if os.path.isdir(self._net_creator._model_dir_):
                shutil.rmtree(self._net_creator._model_dir_)

121
        self._networks = self._net_creator.networks
Nicola Gatto's avatar
Nicola Gatto committed
122 123 124 125 126 127 128

        try:
            os.makedirs(self._net_creator._model_dir_)
        except OSError:
            if not os.path.isdir(self._net_creator._model_dir_):
                raise

129
        trainers = [mx.gluon.Trainer(network.collect_params(), optimizer, optimizer_params) for network in self._networks.values()]
Nicola Gatto's avatar
Nicola Gatto committed
130

Eyüp Harputlu's avatar
Eyüp Harputlu committed
131 132 133 134 135 136
        margin = loss_params['margin'] if 'margin' in loss_params else 1.0
        sparseLabel = loss_params['sparse_label'] if 'sparse_label' in loss_params else True
        if loss == 'softmax_cross_entropy':
            fromLogits = loss_params['from_logits'] if 'from_logits' in loss_params else False
            loss_function = mx.gluon.loss.SoftmaxCrossEntropyLoss(from_logits=fromLogits, sparse_label=sparseLabel)
        elif loss == 'sigmoid_binary_cross_entropy':
Nicola Gatto's avatar
Nicola Gatto committed
137
            loss_function = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss()
Eyüp Harputlu's avatar
Eyüp Harputlu committed
138 139 140
        elif loss == 'cross_entropy':
            loss_function = CrossEntropyLoss(sparse_label=sparseLabel)
        elif loss == 'l2':
Nicola Gatto's avatar
Nicola Gatto committed
141
            loss_function = mx.gluon.loss.L2Loss()
Eyüp Harputlu's avatar
Eyüp Harputlu committed
142
        elif loss == 'l1':
Nicola Gatto's avatar
Nicola Gatto committed
143
            loss_function = mx.gluon.loss.L2Loss()
Eyüp Harputlu's avatar
Eyüp Harputlu committed
144 145 146 147 148 149 150 151 152 153 154 155 156
        elif loss == 'huber':
            rho = loss_params['rho'] if 'rho' in loss_params else 1
            loss_function = mx.gluon.loss.HuberLoss(rho=rho)
        elif loss == 'hinge':
            loss_function = mx.gluon.loss.HingeLoss(margin=margin)
        elif loss == 'squared_hinge':
            loss_function = mx.gluon.loss.SquaredHingeLoss(margin=margin)
        elif loss == 'logistic':
            labelFormat = loss_params['label_format'] if 'label_format' in loss_params else 'signed'
            loss_function = mx.gluon.loss.LogisticLoss(label_format=labelFormat)
        elif loss == 'kullback_leibler':
            fromLogits = loss_params['from_logits'] if 'from_logits' in loss_params else True
            loss_function = mx.gluon.loss.KLDivLoss(from_logits=fromLogits)
Eyüp Harputlu's avatar
Eyüp Harputlu committed
157 158
        elif loss == 'log_cosh':
            loss_function = LogCoshLoss()
Eyüp Harputlu's avatar
Eyüp Harputlu committed
159 160
        else:
            logging.error("Invalid loss parameter.")
Nicola Gatto's avatar
Nicola Gatto committed
161 162 163 164 165 166 167

        speed_period = 50
        tic = None

        for epoch in range(begin_epoch, begin_epoch + num_epoch):
            train_iter.reset()
            for batch_i, batch in enumerate(train_iter):
168
                data_ = batch.data[0].as_in_context(mx_context)
169
                predictions_label = batch.label[0].as_in_context(mx_context)
170

Nicola Gatto's avatar
Nicola Gatto committed
171
                with autograd.record():
172
                    predictions_ = mx.nd.zeros((batch_size, 10,), ctx=mx_context)
173

174
                    lossList = []
175
                    predictions_ = self._networks[0](data_)
176 177 178 179 180
                    lossList.append(loss_function(predictions_, predictions_label))

                    loss = 0
                    for element in lossList:
                        loss = loss + element
181

Nicola Gatto's avatar
Nicola Gatto committed
182
                loss.backward()
183 184 185

                for trainer in trainers:
                    trainer.step(batch_size)
Nicola Gatto's avatar
Nicola Gatto committed
186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204

                if tic is None:
                    tic = time.time()
                else:
                    if batch_i % speed_period == 0:
                        try:
                            speed = speed_period * batch_size / (time.time() - tic)
                        except ZeroDivisionError:
                            speed = float("inf")

                        logging.info("Epoch[%d] Batch[%d] Speed: %.2f samples/sec" % (epoch, batch_i, speed))

                        tic = time.time()

            tic = None

            train_iter.reset()
            metric = mx.metric.create(eval_metric)
            for batch_i, batch in enumerate(train_iter):
205
                data_ = batch.data[0].as_in_context(mx_context)
206 207 208 209 210

                labels = [
                    batch.label[0].as_in_context(mx_context)
                ]

211
                outputs=[]
212

213
                if True: 
214
                    predictions_ = mx.nd.zeros((batch_size, 10,), ctx=mx_context)
215 216

                    predictions_ = self._networks[0](data_)
217
                    outputs.append(predictions_)
218

219
                predictions = []
220
                for output_name in outputs:
221 222 223 224 225
                    if mx.nd.shape_array(output_name).size > 1:
                        predictions.append(mx.nd.argmax(output_name, axis=1))
                    #ArgMax already applied
                    else:
                        predictions.append(output_name)
226

Christian Fuß's avatar
Christian Fuß committed
227 228 229
                #print [word[0] for word in predictions]
                #print labels[0]

230
                metric.update(preds=predictions, labels=labels)
Nicola Gatto's avatar
Nicola Gatto committed
231 232 233 234 235
            train_metric_score = metric.get()[1]

            test_iter.reset()
            metric = mx.metric.create(eval_metric)
            for batch_i, batch in enumerate(test_iter):
236
                data_ = batch.data[0].as_in_context(mx_context)
237 238 239 240 241

                labels = [
                    batch.label[0].as_in_context(mx_context)
                ]

242 243
                outputs=[]

244
                if True: 
245
                    predictions_ = mx.nd.zeros((batch_size, 10,), ctx=mx_context)
246 247

                    predictions_ = self._networks[0](data_)
248
                    outputs.append(predictions_)
249

250
                predictions = []
251
                for output_name in outputs:
252 253 254 255 256
                    if mx.nd.shape_array(output_name).size > 1:
                        predictions.append(mx.nd.argmax(output_name, axis=1))
                    #ArgMax already applied
                    else:
                        predictions.append(output_name)
257

258
                metric.update(preds=predictions, labels=labels)
Nicola Gatto's avatar
Nicola Gatto committed
259 260 261 262
            test_metric_score = metric.get()[1]

            logging.info("Epoch[%d] Train: %f, Test: %f" % (epoch, train_metric_score, test_metric_score))

263

Nicola Gatto's avatar
Nicola Gatto committed
264
            if (epoch - begin_epoch) % checkpoint_period == 0:
265 266
                for i, network in self._networks.items():
                    network.save_parameters(self.parameter_path(i) + '-' + str(epoch).zfill(4) + '.params')
Nicola Gatto's avatar
Nicola Gatto committed
267

268 269 270
        for i, network in self._networks.items():
            network.save_parameters(self.parameter_path(i) + '-' + str(num_epoch + begin_epoch).zfill(4) + '.params')
            network.export(self.parameter_path(i) + '_newest', epoch=0)
Nicola Gatto's avatar
Nicola Gatto committed
271

272 273
    def parameter_path(self, index):
        return self._net_creator._model_dir_ + self._net_creator._model_prefix_ + '_' + str(index)