CNNSupervisedTrainer_Alexnet.py 11.6 KB
Newer Older
Nicola Gatto's avatar
Nicola Gatto committed
1 2 3 4 5 6 7 8
import mxnet as mx
import logging
import numpy as np
import time
import os
import shutil
from mxnet import gluon, autograd, nd

Eyüp Harputlu's avatar
Eyüp Harputlu committed
9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
class CrossEntropyLoss(gluon.loss.Loss):
    def __init__(self, axis=-1, sparse_label=True, weight=None, batch_axis=0, **kwargs):
        super(CrossEntropyLoss, self).__init__(weight, batch_axis, **kwargs)
        self._axis = axis
        self._sparse_label = sparse_label

    def hybrid_forward(self, F, pred, label, sample_weight=None):
        pred = F.log(pred)
        if self._sparse_label:
            loss = -F.pick(pred, label, axis=self._axis, keepdims=True)
        else:
            label = gluon.loss._reshape_like(F, label, pred)
            loss = -F.sum(pred * label, axis=self._axis, keepdims=True)
        loss = gluon.loss._apply_weighting(F, loss, self._weight, sample_weight)
        return F.mean(loss, axis=self._batch_axis, exclude=True)

Eyüp Harputlu's avatar
Eyüp Harputlu committed
25 26 27 28 29 30 31 32 33
class LogCoshLoss(gluon.loss.Loss):
    def __init__(self, weight=None, batch_axis=0, **kwargs):
        super(LogCoshLoss, self).__init__(weight, batch_axis, **kwargs)

    def hybrid_forward(self, F, pred, label, sample_weight=None):
        loss = F.log(F.cosh(pred - label))
        loss = gluon.loss._apply_weighting(F, loss, self._weight, sample_weight)
        return F.mean(loss, axis=self._batch_axis, exclude=True)

34
class CNNSupervisedTrainer_Alexnet:
35
    def __init__(self, data_loader, net_constructor):
Nicola Gatto's avatar
Nicola Gatto committed
36 37
        self._data_loader = data_loader
        self._net_creator = net_constructor
38
        self._networks = {}
Nicola Gatto's avatar
Nicola Gatto committed
39 40 41 42

    def train(self, batch_size=64,
              num_epoch=10,
              eval_metric='acc',
Eyüp Harputlu's avatar
Eyüp Harputlu committed
43 44
              loss ='softmax_cross_entropy',
              loss_params={},
Nicola Gatto's avatar
Nicola Gatto committed
45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74
              optimizer='adam',
              optimizer_params=(('learning_rate', 0.001),),
              load_checkpoint=True,
              context='gpu',
              checkpoint_period=5,
              normalize=True):
        if context == 'gpu':
            mx_context = mx.gpu()
        elif context == 'cpu':
            mx_context = mx.cpu()
        else:
            logging.error("Context argument is '" + context + "'. Only 'cpu' and 'gpu are valid arguments'.")

        if 'weight_decay' in optimizer_params:
            optimizer_params['wd'] = optimizer_params['weight_decay']
            del optimizer_params['weight_decay']
        if 'learning_rate_decay' in optimizer_params:
            min_learning_rate = 1e-08
            if 'learning_rate_minimum' in optimizer_params:
                min_learning_rate = optimizer_params['learning_rate_minimum']
                del optimizer_params['learning_rate_minimum']
            optimizer_params['lr_scheduler'] = mx.lr_scheduler.FactorScheduler(
                                                   optimizer_params['step_size'],
                                                   factor=optimizer_params['learning_rate_decay'],
                                                   stop_factor_lr=min_learning_rate)
            del optimizer_params['step_size']
            del optimizer_params['learning_rate_decay']


        train_iter, test_iter, data_mean, data_std = self._data_loader.load_data(batch_size)
75 76 77 78 79

        if normalize:
            self._net_creator.construct(context=mx_context, data_mean=data_mean, data_std=data_std)
        else:
            self._net_creator.construct(context=mx_context)
Nicola Gatto's avatar
Nicola Gatto committed
80 81 82 83 84 85 86 87

        begin_epoch = 0
        if load_checkpoint:
            begin_epoch = self._net_creator.load(mx_context)
        else:
            if os.path.isdir(self._net_creator._model_dir_):
                shutil.rmtree(self._net_creator._model_dir_)

88
        self._networks = self._net_creator.networks
