CNNSupervisedTrainer_Alexnet.py 12.7 KB
Newer Older
Nicola Gatto's avatar
Nicola Gatto committed
1 2 3 4 5 6
import mxnet as mx
import logging
import numpy as np
import time
import os
import shutil
Christian Fuß's avatar
Christian Fuß committed
7
import pickle
Nicola Gatto's avatar
Nicola Gatto committed
8 9
from mxnet import gluon, autograd, nd

Eyüp Harputlu's avatar
Eyüp Harputlu committed
10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
class CrossEntropyLoss(gluon.loss.Loss):
    def __init__(self, axis=-1, sparse_label=True, weight=None, batch_axis=0, **kwargs):
        super(CrossEntropyLoss, self).__init__(weight, batch_axis, **kwargs)
        self._axis = axis
        self._sparse_label = sparse_label

    def hybrid_forward(self, F, pred, label, sample_weight=None):
        pred = F.log(pred)
        if self._sparse_label:
            loss = -F.pick(pred, label, axis=self._axis, keepdims=True)
        else:
            label = gluon.loss._reshape_like(F, label, pred)
            loss = -F.sum(pred * label, axis=self._axis, keepdims=True)
        loss = gluon.loss._apply_weighting(F, loss, self._weight, sample_weight)
        return F.mean(loss, axis=self._batch_axis, exclude=True)

Eyüp Harputlu's avatar
Eyüp Harputlu committed
26 27 28 29 30 31 32 33 34
class LogCoshLoss(gluon.loss.Loss):
    def __init__(self, weight=None, batch_axis=0, **kwargs):
        super(LogCoshLoss, self).__init__(weight, batch_axis, **kwargs)

    def hybrid_forward(self, F, pred, label, sample_weight=None):
        loss = F.log(F.cosh(pred - label))
        loss = gluon.loss._apply_weighting(F, loss, self._weight, sample_weight)
        return F.mean(loss, axis=self._batch_axis, exclude=True)

Christian Fuß's avatar
Christian Fuß committed
35

36
class CNNSupervisedTrainer_Alexnet:
Christian Fuß's avatar
Christian Fuß committed
37
    def applyBeamSearch(input, length, width, maxLength, currProb, netIndex, bestOutput):
38
        bestProb = 0.0
Christian Fuß's avatar
Christian Fuß committed
39 40
        while length < maxLength:
            length += 1
41 42 43 44 45 46 47 48
            batchIndex = 0
            for batchEntry in input:
                top_k_indices = mx.nd.topk(batchEntry, axis=0, k=width)
                top_k_values = mx.nd.topk(batchEntry, ret_typ='value', axis=0, k=width)
                for index in range(top_k_indices.size):

                    #print mx.nd.array(top_k_indices[index])
                    #print top_k_values[index]
Christian Fuß's avatar
Christian Fuß committed
49
                    if length == 1:
50
                        #print mx.nd.array(top_k_indices[index])
Christian Fuß's avatar
Christian Fuß committed
51
                        result = applyBeamSearch(self._networks[netIndex](mx.nd.array(top_k_indices[index])), length, width, maxLength,
52 53
                            currProb * top_k_values[index], netIndex, self._networks[netIndex](mx.nd.array(top_k_indices[index])))
                    else:
Christian Fuß's avatar
Christian Fuß committed
54
                        result = applyBeamSearch(self._networks[netIndex](mx.nd.array(top_k_indices[index])), length, width, maxLength,
55 56
                            currProb * top_k_values[index], netIndex, bestOutput)

Christian Fuß's avatar
Christian Fuß committed
57
                    if length == maxLength:
58 59 60 61 62 63 64 65 66 67 68 69
                        #print currProb
                        if currProb > bestProb:
                            bestProb = currProb
                            bestOutput[batchIndex] = result[batchIndex]
                            #print "new bestOutput: ", bestOutput

                batchIndex += 1
        #print bestOutput
        #print bestProb
        return bestOutput


70
    def __init__(self, data_loader, net_constructor):
Nicola Gatto's avatar
Nicola Gatto committed
71 72
        self._data_loader = data_loader
        self._net_creator = net_constructor
73
        self._networks = {}
Nicola Gatto's avatar
Nicola Gatto committed
74 75 76 77

    def train(self, batch_size=64,
              num_epoch=10,
              eval_metric='acc',
Eyüp Harputlu's avatar
Eyüp Harputlu committed
78 79
              loss ='softmax_cross_entropy',
              loss_params={},
Nicola Gatto's avatar
Nicola Gatto committed
80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109
              optimizer='adam',
              optimizer_params=(('learning_rate', 0.001),),
              load_checkpoint=True,
              context='gpu',
              checkpoint_period=5,
              normalize=True):
        if context == 'gpu':
            mx_context = mx.gpu()
        elif context == 'cpu':
            mx_context = mx.cpu()
        else:
            logging.error("Context argument is '" + context + "'. Only 'cpu' and 'gpu are valid arguments'.")

        if 'weight_decay' in optimizer_params:
            optimizer_params['wd'] = optimizer_params['weight_decay']
            del optimizer_params['weight_decay']
        if 'learning_rate_decay' in optimizer_params:
            min_learning_rate = 1e-08
            if 'learning_rate_minimum' in optimizer_params:
                min_learning_rate = optimizer_params['learning_rate_minimum']
                del optimizer_params['learning_rate_minimum']
            optimizer_params['lr_scheduler'] = mx.lr_scheduler.FactorScheduler(
                                                   optimizer_params['step_size'],
                                                   factor=optimizer_params['learning_rate_decay'],
                                                   stop_factor_lr=min_learning_rate)
            del optimizer_params['step_size']
            del optimizer_params['learning_rate_decay']


        train_iter, test_iter, data_mean, data_std = self._data_loader.load_data(batch_size)
