CNNSupervisedTrainer_Alexnet.py 11.2 KB
Newer Older
Nicola Gatto's avatar
Nicola Gatto committed
1 2 3 4 5 6 7 8
import mxnet as mx
import logging
import numpy as np
import time
import os
import shutil
from mxnet import gluon, autograd, nd

Eyüp Harputlu's avatar
Eyüp Harputlu committed
9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
class CrossEntropyLoss(gluon.loss.Loss):
    def __init__(self, axis=-1, sparse_label=True, weight=None, batch_axis=0, **kwargs):
        super(CrossEntropyLoss, self).__init__(weight, batch_axis, **kwargs)
        self._axis = axis
        self._sparse_label = sparse_label

    def hybrid_forward(self, F, pred, label, sample_weight=None):
        pred = F.log(pred)
        if self._sparse_label:
            loss = -F.pick(pred, label, axis=self._axis, keepdims=True)
        else:
            label = gluon.loss._reshape_like(F, label, pred)
            loss = -F.sum(pred * label, axis=self._axis, keepdims=True)
        loss = gluon.loss._apply_weighting(F, loss, self._weight, sample_weight)
        return F.mean(loss, axis=self._batch_axis, exclude=True)

Eyüp Harputlu's avatar
Eyüp Harputlu committed
25 26 27 28 29 30 31 32 33
class LogCoshLoss(gluon.loss.Loss):
    def __init__(self, weight=None, batch_axis=0, **kwargs):
        super(LogCoshLoss, self).__init__(weight, batch_axis, **kwargs)

    def hybrid_forward(self, F, pred, label, sample_weight=None):
        loss = F.log(F.cosh(pred - label))
        loss = gluon.loss._apply_weighting(F, loss, self._weight, sample_weight)
        return F.mean(loss, axis=self._batch_axis, exclude=True)

34
class CNNSupervisedTrainer_Alexnet:
35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67
    def applyBeamSearch(input, depth, width, maxDepth, currProb, netIndex, bestOutput):
        bestProb = 0.0
        while depth < maxDepth:
            depth += 1
            batchIndex = 0
            for batchEntry in input:
                top_k_indices = mx.nd.topk(batchEntry, axis=0, k=width)
                top_k_values = mx.nd.topk(batchEntry, ret_typ='value', axis=0, k=width)
                for index in range(top_k_indices.size):

                    #print mx.nd.array(top_k_indices[index])
                    #print top_k_values[index]
                    if depth == 1:
                        #print mx.nd.array(top_k_indices[index])
                        result = applyBeamSearch(self._networks[netIndex](mx.nd.array(top_k_indices[index])), depth, width, maxDepth,
                            currProb * top_k_values[index], netIndex, self._networks[netIndex](mx.nd.array(top_k_indices[index])))
                    else:
                        result = applyBeamSearch(self._networks[netIndex](mx.nd.array(top_k_indices[index])), depth, width, maxDepth,
                            currProb * top_k_values[index], netIndex, bestOutput)

                    if depth == maxDepth:
                        #print currProb
                        if currProb > bestProb:
                            bestProb = currProb
                            bestOutput[batchIndex] = result[batchIndex]
                            #print "new bestOutput: ", bestOutput

                batchIndex += 1
        #print bestOutput
        #print bestProb
        return bestOutput


68
    def __init__(self, data_loader, net_constructor):
Nicola Gatto's avatar
Nicola Gatto committed
69 70
        self._data_loader = data_loader
        self._net_creator = net_constructor
71
        self._networks = {}
Nicola Gatto's avatar
Nicola Gatto committed
72 73 74 75

    def train(self, batch_size=64,
              num_epoch=10,
              eval_metric='acc',
Eyüp Harputlu's avatar
Eyüp Harputlu committed
76 77
              loss ='softmax_cross_entropy',
              loss_params={},
Nicola Gatto's avatar
Nicola Gatto committed
78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107
              optimizer='adam',
              optimizer_params=(('learning_rate', 0.001),),
              load_checkpoint=True,
              context='gpu',
              checkpoint_period=5,
              normalize=True):
        if context == 'gpu':
            mx_context = mx.gpu()
        elif context == 'cpu':
            mx_context = mx.cpu()
        else:
            logging.error("Context argument is '" + context + "'. Only 'cpu' and 'gpu are valid arguments'.")

        if 'weight_decay' in optimizer_params:
            optimizer_params['wd'] = optimizer_params['weight_decay']
            del optimizer_params['weight_decay']
        if 'learning_rate_decay' in optimizer_params:
            min_learning_rate = 1e-08
            if 'learning_rate_minimum' in optimizer_params:
                min_learning_rate = optimizer_params['learning_rate_minimum']
                del optimizer_params['learning_rate_minimum']
            optimizer_params['lr_scheduler'] = mx.lr_scheduler.FactorScheduler(
                                                   optimizer_params['step_size'],
                                                   factor=optimizer_params['learning_rate_decay'],
                                                   stop_factor_lr=min_learning_rate)
            del optimizer_params['step_size']
            del optimizer_params['learning_rate_decay']


        train_iter, test_iter, data_mean, data_std = self._data_loader.load_data(batch_size)
