CNNSupervisedTrainer_Alexnet.py 20.2 KB
Newer Older
Nicola Gatto's avatar
Nicola Gatto committed
1 2 3 4 5 6
import mxnet as mx
import logging
import numpy as np
import time
import os
import shutil
Christian Fuß's avatar
Christian Fuß committed
7
import pickle
Sebastian N.'s avatar
Sebastian N. committed
8 9
import math
import sys
Nicola Gatto's avatar
Nicola Gatto committed
10 11
from mxnet import gluon, autograd, nd

Eyüp Harputlu's avatar
Eyüp Harputlu committed
12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27
class CrossEntropyLoss(gluon.loss.Loss):
    def __init__(self, axis=-1, sparse_label=True, weight=None, batch_axis=0, **kwargs):
        super(CrossEntropyLoss, self).__init__(weight, batch_axis, **kwargs)
        self._axis = axis
        self._sparse_label = sparse_label

    def hybrid_forward(self, F, pred, label, sample_weight=None):
        pred = F.log(pred)
        if self._sparse_label:
            loss = -F.pick(pred, label, axis=self._axis, keepdims=True)
        else:
            label = gluon.loss._reshape_like(F, label, pred)
            loss = -F.sum(pred * label, axis=self._axis, keepdims=True)
        loss = gluon.loss._apply_weighting(F, loss, self._weight, sample_weight)
        return F.mean(loss, axis=self._batch_axis, exclude=True)

Eyüp Harputlu's avatar
Eyüp Harputlu committed
28 29 30 31 32 33 34 35 36
class LogCoshLoss(gluon.loss.Loss):
    def __init__(self, weight=None, batch_axis=0, **kwargs):
        super(LogCoshLoss, self).__init__(weight, batch_axis, **kwargs)

    def hybrid_forward(self, F, pred, label, sample_weight=None):
        loss = F.log(F.cosh(pred - label))
        loss = gluon.loss._apply_weighting(F, loss, self._weight, sample_weight)
        return F.mean(loss, axis=self._batch_axis, exclude=True)

Sebastian N.'s avatar
Sebastian N. committed
37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56
class SoftmaxCrossEntropyLossIgnoreIndices(gluon.loss.Loss):
    def __init__(self, axis=-1, ignore_indices=[], sparse_label=True, from_logits=False, weight=None, batch_axis=0, **kwargs):
        super(SoftmaxCrossEntropyLossIgnoreIndices, self).__init__(weight, batch_axis, **kwargs)
        self._axis = axis
        self._ignore_indices = ignore_indices
        self._sparse_label = sparse_label
        self._from_logits = from_logits

    def hybrid_forward(self, F, pred, label, sample_weight=None):
        log_softmax = F.log_softmax
        pick = F.pick
        if not self._from_logits:
            pred = log_softmax(pred, self._axis)
        if self._sparse_label:
            loss = -pick(pred, label, axis=self._axis, keepdims=True)
        else:
            label = _reshape_like(F, label, pred)
            loss = -(pred * label).sum(axis=self._axis, keepdims=True)
        # ignore some indices for loss, e.g. <pad> tokens in NLP applications
        for i in self._ignore_indices:
57
            loss = loss * mx.nd.logical_not(mx.nd.equal(mx.nd.argmax(pred, axis=1), mx.nd.ones_like(mx.nd.argmax(pred, axis=1))*i) * mx.nd.equal(mx.nd.argmax(pred, axis=1), label))
Sebastian N.'s avatar
Sebastian N. committed
58 59
        return loss.mean(axis=self._batch_axis, exclude=True)

