CNNSupervisedTrainer_mnist_mnistClassifier_net.py 8.9 KB
Newer Older
Nicola Gatto's avatar
Nicola Gatto committed
1 2 3 4 5 6 7 8
import mxnet as mx
import logging
import numpy as np
import time
import os
import shutil
from mxnet import gluon, autograd, nd

9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33
class CrossEntropyLoss(gluon.loss.Loss):
    def __init__(self, axis=-1, sparse_label=True, weight=None, batch_axis=0, **kwargs):
        super(CrossEntropyLoss, self).__init__(weight, batch_axis, **kwargs)
        self._axis = axis
        self._sparse_label = sparse_label

    def hybrid_forward(self, F, pred, label, sample_weight=None):
        pred = F.log(pred)
        if self._sparse_label:
            loss = -F.pick(pred, label, axis=self._axis, keepdims=True)
        else:
            label = gluon.loss._reshape_like(F, label, pred)
            loss = -F.sum(pred * label, axis=self._axis, keepdims=True)
        loss = gluon.loss._apply_weighting(F, loss, self._weight, sample_weight)
        return F.mean(loss, axis=self._batch_axis, exclude=True)

class LogCoshLoss(gluon.loss.Loss):
    def __init__(self, weight=None, batch_axis=0, **kwargs):
        super(LogCoshLoss, self).__init__(weight, batch_axis, **kwargs)

    def hybrid_forward(self, F, pred, label, sample_weight=None):
        loss = F.log(F.cosh(pred - label))
        loss = gluon.loss._apply_weighting(F, loss, self._weight, sample_weight)
        return F.mean(loss, axis=self._batch_axis, exclude=True)

34
class CNNSupervisedTrainer_mnist_mnistClassifier_net:
Sebastian N.'s avatar
Sebastian N. committed
35
    def __init__(self, data_loader, net_constructor):
Nicola Gatto's avatar
Nicola Gatto committed
36 37
        self._data_loader = data_loader
        self._net_creator = net_constructor
Sebastian N.'s avatar
Sebastian N. committed
38
        self._networks = {}
Nicola Gatto's avatar
Nicola Gatto committed
39 40 41 42

    def train(self, batch_size=64,
              num_epoch=10,
              eval_metric='acc',
43 44
              loss ='softmax_cross_entropy',
              loss_params={},
Nicola Gatto's avatar
Nicola Gatto committed
45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74
              optimizer='adam',
              optimizer_params=(('learning_rate', 0.001),),
              load_checkpoint=True,
              context='gpu',
              checkpoint_period=5,
              normalize=True):
        if context == 'gpu':
            mx_context = mx.gpu()
        elif context == 'cpu':
            mx_context = mx.cpu()
        else:
            logging.error("Context argument is '" + context + "'. Only 'cpu' and 'gpu are valid arguments'.")

        if 'weight_decay' in optimizer_params:
            optimizer_params['wd'] = optimizer_params['weight_decay']
            del optimizer_params['weight_decay']
        if 'learning_rate_decay' in optimizer_params:
            min_learning_rate = 1e-08
            if 'learning_rate_minimum' in optimizer_params:
                min_learning_rate = optimizer_params['learning_rate_minimum']
                del optimizer_params['learning_rate_minimum']
            optimizer_params['lr_scheduler'] = mx.lr_scheduler.FactorScheduler(
                                                   optimizer_params['step_size'],
                                                   factor=optimizer_params['learning_rate_decay'],
                                                   stop_factor_lr=min_learning_rate)
            del optimizer_params['step_size']
            del optimizer_params['learning_rate_decay']


        train_iter, test_iter, data_mean, data_std = self._data_loader.load_data(batch_size)
Sebastian N.'s avatar
Sebastian N. committed
75 76 77 78 79

        if normalize:
            self._net_creator.construct(context=mx_context, data_mean=data_mean, data_std=data_std)
        else:
            self._net_creator.construct(context=mx_context)
Nicola Gatto's avatar
Nicola Gatto committed
80 81 82 83 84 85 86 87

        begin_epoch = 0
        if load_checkpoint:
            begin_epoch = self._net_creator.load(mx_context)
        else:
            if os.path.isdir(self._net_creator._model_dir_):
                shutil.rmtree(self._net_creator._model_dir_)

Sebastian N.'s avatar
Sebastian N. committed
88
        self._networks = self._net_creator.networks
Nicola Gatto's avatar
Nicola Gatto committed
89 90 91 92 93 94 95

        try:
            os.makedirs(self._net_creator._model_dir_)
        except OSError:
            if not os.path.isdir(self._net_creator._model_dir_):
                raise

Sebastian N.'s avatar
Sebastian N. committed
96
        trainers = [mx.gluon.Trainer(network.collect_params(), optimizer, optimizer_params) for network in self._networks.values()]
Nicola Gatto's avatar
Nicola Gatto committed
97