108 109 110 111 112

        if normalize:
            self._net_creator.construct(context=mx_context, data_mean=data_mean, data_std=data_std)
        else:
            self._net_creator.construct(context=mx_context)
Nicola Gatto's avatar
Nicola Gatto committed
113 114 115 116 117 118 119 120

        begin_epoch = 0
        if load_checkpoint:
            begin_epoch = self._net_creator.load(mx_context)
        else:
            if os.path.isdir(self._net_creator._model_dir_):
                shutil.rmtree(self._net_creator._model_dir_)

121
        self._networks = self._net_creator.networks
Nicola Gatto's avatar
Nicola Gatto committed
122 123 124 125 126 127 128

        try:
            os.makedirs(self._net_creator._model_dir_)
        except OSError:
            if not os.path.isdir(self._net_creator._model_dir_):
                raise

129
        trainers = [mx.gluon.Trainer(network.collect_params(), optimizer, optimizer_params) for network in self._networks.values()]
Nicola Gatto's avatar
Nicola Gatto committed
130

Eyüp Harputlu's avatar
Eyüp Harputlu committed
131 132 133 134 135 136
        margin = loss_params['margin'] if 'margin' in loss_params else 1.0
        sparseLabel = loss_params['sparse_label'] if 'sparse_label' in loss_params else True
        if loss == 'softmax_cross_entropy':
            fromLogits = loss_params['from_logits'] if 'from_logits' in loss_params else False
            loss_function = mx.gluon.loss.SoftmaxCrossEntropyLoss(from_logits=fromLogits, sparse_label=sparseLabel)
        elif loss == 'sigmoid_binary_cross_entropy':
Nicola Gatto's avatar
Nicola Gatto committed
137
            loss_function = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss()
Eyüp Harputlu's avatar
Eyüp Harputlu committed
138 139 140
        elif loss == 'cross_entropy':
            loss_function = CrossEntropyLoss(sparse_label=sparseLabel)
        elif loss == 'l2':
Nicola Gatto's avatar
Nicola Gatto committed
141
            loss_function = mx.gluon.loss.L2Loss()
Eyüp Harputlu's avatar
Eyüp Harputlu committed
142
        elif loss == 'l1':
Nicola Gatto's avatar
Nicola Gatto committed
143
            loss_function = mx.gluon.loss.L2Loss()
Eyüp Harputlu's avatar
Eyüp Harputlu committed
144 145 146 147 148 149 150 151 152 153 154 155 156
        elif loss == 'huber':
            rho = loss_params['rho'] if 'rho' in loss_params else 1
            loss_function = mx.gluon.loss.HuberLoss(rho=rho)
        elif loss == 'hinge':
            loss_function = mx.gluon.loss.HingeLoss(margin=margin)
        elif loss == 'squared_hinge':
            loss_function = mx.gluon.loss.SquaredHingeLoss(margin=margin)
        elif loss == 'logistic':
            labelFormat = loss_params['label_format'] if 'label_format' in loss_params else 'signed'
            loss_function = mx.gluon.loss.LogisticLoss(label_format=labelFormat)
        elif loss == 'kullback_leibler':
            fromLogits = loss_params['from_logits'] if 'from_logits' in loss_params else True
            loss_function = mx.gluon.loss.KLDivLoss(from_logits=fromLogits)
Eyüp Harputlu's avatar
Eyüp Harputlu committed
157 158
        elif loss == 'log_cosh':
            loss_function = LogCoshLoss()
Eyüp Harputlu's avatar
Eyüp Harputlu committed
159 160
        else:
            logging.error("Invalid loss parameter.")