Sebastian N.'s avatar
Sebastian N. committed
60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116
@mx.metric.register
class BLEU(mx.metric.EvalMetric):
    N = 4

    def __init__(self, exclude=None, name='bleu', output_names=None, label_names=None):
        super(BLEU, self).__init__(name=name, output_names=output_names, label_names=label_names)

        self._exclude = exclude or []

        self._match_counts = [0 for _ in range(self.N)]
        self._counts = [0 for _ in range(self.N)]

        self._size_ref = 0
        self._size_hyp = 0

    def update(self, labels, preds):
        labels, preds = mx.metric.check_label_shapes(labels, preds, True)

        new_labels = self._convert(labels)
        new_preds = self._convert(preds)

        for label, pred in zip(new_labels, new_preds):
            reference = [word for word in label if word not in self._exclude]
            hypothesis = [word for word in pred if word not in self._exclude]

            self._size_ref += len(reference)
            self._size_hyp += len(hypothesis)

            for n in range(self.N):
                reference_ngrams = self._get_ngrams(reference, n + 1)
                hypothesis_ngrams = self._get_ngrams(hypothesis, n + 1)

                match_count = 0

                for ngram in hypothesis_ngrams:
                    if ngram in reference_ngrams:
                        reference_ngrams.remove(ngram)

                        match_count += 1

                self._match_counts[n] += match_count
                self._counts[n] += len(hypothesis_ngrams)

    def get(self):
        precisions = [sys.float_info.min for n in range(self.N)]

        i = 1

        for n in range(self.N):
            match_counts = self._match_counts[n]
            counts = self._counts[n]

            if counts != 0:
                if match_counts == 0:
                    i *= 2
                    match_counts = 1 / i

117 118
                if (match_counts / counts) > 0:
                    precisions[n] = match_counts / counts
Sebastian N.'s avatar
Sebastian N. committed
119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169

        bleu = self._get_brevity_penalty() * math.exp(sum(map(math.log, precisions)) / self.N)

        return (self.name, bleu)

    def calculate(self):
        precisions = [sys.float_info.min for n in range(self.N)]

        i = 1

        for n in range(self.N):
            match_counts = self._match_counts[n]
            counts = self._counts[n]

            if counts != 0:
                if match_counts == 0:
                    i *= 2
                    match_counts = 1 / i

                precisions[n] = match_counts / counts

        return self._get_brevity_penalty() * math.exp(sum(map(math.log, precisions)) / self.N)

    def _get_brevity_penalty(self):
        if self._size_hyp >= self._size_ref:
            return 1
        else:
            return math.exp(1 - (self._size_ref / self._size_hyp))

    @staticmethod
    def _get_ngrams(sentence, n):
        ngrams = []

        if len(sentence) >= n:
            for i in range(len(sentence) - n + 1):
                ngrams.append(sentence[i:i+n])

        return ngrams

    @staticmethod
    def _convert(nd_list):
        if len(nd_list) == 0:
            return []

        new_list = [[] for _ in range(nd_list[0].shape[0])]

        for element in nd_list:
            for i in range(element.shape[0]):
                new_list[i].append(element[i].asscalar())

        return new_list
Christian Fuß's avatar
Christian Fuß committed
170

Sebastian N.'s avatar
Sebastian N. committed
171 172


173
class CNNSupervisedTrainer_Alexnet:
174
    def __init__(self, data_loader, net_constructor):
Nicola Gatto's avatar
Nicola Gatto committed
175 176
        self._data_loader = data_loader
        self._net_creator = net_constructor
177
        self._networks = {}