Nicola Gatto's avatar
Nicola Gatto committed
89 90 91 92 93 94 95

        try:
            os.makedirs(self._net_creator._model_dir_)
        except OSError:
            if not os.path.isdir(self._net_creator._model_dir_):
                raise

96
        trainers = [mx.gluon.Trainer(network.collect_params(), optimizer, optimizer_params) for network in self._networks.values()]
Nicola Gatto's avatar
Nicola Gatto committed
97

Eyüp Harputlu's avatar
Eyüp Harputlu committed
98 99 100 101 102 103
        margin = loss_params['margin'] if 'margin' in loss_params else 1.0
        sparseLabel = loss_params['sparse_label'] if 'sparse_label' in loss_params else True
        if loss == 'softmax_cross_entropy':
            fromLogits = loss_params['from_logits'] if 'from_logits' in loss_params else False
            loss_function = mx.gluon.loss.SoftmaxCrossEntropyLoss(from_logits=fromLogits, sparse_label=sparseLabel)
        elif loss == 'sigmoid_binary_cross_entropy':
Nicola Gatto's avatar
Nicola Gatto committed
104
            loss_function = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss()
Eyüp Harputlu's avatar
Eyüp Harputlu committed
105 106 107
        elif loss == 'cross_entropy':
            loss_function = CrossEntropyLoss(sparse_label=sparseLabel)
        elif loss == 'l2':
Nicola Gatto's avatar
Nicola Gatto committed
108
            loss_function = mx.gluon.loss.L2Loss()
Eyüp Harputlu's avatar
Eyüp Harputlu committed
109
        elif loss == 'l1':
Nicola Gatto's avatar
Nicola Gatto committed
110
            loss_function = mx.gluon.loss.L2Loss()
Eyüp Harputlu's avatar
Eyüp Harputlu committed
111 112 113 114 115 116 117 118 119 120 121 122 123
        elif loss == 'huber':
            rho = loss_params['rho'] if 'rho' in loss_params else 1
            loss_function = mx.gluon.loss.HuberLoss(rho=rho)
        elif loss == 'hinge':
            loss_function = mx.gluon.loss.HingeLoss(margin=margin)
        elif loss == 'squared_hinge':
            loss_function = mx.gluon.loss.SquaredHingeLoss(margin=margin)
        elif loss == 'logistic':
            labelFormat = loss_params['label_format'] if 'label_format' in loss_params else 'signed'
            loss_function = mx.gluon.loss.LogisticLoss(label_format=labelFormat)
        elif loss == 'kullback_leibler':
            fromLogits = loss_params['from_logits'] if 'from_logits' in loss_params else True
            loss_function = mx.gluon.loss.KLDivLoss(from_logits=fromLogits)
Eyüp Harputlu's avatar
Eyüp Harputlu committed
124 125
        elif loss == 'log_cosh':
            loss_function = LogCoshLoss()
Eyüp Harputlu's avatar
Eyüp Harputlu committed
126 127
        else:
            logging.error("Invalid loss parameter.")
Nicola Gatto's avatar
Nicola Gatto committed
128 129 130 131 132 133 134

        speed_period = 50
        tic = None

        for epoch in range(begin_epoch, begin_epoch + num_epoch):
            train_iter.reset()
            for batch_i, batch in enumerate(train_iter):
135
                data_ = batch.data[0].as_in_context(mx_context)
136
                predictions_label = batch.label[0].as_in_context(mx_context)
137

Nicola Gatto's avatar
Nicola Gatto committed
138
                with autograd.record():
139
                    predictions_ = mx.nd.zeros((batch_size, 10,), ctx=mx_context)
140

141
                    lossList = []
142
                    predictions_ = self._networks[0](data_)
143 144 145 146 147
                    lossList.append(loss_function(predictions_, predictions_label))

                    loss = 0
                    for element in lossList:
                        loss = loss + element
148

Nicola Gatto's avatar
Nicola Gatto committed
149 150

                loss.backward()
151

152

153 154
                for trainer in trainers:
                    trainer.step(batch_size)
Nicola Gatto's avatar
Nicola Gatto committed
155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173

                if tic is None:
                    tic = time.time()
                else:
                    if batch_i % speed_period == 0:
                        try:
                            speed = speed_period * batch_size / (time.time() - tic)
                        except ZeroDivisionError:
                            speed = float("inf")

                        logging.info("Epoch[%d] Batch[%d] Speed: %.2f samples/sec" % (epoch, batch_i, speed))

                        tic = time.time()

            tic = None

            train_iter.reset()
            metric = mx.metric.create(eval_metric)
            for batch_i, batch in enumerate(train_iter):