110 111 112 113 114

        if normalize:
            self._net_creator.construct(context=mx_context, data_mean=data_mean, data_std=data_std)
        else:
            self._net_creator.construct(context=mx_context)
Nicola Gatto's avatar
Nicola Gatto committed
115 116 117 118 119 120 121 122

        begin_epoch = 0
        if load_checkpoint:
            begin_epoch = self._net_creator.load(mx_context)
        else:
            if os.path.isdir(self._net_creator._model_dir_):
                shutil.rmtree(self._net_creator._model_dir_)

123
        self._networks = self._net_creator.networks
Nicola Gatto's avatar
Nicola Gatto committed
124 125 126 127 128 129 130

        try:
            os.makedirs(self._net_creator._model_dir_)
        except OSError:
            if not os.path.isdir(self._net_creator._model_dir_):
                raise

131
        trainers = [mx.gluon.Trainer(network.collect_params(), optimizer, optimizer_params) for network in self._networks.values()]
Nicola Gatto's avatar
Nicola Gatto committed
132

Eyüp Harputlu's avatar
Eyüp Harputlu committed
133 134 135 136 137 138
        margin = loss_params['margin'] if 'margin' in loss_params else 1.0
        sparseLabel = loss_params['sparse_label'] if 'sparse_label' in loss_params else True
        if loss == 'softmax_cross_entropy':
            fromLogits = loss_params['from_logits'] if 'from_logits' in loss_params else False
            loss_function = mx.gluon.loss.SoftmaxCrossEntropyLoss(from_logits=fromLogits, sparse_label=sparseLabel)
        elif loss == 'sigmoid_binary_cross_entropy':
Nicola Gatto's avatar
Nicola Gatto committed
139
            loss_function = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss()
Eyüp Harputlu's avatar
Eyüp Harputlu committed
140 141 142
        elif loss == 'cross_entropy':
            loss_function = CrossEntropyLoss(sparse_label=sparseLabel)
        elif loss == 'l2':
Nicola Gatto's avatar
Nicola Gatto committed
143
            loss_function = mx.gluon.loss.L2Loss()
Eyüp Harputlu's avatar
Eyüp Harputlu committed
144
        elif loss == 'l1':
Nicola Gatto's avatar
Nicola Gatto committed
145
            loss_function = mx.gluon.loss.L2Loss()
Eyüp Harputlu's avatar
Eyüp Harputlu committed
146 147 148 149 150 151 152 153 154 155 156 157 158
        elif loss == 'huber':
            rho = loss_params['rho'] if 'rho' in loss_params else 1
            loss_function = mx.gluon.loss.HuberLoss(rho=rho)
        elif loss == 'hinge':
            loss_function = mx.gluon.loss.HingeLoss(margin=margin)
        elif loss == 'squared_hinge':
            loss_function = mx.gluon.loss.SquaredHingeLoss(margin=margin)
        elif loss == 'logistic':
            labelFormat = loss_params['label_format'] if 'label_format' in loss_params else 'signed'
            loss_function = mx.gluon.loss.LogisticLoss(label_format=labelFormat)
        elif loss == 'kullback_leibler':
            fromLogits = loss_params['from_logits'] if 'from_logits' in loss_params else True
            loss_function = mx.gluon.loss.KLDivLoss(from_logits=fromLogits)
Eyüp Harputlu's avatar
Eyüp Harputlu committed
159 160
        elif loss == 'log_cosh':
            loss_function = LogCoshLoss()
Eyüp Harputlu's avatar
Eyüp Harputlu committed
161 162
        else:
            logging.error("Invalid loss parameter.")
Nicola Gatto's avatar
Nicola Gatto committed
163 164 165 166 167 168 169

        speed_period = 50
        tic = None

        for epoch in range(begin_epoch, begin_epoch + num_epoch):
            train_iter.reset()
            for batch_i, batch in enumerate(train_iter):
170
                data_ = batch.data[0].as_in_context(mx_context)
171
                predictions_label = batch.label[0].as_in_context(mx_context)
172

Christian Fuß's avatar
Christian Fuß committed
173 174
                outputs=[]

Nicola Gatto's avatar
Nicola Gatto committed
175
                with autograd.record():
176
                    predictions_ = mx.nd.zeros((batch_size, 10,), ctx=mx_context)
177

178
                    lossList = []
179
                    predictions_ = self._networks[0](data_)
180 181 182 183 184
                    lossList.append(loss_function(predictions_, predictions_label))

                    loss = 0
                    for element in lossList:
                        loss = loss + element