Nicola Gatto's avatar
Nicola Gatto committed
178 179 180 181

    def train(self, batch_size=64,
              num_epoch=10,
              eval_metric='acc',
Sebastian N.'s avatar
Sebastian N. committed
182
              eval_metric_params={},
183
              eval_train=False,
Eyüp Harputlu's avatar
Eyüp Harputlu committed
184 185
              loss ='softmax_cross_entropy',
              loss_params={},
Nicola Gatto's avatar
Nicola Gatto committed
186 187 188 189
              optimizer='adam',
              optimizer_params=(('learning_rate', 0.001),),
              load_checkpoint=True,
              checkpoint_period=5,
190 191
              log_period=50,
              context='gpu',
192
              save_attention_image=False,
193
              use_teacher_forcing=False,
Nicola Gatto's avatar
Nicola Gatto committed
194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216
              normalize=True):
        if context == 'gpu':
            mx_context = mx.gpu()
        elif context == 'cpu':
            mx_context = mx.cpu()
        else:
            logging.error("Context argument is '" + context + "'. Only 'cpu' and 'gpu are valid arguments'.")

        if 'weight_decay' in optimizer_params:
            optimizer_params['wd'] = optimizer_params['weight_decay']
            del optimizer_params['weight_decay']
        if 'learning_rate_decay' in optimizer_params:
            min_learning_rate = 1e-08
            if 'learning_rate_minimum' in optimizer_params:
                min_learning_rate = optimizer_params['learning_rate_minimum']
                del optimizer_params['learning_rate_minimum']
            optimizer_params['lr_scheduler'] = mx.lr_scheduler.FactorScheduler(
                                                   optimizer_params['step_size'],
                                                   factor=optimizer_params['learning_rate_decay'],
                                                   stop_factor_lr=min_learning_rate)
            del optimizer_params['step_size']
            del optimizer_params['learning_rate_decay']

217
        train_iter, test_iter, data_mean, data_std, train_images, test_images = self._data_loader.load_data(batch_size)
218 219 220 221 222

        if normalize:
            self._net_creator.construct(context=mx_context, data_mean=data_mean, data_std=data_std)
        else:
            self._net_creator.construct(context=mx_context)
Nicola Gatto's avatar
Nicola Gatto committed
223 224 225 226 227 228 229 230

        begin_epoch = 0
        if load_checkpoint:
            begin_epoch = self._net_creator.load(mx_context)
        else:
            if os.path.isdir(self._net_creator._model_dir_):
                shutil.rmtree(self._net_creator._model_dir_)

231
        self._networks = self._net_creator.networks
Nicola Gatto's avatar
Nicola Gatto committed
232 233 234 235 236 237 238

        try:
            os.makedirs(self._net_creator._model_dir_)
        except OSError:
            if not os.path.isdir(self._net_creator._model_dir_):
                raise

Sebastian N.'s avatar
Sebastian N. committed
239
        trainers = [mx.gluon.Trainer(network.collect_params(), optimizer, optimizer_params) for network in self._networks.values() if len(network.collect_params().values()) != 0]
Nicola Gatto's avatar
Nicola Gatto committed
240

Eyüp Harputlu's avatar
Eyüp Harputlu committed
241 242
        margin = loss_params['margin'] if 'margin' in loss_params else 1.0
        sparseLabel = loss_params['sparse_label'] if 'sparse_label' in loss_params else True
243
        ignore_indices = [loss_params['ignore_indices']] if 'ignore_indices' in loss_params else []
Eyüp Harputlu's avatar
Eyüp Harputlu committed
244 245
        if loss == 'softmax_cross_entropy':
            fromLogits = loss_params['from_logits'] if 'from_logits' in loss_params else False
246
            loss_function = mx.gluon.loss.SoftmaxCrossEntropyLoss(from_logits=fromLogits, sparse_label=sparseLabel)
247
        elif loss == 'softmax_cross_entropy_ignore_indices':
248
            fromLogits = loss_params['from_logits'] if 'from_logits' in loss_params else False
249
            loss_function = SoftmaxCrossEntropyLossIgnoreIndices(ignore_indices=ignore_indices, from_logits=fromLogits, sparse_label=sparseLabel)
Eyüp Harputlu's avatar
Eyüp Harputlu committed
250
        elif loss == 'sigmoid_binary_cross_entropy':
Nicola Gatto's avatar
Nicola Gatto committed
251
            loss_function = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss()
Eyüp Harputlu's avatar
Eyüp Harputlu committed
252 253 254
        elif loss == 'cross_entropy':
            loss_function = CrossEntropyLoss(sparse_label=sparseLabel)
        elif loss == 'l2':
Nicola Gatto's avatar
Nicola Gatto committed
255
            loss_function = mx.gluon.loss.L2Loss()