174
                data_ = batch.data[0].as_in_context(mx_context)
175 176 177 178 179

                labels = [
                    batch.label[0].as_in_context(mx_context)
                ]

180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213

                def applyBeamSearch(input, depth, width, maxDepth, currProb, netIndex, bestOutput):
                    bestProb = 0.0
                    while depth < maxDepth:
                        depth += 1
                        batchIndex = 0
                        for batchEntry in input:
                            top_k_indices = mx.nd.topk(batchEntry, axis=0, k=width)
                            top_k_values = mx.nd.topk(batchEntry, ret_typ='value', axis=0, k=width)
                            for index in range(top_k_indices.size):

                                #print mx.nd.array(top_k_indices[index])
                                #print top_k_values[index]
                                if depth == 1:
                                    #print mx.nd.array(top_k_indices[index])
                                    result = applyBeamSearch(self._networks[netIndex](mx.nd.array(top_k_indices[index])), depth, width, maxDepth,
                                        currProb * top_k_values[index], netIndex, self._networks[netIndex](mx.nd.array(top_k_indices[index])))
                                else:
                                    result = applyBeamSearch(self._networks[netIndex](mx.nd.array(top_k_indices[index])), depth, width, maxDepth,
                                        currProb * top_k_values[index], netIndex, bestOutput)

                                if depth == maxDepth:
                                    #print currProb
                                    if currProb > bestProb:
                                        bestProb = currProb
                                        bestOutput[batchIndex] = result[batchIndex]
                                        #print "new bestOutput: ", bestOutput

                            batchIndex += 1
                    #print bestOutput
                    #print bestProb
                    return bestOutput


214
                if True: 
215
                    predictions_ = mx.nd.zeros((batch_size, 10,), ctx=mx_context)
216 217

                    predictions_ = self._networks[0](data_)
218

219 220 221 222 223 224 225 226 227
                out_names=[]
                out_names.append(predictions_)
                predictions = []
                for output_name in out_names:
                    if mx.nd.shape_array(output_name).size > 1:
                        predictions.append(mx.nd.argmax(output_name, axis=1))
                    #ArgMax already applied
                    else:
                        predictions.append(output_name)
228

229 230

                metric.update(preds=predictions, labels=labels)
Nicola Gatto's avatar
Nicola Gatto committed
231 232 233 234 235
            train_metric_score = metric.get()[1]

            test_iter.reset()
            metric = mx.metric.create(eval_metric)
            for batch_i, batch in enumerate(test_iter):
236
                data_ = batch.data[0].as_in_context(mx_context)
237 238 239 240 241

                labels = [
                    batch.label[0].as_in_context(mx_context)
                ]

242
                if True: 
243
                    predictions_ = mx.nd.zeros((batch_size, 10,), ctx=mx_context)
244 245

                    predictions_ = self._networks[0](data_)
246

247 248 249 250 251 252 253 254 255
                out_names=[]
                out_names.append(predictions_)
                predictions = []
                for output_name in out_names:
                    if mx.nd.shape_array(output_name).size > 1:
                        predictions.append(mx.nd.argmax(output_name, axis=1))
                    #ArgMax already applied
                    else:
                        predictions.append(output_name)
256

257
                metric.update(preds=predictions, labels=labels)
Nicola Gatto's avatar
Nicola Gatto committed
258 259 260 261
            test_metric_score = metric.get()[1]

            logging.info("Epoch[%d] Train: %f, Test: %f" % (epoch, train_metric_score, test_metric_score))

262

Nicola Gatto's avatar
Nicola Gatto committed
263
            if (epoch - begin_epoch) % checkpoint_period == 0:
264 265
                for i, network in self._networks.items():
                    network.save_parameters(self.parameter_path(i) + '-' + str(epoch).zfill(4) + '.params')
Nicola Gatto's avatar
Nicola Gatto committed
266

267 268 269
        for i, network in self._networks.items():
            network.save_parameters(self.parameter_path(i) + '-' + str(num_epoch + begin_epoch).zfill(4) + '.params')
            network.export(self.parameter_path(i) + '_newest', epoch=0)
Nicola Gatto's avatar
Nicola Gatto committed
270

271 272
    def parameter_path(self, index):
        return self._net_creator._model_dir_ + self._net_creator._model_prefix_ + '_' + str(index)