185

Nicola Gatto's avatar
Nicola Gatto committed
186
                loss.backward()
187 188 189

                for trainer in trainers:
                    trainer.step(batch_size)
Nicola Gatto's avatar
Nicola Gatto committed
190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208

                if tic is None:
                    tic = time.time()
                else:
                    if batch_i % speed_period == 0:
                        try:
                            speed = speed_period * batch_size / (time.time() - tic)
                        except ZeroDivisionError:
                            speed = float("inf")

                        logging.info("Epoch[%d] Batch[%d] Speed: %.2f samples/sec" % (epoch, batch_i, speed))

                        tic = time.time()

            tic = None

            train_iter.reset()
            metric = mx.metric.create(eval_metric)
            for batch_i, batch in enumerate(train_iter):
209
                data_ = batch.data[0].as_in_context(mx_context)
210 211 212 213 214

                labels = [
                    batch.label[0].as_in_context(mx_context)
                ]

215
                outputs=[]
216

217
                if True: 
218
                    predictions_ = mx.nd.zeros((batch_size, 10,), ctx=mx_context)
219 220

                    predictions_ = self._networks[0](data_)
221
                    outputs.append(predictions_)
222

223
                predictions = []
224
                for output_name in outputs:
225 226 227 228 229
                    if mx.nd.shape_array(output_name).size > 1:
                        predictions.append(mx.nd.argmax(output_name, axis=1))
                    #ArgMax already applied
                    else:
                        predictions.append(output_name)
230

Christian Fuß's avatar
Christian Fuß committed
231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249
                #Compute BLEU and NIST Score if data folder contains a dictionary -> NLP dataset
                if(os.path.isfile('data/Alexnet/dict.pkl')):
                    with open('data/Alexnet/dict.pkl', 'rb') as f:
                        dict = pickle.load(f)

                    import nltk.translate.bleu_score
                    import nltk.translate.nist_score

                    prediction = []
                    for index in range(batch_size):
                        sentence = ''
                        for entry in predictions:
                            sentence += dict[int(entry[index].asscalar())] + ' '
                        prediction.append(sentence)

                    for index in range(batch_size):
                        sentence = ''
                        for batchEntry in batch.label:
                            sentence += dict[int(batchEntry[index].asscalar())] + ' '
Sebastian N.'s avatar
Sebastian N. committed
250 251 252
                        print("############################")
                        print("label: ", sentence)
                        print("prediction: ", prediction[index])
Christian Fuß's avatar
Christian Fuß committed
253 254 255

                        BLEUscore = nltk.translate.bleu_score.sentence_bleu([sentence], prediction[index])
                        NISTscore = nltk.translate.nist_score.sentence_nist([sentence], prediction[index])
Sebastian N.'s avatar
Sebastian N. committed
256 257 258
                        print("BLEU: ", BLEUscore)
                        print("NIST: ", NISTscore)
                        print("############################")
Christian Fuß's avatar
Christian Fuß committed
259

260
                metric.update(preds=predictions, labels=labels)
Nicola Gatto's avatar
Nicola Gatto committed
261 262 263 264 265
            train_metric_score = metric.get()[1]

            test_iter.reset()
            metric = mx.metric.create(eval_metric)
            for batch_i, batch in enumerate(test_iter):
266
                data_ = batch.data[0].as_in_context(mx_context)
267 268 269 270 271

                labels = [
                    batch.label[0].as_in_context(mx_context)
                ]

272 273
                outputs=[]

274
                if True: 
275
                    predictions_ = mx.nd.zeros((batch_size, 10,), ctx=mx_context)
276 277

                    predictions_ = self._networks[0](data_)
278
                    outputs.append(predictions_)
279

280
                predictions = []
281
                for output_name in outputs:
282 283 284 285 286
                    if mx.nd.shape_array(output_name).size > 1:
                        predictions.append(mx.nd.argmax(output_name, axis=1))
                    #ArgMax already applied
                    else:
                        predictions.append(output_name)
287

288
                metric.update(preds=predictions, labels=labels)
Nicola Gatto's avatar
Nicola Gatto committed
289 290 291 292
            test_metric_score = metric.get()[1]

            logging.info("Epoch[%d] Train: %f, Test: %f" % (epoch, train_metric_score, test_metric_score))

293

Nicola Gatto's avatar
Nicola Gatto committed
294
            if (epoch - begin_epoch) % checkpoint_period == 0:
295 296
                for i, network in self._networks.items():
                    network.save_parameters(self.parameter_path(i) + '-' + str(epoch).zfill(4) + '.params')
Nicola Gatto's avatar
Nicola Gatto committed
297

298 299 300
        for i, network in self._networks.items():
            network.save_parameters(self.parameter_path(i) + '-' + str(num_epoch + begin_epoch).zfill(4) + '.params')
            network.export(self.parameter_path(i) + '_newest', epoch=0)
Nicola Gatto's avatar
Nicola Gatto committed
301

302 303
    def parameter_path(self, index):
        return self._net_creator._model_dir_ + self._net_creator._model_prefix_ + '_' + str(index)