Eyüp Harputlu's avatar
Eyüp Harputlu committed
256
        elif loss == 'l1':
Nicola Gatto's avatar
Nicola Gatto committed
257
            loss_function = mx.gluon.loss.L2Loss()
Eyüp Harputlu's avatar
Eyüp Harputlu committed
258 259 260 261 262 263 264 265 266 267 268 269 270
        elif loss == 'huber':
            rho = loss_params['rho'] if 'rho' in loss_params else 1
            loss_function = mx.gluon.loss.HuberLoss(rho=rho)
        elif loss == 'hinge':
            loss_function = mx.gluon.loss.HingeLoss(margin=margin)
        elif loss == 'squared_hinge':
            loss_function = mx.gluon.loss.SquaredHingeLoss(margin=margin)
        elif loss == 'logistic':
            labelFormat = loss_params['label_format'] if 'label_format' in loss_params else 'signed'
            loss_function = mx.gluon.loss.LogisticLoss(label_format=labelFormat)
        elif loss == 'kullback_leibler':
            fromLogits = loss_params['from_logits'] if 'from_logits' in loss_params else True
            loss_function = mx.gluon.loss.KLDivLoss(from_logits=fromLogits)
Eyüp Harputlu's avatar
Eyüp Harputlu committed
271 272
        elif loss == 'log_cosh':
            loss_function = LogCoshLoss()
Eyüp Harputlu's avatar
Eyüp Harputlu committed
273 274
        else:
            logging.error("Invalid loss parameter.")
Nicola Gatto's avatar
Nicola Gatto committed
275 276 277 278

        tic = None

        for epoch in range(begin_epoch, begin_epoch + num_epoch):
279 280

            loss_total = 0
Nicola Gatto's avatar
Nicola Gatto committed
281 282
            train_iter.reset()
            for batch_i, batch in enumerate(train_iter):
283 284
                with autograd.record():
                    labels = [batch.label[i].as_in_context(mx_context) for i in range(1)]
285

286 287
                    data_ = batch.data[0].as_in_context(mx_context)

288
                    predictions_ = mx.nd.zeros((batch_size, 10,), ctx=mx_context)
Christian Fuß's avatar
Christian Fuß committed
289

290

291 292
                    nd.waitall()

293
                    lossList = []
294

295
                    predictions_ = self._networks[0](data_)
296 297

                    lossList.append(loss_function(predictions_, labels[0]))
298

299 300 301
                    loss = 0
                    for element in lossList:
                        loss = loss + element
Nicola Gatto's avatar
Nicola Gatto committed
302 303

                loss.backward()
304

305 306
                loss_total += loss.sum().asscalar()

307 308
                for trainer in trainers:
                    trainer.step(batch_size)
Nicola Gatto's avatar
Nicola Gatto committed
309 310 311 312

                if tic is None:
                    tic = time.time()
                else:
313
                    if batch_i % log_period == 0:
Nicola Gatto's avatar
Nicola Gatto committed
314
                        try:
315
                            speed = log_period * batch_size / (time.time() - tic)
Nicola Gatto's avatar
Nicola Gatto committed
316 317 318
                        except ZeroDivisionError:
                            speed = float("inf")

319 320 321 322
                        loss_avg = loss_total / (batch_size * log_period)
                        loss_total = 0

                        logging.info("Epoch[%d] Batch[%d] Speed: %.2f samples/sec Loss: %.5f" % (epoch, batch_i, speed, loss_avg))
Nicola Gatto's avatar
Nicola Gatto committed
323 324 325 326 327

                        tic = time.time()

            tic = None

328 329 330 331 332

            if eval_train:
                train_iter.reset()
                metric = mx.metric.create(eval_metric, **eval_metric_params)
                for batch_i, batch in enumerate(train_iter):
333
                    labels = [batch.label[i].as_in_context(mx_context) for i in range(1)]
334