Nicola Gatto's avatar
Nicola Gatto committed
161 162 163 164 165 166 167

        speed_period = 50
        tic = None

        for epoch in range(begin_epoch, begin_epoch + num_epoch):
            train_iter.reset()
            for batch_i, batch in enumerate(train_iter):
168
                data_ = batch.data[0].as_in_context(mx_context)
169
                predictions_label = batch.label[0].as_in_context(mx_context)
170

Nicola Gatto's avatar
Nicola Gatto committed
171
                with autograd.record():
172
                    predictions_ = mx.nd.zeros((batch_size, 10,), ctx=mx_context)
173

174
                    lossList = []
175
                    predictions_ = self._networks[0](data_)
176 177 178 179 180
                    lossList.append(loss_function(predictions_, predictions_label))

                    loss = 0
                    for element in lossList:
                        loss = loss + element
181

Nicola Gatto's avatar
Nicola Gatto committed
182
                loss.backward()
183 184 185

                for trainer in trainers:
                    trainer.step(batch_size)
Nicola Gatto's avatar
Nicola Gatto committed
186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204

                if tic is None:
                    tic = time.time()
                else:
                    if batch_i % speed_period == 0:
                        try:
                            speed = speed_period * batch_size / (time.time() - tic)
                        except ZeroDivisionError:
                            speed = float("inf")

                        logging.info("Epoch[%d] Batch[%d] Speed: %.2f samples/sec" % (epoch, batch_i, speed))

                        tic = time.time()

            tic = None

            train_iter.reset()
            metric = mx.metric.create(eval_metric)
            for batch_i, batch in enumerate(train_iter):
205
                data_ = batch.data[0].as_in_context(mx_context)
206 207 208 209 210

                labels = [
                    batch.label[0].as_in_context(mx_context)
                ]

211
                outputs=[]
212

213
                if True: 
214
                    predictions_ = mx.nd.zeros((batch_size, 10,), ctx=mx_context)
215 216

                    predictions_ = self._networks[0](data_)
217
                    outputs.append(predictions_)
218

219
                predictions = []
220
                for output_name in outputs:
221 222 223 224 225
                    if mx.nd.shape_array(output_name).size > 1:
                        predictions.append(mx.nd.argmax(output_name, axis=1))
                    #ArgMax already applied
                    else:
                        predictions.append(output_name)
226

227
                metric.update(preds=predictions, labels=labels)
Nicola Gatto's avatar
Nicola Gatto committed
228 229 230 231 232
            train_metric_score = metric.get()[1]

            test_iter.reset()
            metric = mx.metric.create(eval_metric)
            for batch_i, batch in enumerate(test_iter):
233
                data_ = batch.data[0].as_in_context(mx_context)
234 235 236 237 238

                labels = [
                    batch.label[0].as_in_context(mx_context)
                ]

239 240
                outputs=[]

241
                if True: 
242
                    predictions_ = mx.nd.zeros((batch_size, 10,), ctx=mx_context)
243 244

                    predictions_ = self._networks[0](data_)
245
                    outputs.append(predictions_)
246

247
                predictions = []
248
                for output_name in outputs:
249 250 251 252 253
                    if mx.nd.shape_array(output_name).size > 1:
                        predictions.append(mx.nd.argmax(output_name, axis=1))
                    #ArgMax already applied
                    else:
                        predictions.append(output_name)
254

255
                metric.update(preds=predictions, labels=labels)
Nicola Gatto's avatar
Nicola Gatto committed
256 257 258 259
            test_metric_score = metric.get()[1]

            logging.info("Epoch[%d] Train: %f, Test: %f" % (epoch, train_metric_score, test_metric_score))

260

Nicola Gatto's avatar
Nicola Gatto committed
261
            if (epoch - begin_epoch) % checkpoint_period == 0:
262 263
                for i, network in self._networks.items():
                    network.save_parameters(self.parameter_path(i) + '-' + str(epoch).zfill(4) + '.params')
Nicola Gatto's avatar
Nicola Gatto committed
264

265 266 267
        for i, network in self._networks.items():
            network.save_parameters(self.parameter_path(i) + '-' + str(num_epoch + begin_epoch).zfill(4) + '.params')
            network.export(self.parameter_path(i) + '_newest', epoch=0)
Nicola Gatto's avatar
Nicola Gatto committed
268

269 270
    def parameter_path(self, index):
        return self._net_creator._model_dir_ + self._net_creator._model_prefix_ + '_' + str(index)