98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127
        margin = loss_params['margin'] if 'margin' in loss_params else 1.0
        sparseLabel = loss_params['sparse_label'] if 'sparse_label' in loss_params else True
        if loss == 'softmax_cross_entropy':
            fromLogits = loss_params['from_logits'] if 'from_logits' in loss_params else False
            loss_function = mx.gluon.loss.SoftmaxCrossEntropyLoss(from_logits=fromLogits, sparse_label=sparseLabel)
        elif loss == 'sigmoid_binary_cross_entropy':
            loss_function = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss()
        elif loss == 'cross_entropy':
            loss_function = CrossEntropyLoss(sparse_label=sparseLabel)
        elif loss == 'l2':
            loss_function = mx.gluon.loss.L2Loss()
        elif loss == 'l1':
            loss_function = mx.gluon.loss.L2Loss()
        elif loss == 'huber':
            rho = loss_params['rho'] if 'rho' in loss_params else 1
            loss_function = mx.gluon.loss.HuberLoss(rho=rho)
        elif loss == 'hinge':
            loss_function = mx.gluon.loss.HingeLoss(margin=margin)
        elif loss == 'squared_hinge':
            loss_function = mx.gluon.loss.SquaredHingeLoss(margin=margin)
        elif loss == 'logistic':
            labelFormat = loss_params['label_format'] if 'label_format' in loss_params else 'signed'
            loss_function = mx.gluon.loss.LogisticLoss(label_format=labelFormat)
        elif loss == 'kullback_leibler':
            fromLogits = loss_params['from_logits'] if 'from_logits' in loss_params else True
            loss_function = mx.gluon.loss.KLDivLoss(from_logits=fromLogits)
        elif loss == 'log_cosh':
            loss_function = LogCoshLoss()
        else:
            logging.error("Invalid loss parameter.")
Nicola Gatto's avatar
Nicola Gatto committed
128 129 130 131 132 133 134

        speed_period = 50
        tic = None

        for epoch in range(begin_epoch, begin_epoch + num_epoch):
            train_iter.reset()
            for batch_i, batch in enumerate(train_iter):
135
                image_ = batch.data[0].as_in_context(mx_context)
136
                predictions_label = batch.label[0].as_in_context(mx_context)
137

Nicola Gatto's avatar
Nicola Gatto committed
138
                with autograd.record():
139
                    predictions_ = mx.nd.zeros((10,), ctx=mx_context)
140 141

                    predictions_ = self._networks[0](image_)
142

143
                    loss = \
144
                        loss_function(predictions_, predictions_label)
Nicola Gatto's avatar
Nicola Gatto committed
145 146

                loss.backward()
Sebastian N.'s avatar
Sebastian N. committed
147 148 149

                for trainer in trainers:
                    trainer.step(batch_size)
Nicola Gatto's avatar
Nicola Gatto committed
150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168

                if tic is None:
                    tic = time.time()
                else:
                    if batch_i % speed_period == 0:
                        try:
                            speed = speed_period * batch_size / (time.time() - tic)
                        except ZeroDivisionError:
                            speed = float("inf")

                        logging.info("Epoch[%d] Batch[%d] Speed: %.2f samples/sec" % (epoch, batch_i, speed))

                        tic = time.time()

            tic = None

            train_iter.reset()
            metric = mx.metric.create(eval_metric)
            for batch_i, batch in enumerate(train_iter):
169
                image_ = batch.data[0].as_in_context(mx_context)
170 171 172 173 174

                labels = [
                    batch.label[0].as_in_context(mx_context)
                ]

175
                if True:
176
                    predictions_ = mx.nd.zeros((10,), ctx=mx_context)
177 178

                    predictions_ = self._networks[0](image_)
179

180
                predictions = [
181
                    mx.nd.argmax(predictions_, axis=1)
182 183 184
                ]

                metric.update(preds=predictions, labels=labels)
Nicola Gatto's avatar
Nicola Gatto committed
185 186 187 188 189
            train_metric_score = metric.get()[1]

            test_iter.reset()
            metric = mx.metric.create(eval_metric)
            for batch_i, batch in enumerate(test_iter):
190
                image_ = batch.data[0].as_in_context(mx_context)
191 192 193 194 195

                labels = [
                    batch.label[0].as_in_context(mx_context)
                ]

196
                if True:
197
                    predictions_ = mx.nd.zeros((10,), ctx=mx_context)
198 199

                    predictions_ = self._networks[0](image_)
200

201
                predictions = [
202
                    mx.nd.argmax(predictions_, axis=1)
203
                ]
204

205
                metric.update(preds=predictions, labels=labels)
Nicola Gatto's avatar
Nicola Gatto committed
206 207 208 209 210
            test_metric_score = metric.get()[1]

            logging.info("Epoch[%d] Train: %f, Test: %f" % (epoch, train_metric_score, test_metric_score))

            if (epoch - begin_epoch) % checkpoint_period == 0:
Sebastian N.'s avatar
Sebastian N. committed
211 212
                for i, network in self._networks.items():
                    network.save_parameters(self.parameter_path(i) + '-' + str(epoch).zfill(4) + '.params')
Nicola Gatto's avatar
Nicola Gatto committed
213

Sebastian N.'s avatar
Sebastian N. committed
214 215 216
        for i, network in self._networks.items():
            network.save_parameters(self.parameter_path(i) + '-' + str(num_epoch + begin_epoch).zfill(4) + '.params')
            network.export(self.parameter_path(i) + '_newest', epoch=0)
Nicola Gatto's avatar
Nicola Gatto committed
217

Sebastian N.'s avatar
Sebastian N. committed
218 219
    def parameter_path(self, index):
        return self._net_creator._model_dir_ + self._net_creator._model_prefix_ + '_' + str(index)