335
                    data_ = batch.data[0].as_in_context(mx_context)
336

337
                    predictions_ = mx.nd.zeros((batch_size, 10,), ctx=mx_context)
338

339

340 341
                    nd.waitall()

342
                    outputs = []
343
                    attentionList=[]
344
                    predictions_ = self._networks[0](data_)
345

346
                    outputs.append(predictions_)
347

348 349

                    if save_attention_image == "True":
350 351
                        import matplotlib
                        matplotlib.use('Agg')
352 353 354 355 356 357 358
                        import matplotlib.pyplot as plt
                        logging.getLogger('matplotlib').setLevel(logging.ERROR)

                        if(os.path.isfile('src/test/resources/training_data/Show_attend_tell/dict.pkl')):
                            with open('src/test/resources/training_data/Show_attend_tell/dict.pkl', 'rb') as f:
                                dict = pickle.load(f)

359 360 361 362
                        plt.clf()
                        fig = plt.figure(figsize=(15,15))
                        max_length = len(labels)-1

363
                        ax = fig.add_subplot(max_length//3, max_length//4, 1)
364
                        ax.imshow(train_images[0+batch_size*(batch_i)].transpose(1,2,0))
365

366 367
                        for l in range(max_length):
                            attention = attentionList[l]
368
                            attention = mx.nd.slice_axis(attention, axis=0, begin=0, end=1).squeeze()
369
                            attention_resized = np.resize(attention.asnumpy(), (8, 8))
370
                            ax = fig.add_subplot(max_length//3, max_length//4, l+2)
371 372 373
                            if int(labels[l+1][0].asscalar()) > len(dict):
                                ax.set_title("<unk>")
                            elif dict[int(labels[l+1][0].asscalar())] == "<end>":
374
                                ax.set_title(".")
375
                                img = ax.imshow(train_images[0+batch_size*(batch_i)].transpose(1,2,0))
376 377 378 379
                                ax.imshow(attention_resized, cmap='gray', alpha=0.6, extent=img.get_extent())
                                break
                            else:
                                ax.set_title(dict[int(labels[l+1][0].asscalar())])
380
                            img = ax.imshow(train_images[0+batch_size*(batch_i)].transpose(1,2,0))
381
                            ax.imshow(attention_resized, cmap='gray', alpha=0.6, extent=img.get_extent())
382 383 384 385

                        plt.tight_layout()
                        target_dir = 'target/attention_images'
                        if not os.path.exists(target_dir):
386
                            os.makedirs(target_dir)
387 388 389
                        plt.savefig(target_dir + '/attention_train.png')
                        plt.close()

390 391 392 393 394 395
                    predictions = []
                    for output_name in outputs:
                        if mx.nd.shape_array(mx.nd.squeeze(output_name)).size > 1:
                            predictions.append(mx.nd.argmax(output_name, axis=1))
                        else:
                            predictions.append(output_name)
396

397 398 399 400
                    metric.update(preds=predictions, labels=labels)
                train_metric_score = metric.get()[1]
            else:
                train_metric_score = 0
Nicola Gatto's avatar
Nicola Gatto committed
401 402

            test_iter.reset()
Sebastian N.'s avatar
Sebastian N. committed
403
            metric = mx.metric.create(eval_metric, **eval_metric_params)
Nicola Gatto's avatar
Nicola Gatto committed
404
            for batch_i, batch in enumerate(test_iter):
405
                if True:
406
                    labels = [batch.label[i].as_in_context(mx_context) for i in range(1)]
407

408
                    data_ = batch.data[0].as_in_context(mx_context)
409

410
                    predictions_ = mx.nd.zeros((batch_size, 10,), ctx=mx_context)
411

412

413 414
                    nd.waitall()

415
                    outputs = []
416
                    attentionList=[]
417
                    predictions_ = self._networks[0](data_)
418

419
                    outputs.append(predictions_)
420

421 422

                    if save_attention_image == "True":
423 424 425 426 427 428 429 430 431 432
                        if not eval_train:
                            import matplotlib
                            matplotlib.use('Agg')
                            import matplotlib.pyplot as plt
                            logging.getLogger('matplotlib').setLevel(logging.ERROR)

                            if(os.path.isfile('src/test/resources/training_data/Show_attend_tell/dict.pkl')):
                                with open('src/test/resources/training_data/Show_attend_tell/dict.pkl', 'rb') as f:
                                    dict = pickle.load(f)

433
                        plt.clf()
434
                        fig = plt.figure(figsize=(15,15))
435 436
                        max_length = len(labels)-1

437
                        ax = fig.add_subplot(max_length//3, max_length//4, 1)
438
                        ax.imshow(test_images[0+batch_size*(batch_i)].transpose(1,2,0))
439

440 441
                        for l in range(max_length):
                            attention = attentionList[l]
442
                            attention = mx.nd.slice_axis(attention, axis=0, begin=0, end=1).squeeze()
443
                            attention_resized = np.resize(attention.asnumpy(), (8, 8))
444
                            ax = fig.add_subplot(max_length//3, max_length//4, l+2)
445 446 447
                            if int(mx.nd.slice_axis(outputs[l+1], axis=0, begin=0, end=1).squeeze().asscalar()) > len(dict):
                                ax.set_title("<unk>")
                            elif dict[int(mx.nd.slice_axis(outputs[l+1], axis=0, begin=0, end=1).squeeze().asscalar())] == "<end>":
448
                                ax.set_title(".")
449
                                img = ax.imshow(test_images[0+batch_size*(batch_i)].transpose(1,2,0))
450 451 452
                                ax.imshow(attention_resized, cmap='gray', alpha=0.6, extent=img.get_extent())
                                break
                            else:
453
                                ax.set_title(dict[int(mx.nd.slice_axis(outputs[l+1], axis=0, begin=0, end=1).squeeze().asscalar())])
454
                            img = ax.imshow(test_images[0+batch_size*(batch_i)].transpose(1,2,0))
455
                            ax.imshow(attention_resized, cmap='gray', alpha=0.6, extent=img.get_extent())
456 457

                        plt.tight_layout()
458 459 460
                        target_dir = 'target/attention_images'
                        if not os.path.exists(target_dir):
                            os.makedirs(target_dir)
461 462 463
                        plt.savefig(target_dir + '/attention_test.png')
                        plt.close()

464
                predictions = []
465
                for output_name in outputs:
Sebastian N.'s avatar
Sebastian N. committed
466
                    if mx.nd.shape_array(mx.nd.squeeze(output_name)).size > 1:
467 468 469 470
                        predictions.append(mx.nd.argmax(output_name, axis=1))
                    #ArgMax already applied
                    else:
                        predictions.append(output_name)
471

472
                metric.update(preds=predictions, labels=labels)
Nicola Gatto's avatar
Nicola Gatto committed
473 474 475 476
            test_metric_score = metric.get()[1]

            logging.info("Epoch[%d] Train: %f, Test: %f" % (epoch, train_metric_score, test_metric_score))

477

Nicola Gatto's avatar
Nicola Gatto committed
478
            if (epoch - begin_epoch) % checkpoint_period == 0:
479 480
                for i, network in self._networks.items():
                    network.save_parameters(self.parameter_path(i) + '-' + str(epoch).zfill(4) + '.params')
Nicola Gatto's avatar
Nicola Gatto committed
481

482 483 484
        for i, network in self._networks.items():
            network.save_parameters(self.parameter_path(i) + '-' + str(num_epoch + begin_epoch).zfill(4) + '.params')
            network.export(self.parameter_path(i) + '_newest', epoch=0)
Nicola Gatto's avatar
Nicola Gatto committed
485

486
    def parameter_path(self, index):
Bernhard Rumpe's avatar
BR-sy  
Bernhard Rumpe committed
487
        return self._net_creator._model_dir_ + self._net_creator._model_prefix_ + '_' + str(index)