CNNSupervisedTrainer_mnist_mnistClassifier_net.py 26.5 KB
Newer Older
Nicola Gatto's avatar
Nicola Gatto committed
1 2 3 4 5 6
import mxnet as mx
import logging
import numpy as np
import time
import os
import shutil
Sebastian N.'s avatar
Sebastian N. committed
7 8 9
import pickle
import math
import sys
Nicola Gatto's avatar
Nicola Gatto committed
10 11
from mxnet import gluon, autograd, nd

12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36
class CrossEntropyLoss(gluon.loss.Loss):
    def __init__(self, axis=-1, sparse_label=True, weight=None, batch_axis=0, **kwargs):
        super(CrossEntropyLoss, self).__init__(weight, batch_axis, **kwargs)
        self._axis = axis
        self._sparse_label = sparse_label

    def hybrid_forward(self, F, pred, label, sample_weight=None):
        pred = F.log(pred)
        if self._sparse_label:
            loss = -F.pick(pred, label, axis=self._axis, keepdims=True)
        else:
            label = gluon.loss._reshape_like(F, label, pred)
            loss = -F.sum(pred * label, axis=self._axis, keepdims=True)
        loss = gluon.loss._apply_weighting(F, loss, self._weight, sample_weight)
        return F.mean(loss, axis=self._batch_axis, exclude=True)

class LogCoshLoss(gluon.loss.Loss):
    def __init__(self, weight=None, batch_axis=0, **kwargs):
        super(LogCoshLoss, self).__init__(weight, batch_axis, **kwargs)

    def hybrid_forward(self, F, pred, label, sample_weight=None):
        loss = F.log(F.cosh(pred - label))
        loss = gluon.loss._apply_weighting(F, loss, self._weight, sample_weight)
        return F.mean(loss, axis=self._batch_axis, exclude=True)

Sebastian N.'s avatar
Sebastian N. committed
37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52
class SoftmaxCrossEntropyLossIgnoreIndices(gluon.loss.Loss):
    def __init__(self, axis=-1, ignore_indices=[], sparse_label=True, from_logits=False, weight=None, batch_axis=0, **kwargs):
        super(SoftmaxCrossEntropyLossIgnoreIndices, self).__init__(weight, batch_axis, **kwargs)
        self._axis = axis
        self._ignore_indices = ignore_indices
        self._sparse_label = sparse_label
        self._from_logits = from_logits

    def hybrid_forward(self, F, pred, label, sample_weight=None):
        log_softmax = F.log_softmax
        pick = F.pick
        if not self._from_logits:
            pred = log_softmax(pred, self._axis)
        if self._sparse_label:
            loss = -pick(pred, label, axis=self._axis, keepdims=True)
        else:
Julian Treiber's avatar
Julian Treiber committed
53
            label = gluon.loss._reshape_like(F, label, pred)
Sebastian N.'s avatar
Sebastian N. committed
54
            loss = -(pred * label).sum(axis=self._axis, keepdims=True)
Sebastian N.'s avatar
Sebastian N. committed
55 56
        # ignore some indices for loss, e.g. <pad> tokens in NLP applications
        for i in self._ignore_indices:
Christian Fuß's avatar
Christian Fuß committed
57
            loss = loss * mx.nd.logical_not(mx.nd.equal(mx.nd.argmax(pred, axis=1), mx.nd.ones_like(mx.nd.argmax(pred, axis=1))*i) * mx.nd.equal(mx.nd.argmax(pred, axis=1), label))
Sebastian N.'s avatar
Sebastian N. committed
58 59
        return loss.mean(axis=self._batch_axis, exclude=True)

Julian Treiber's avatar
Julian Treiber committed
60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88
class DiceLoss(gluon.loss.Loss):
    def __init__(self, axis=-1, sparse_label=True, from_logits=False, weight=None,
                 batch_axis=0, **kwargs):
        super(DiceLoss, self).__init__(weight, batch_axis, **kwargs)
        self._axis = axis
        self._sparse_label = sparse_label
        self._from_logits = from_logits

    def dice_loss(self, F, pred, label):
        smooth = 1.
        pred_y = F.argmax(pred, axis = self._axis)
        intersection = pred_y * label
        score = (2 * F.mean(intersection, axis=self._batch_axis, exclude=True) + smooth) \
            / (F.mean(label, axis=self._batch_axis, exclude=True) + F.mean(pred_y, axis=self._batch_axis, exclude=True) + smooth)

        return - F.log(score)

    def hybrid_forward(self, F, pred, label, sample_weight=None):
        if not self._from_logits:
            pred = F.log_softmax(pred, self._axis)
        if self._sparse_label:
            loss = -F.pick(pred, label, axis=self._axis, keepdims=True)
        else:
            label = gluon.loss._reshape_like(F, label, pred)
            loss = -F.sum(pred*label, axis=self._axis, keepdims=True)
        loss = gluon.loss._apply_weighting(F, loss, self._weight, sample_weight)
        diceloss = self.dice_loss(F, pred, label)
        return F.mean(loss, axis=self._batch_axis, exclude=True) + diceloss

Julian Treiber's avatar
Julian Treiber committed
89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135
class SoftmaxCrossEntropyLossIgnoreLabel(gluon.loss.Loss):
    def __init__(self, axis=-1, from_logits=False, weight=None,
                 batch_axis=0, ignore_label=255, **kwargs):
        super(SoftmaxCrossEntropyLossIgnoreLabel, self).__init__(weight, batch_axis, **kwargs)
        self._axis = axis
        self._from_logits = from_logits
        self._ignore_label = ignore_label

    def hybrid_forward(self, F, output, label, sample_weight=None):
        if not self._from_logits:
            output = F.log_softmax(output, axis=self._axis)

        valid_label_map = (label != self._ignore_label)
        loss = -(F.pick(output, label, axis=self._axis, keepdims=True) * valid_label_map )

        loss = gluon.loss._apply_weighting(F, loss, self._weight, sample_weight)
        return F.sum(loss) / F.sum(valid_label_map)

@mx.metric.register
class ACCURACY_IGNORE_LABEL(mx.metric.EvalMetric):
    """Ignores a label when computing accuracy.
    """
    def __init__(self, axis=1, metric_ignore_label=255, name='accuracy',
                 output_names=None, label_names=None):
        super(ACCURACY_IGNORE_LABEL, self).__init__(
            name, axis=axis,
            output_names=output_names, label_names=label_names)
        self.axis = axis
        self.ignore_label = metric_ignore_label

    def update(self, labels, preds):
        mx.metric.check_label_shapes(labels, preds)

        for label, pred_label in zip(labels, preds):
            if pred_label.shape != label.shape:
                pred_label = mx.nd.argmax(pred_label, axis=self.axis, keepdims=True)
            label = label.astype('int32')
            pred_label = pred_label.astype('int32').as_in_context(label.context)

            mx.metric.check_label_shapes(label, pred_label)

            correct = mx.nd.sum( (label == pred_label) * (label != self.ignore_label) ).asscalar()
            total = mx.nd.sum( (label != self.ignore_label) ).asscalar()

            self.sum_metric += correct
            self.num_inst += total

Sebastian N.'s avatar
Sebastian N. committed
136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192
@mx.metric.register
class BLEU(mx.metric.EvalMetric):
    N = 4

    def __init__(self, exclude=None, name='bleu', output_names=None, label_names=None):
        super(BLEU, self).__init__(name=name, output_names=output_names, label_names=label_names)

        self._exclude = exclude or []

        self._match_counts = [0 for _ in range(self.N)]
        self._counts = [0 for _ in range(self.N)]

        self._size_ref = 0
        self._size_hyp = 0

    def update(self, labels, preds):
        labels, preds = mx.metric.check_label_shapes(labels, preds, True)

        new_labels = self._convert(labels)
        new_preds = self._convert(preds)

        for label, pred in zip(new_labels, new_preds):
            reference = [word for word in label if word not in self._exclude]
            hypothesis = [word for word in pred if word not in self._exclude]

            self._size_ref += len(reference)
            self._size_hyp += len(hypothesis)

            for n in range(self.N):
                reference_ngrams = self._get_ngrams(reference, n + 1)
                hypothesis_ngrams = self._get_ngrams(hypothesis, n + 1)

                match_count = 0

                for ngram in hypothesis_ngrams:
                    if ngram in reference_ngrams:
                        reference_ngrams.remove(ngram)

                        match_count += 1

                self._match_counts[n] += match_count
                self._counts[n] += len(hypothesis_ngrams)

    def get(self):
        precisions = [sys.float_info.min for n in range(self.N)]

        i = 1

        for n in range(self.N):
            match_counts = self._match_counts[n]
            counts = self._counts[n]

            if counts != 0:
                if match_counts == 0:
                    i *= 2
                    match_counts = 1 / i

193 194
                if (match_counts / counts) > 0:
                    precisions[n] = match_counts / counts
Sebastian N.'s avatar
Sebastian N. committed
195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221

        bleu = self._get_brevity_penalty() * math.exp(sum(map(math.log, precisions)) / self.N)

        return (self.name, bleu)

    def calculate(self):
        precisions = [sys.float_info.min for n in range(self.N)]

        i = 1

        for n in range(self.N):
            match_counts = self._match_counts[n]
            counts = self._counts[n]

            if counts != 0:
                if match_counts == 0:
                    i *= 2
                    match_counts = 1 / i

                precisions[n] = match_counts / counts

        return self._get_brevity_penalty() * math.exp(sum(map(math.log, precisions)) / self.N)

    def _get_brevity_penalty(self):
        if self._size_hyp >= self._size_ref:
            return 1
        else:
Sebastian N.'s avatar
Sebastian N. committed
222 223 224 225 226 227
            if self._size_hyp > 0:
                size_hyp = self._size_hyp
            else:
                size_hyp = 1

            return math.exp(1 - (self._size_ref / size_hyp))
Sebastian N.'s avatar
Sebastian N. committed
228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253

    @staticmethod
    def _get_ngrams(sentence, n):
        ngrams = []

        if len(sentence) >= n:
            for i in range(len(sentence) - n + 1):
                ngrams.append(sentence[i:i+n])

        return ngrams

    @staticmethod
    def _convert(nd_list):
        if len(nd_list) == 0:
            return []

        new_list = [[] for _ in range(nd_list[0].shape[0])]

        for element in nd_list:
            for i in range(element.shape[0]):
                new_list[i].append(element[i].asscalar())

        return new_list



254
class CNNSupervisedTrainer_mnist_mnistClassifier_net:
Sebastian N.'s avatar
Sebastian N. committed
255
    def __init__(self, data_loader, net_constructor):
Nicola Gatto's avatar
Nicola Gatto committed
256 257
        self._data_loader = data_loader
        self._net_creator = net_constructor
Sebastian N.'s avatar
Sebastian N. committed
258
        self._networks = {}
Nicola Gatto's avatar
Nicola Gatto committed
259 260 261 262

    def train(self, batch_size=64,
              num_epoch=10,
              eval_metric='acc',
Sebastian N.'s avatar
Updated  
Sebastian N. committed
263
              eval_metric_params={},
264
              eval_train=False,
265 266
              loss ='softmax_cross_entropy',
              loss_params={},
Nicola Gatto's avatar
Nicola Gatto committed
267 268 269 270
              optimizer='adam',
              optimizer_params=(('learning_rate', 0.001),),
              load_checkpoint=True,
              checkpoint_period=5,
Julian Treiber's avatar
Julian Treiber committed
271
              load_pretrained=False,
272 273
              log_period=50,
              context='gpu',
Sebastian N.'s avatar
Sebastian N. committed
274
              save_attention_image=False,
Sebastian N.'s avatar
Sebastian N. committed
275
              use_teacher_forcing=False,
Sebastian N.'s avatar
Sebastian N. committed
276 277 278 279
              normalize=True,
              shuffle_data=False,
              clip_global_grad_norm=None,
              preprocessing = False):
Nicola Gatto's avatar
Nicola Gatto committed
280 281 282 283 284 285 286
        if context == 'gpu':
            mx_context = mx.gpu()
        elif context == 'cpu':
            mx_context = mx.cpu()
        else:
            logging.error("Context argument is '" + context + "'. Only 'cpu' and 'gpu are valid arguments'.")

Sebastian N.'s avatar
Sebastian N. committed
287 288 289 290 291 292
        if preprocessing:
            preproc_lib = "CNNPreprocessor_mnist_mnistClassifier_net_executor"
            train_iter, test_iter, data_mean, data_std, train_images, test_images = self._data_loader.load_preprocessed_data(batch_size, preproc_lib, shuffle_data)
        else:
            train_iter, test_iter, data_mean, data_std, train_images, test_images = self._data_loader.load_data(batch_size, shuffle_data)

Nicola Gatto's avatar
Nicola Gatto committed
293 294 295 296 297 298 299 300 301
        if 'weight_decay' in optimizer_params:
            optimizer_params['wd'] = optimizer_params['weight_decay']
            del optimizer_params['weight_decay']
        if 'learning_rate_decay' in optimizer_params:
            min_learning_rate = 1e-08
            if 'learning_rate_minimum' in optimizer_params:
                min_learning_rate = optimizer_params['learning_rate_minimum']
                del optimizer_params['learning_rate_minimum']
            optimizer_params['lr_scheduler'] = mx.lr_scheduler.FactorScheduler(
Sebastian N.'s avatar
Sebastian N. committed
302 303 304
                                                   optimizer_params['step_size'],
                                                   factor=optimizer_params['learning_rate_decay'],
                                                   stop_factor_lr=min_learning_rate)
Nicola Gatto's avatar
Nicola Gatto committed
305 306 307
            del optimizer_params['step_size']
            del optimizer_params['learning_rate_decay']

Sebastian N.'s avatar
Sebastian N. committed
308 309 310 311
        if normalize:
            self._net_creator.construct(context=mx_context, data_mean=data_mean, data_std=data_std)
        else:
            self._net_creator.construct(context=mx_context)
Nicola Gatto's avatar
Nicola Gatto committed
312 313 314

        begin_epoch = 0
        if load_checkpoint:
Sebastian N.'s avatar
Updated  
Sebastian N. committed
315
            begin_epoch = self._net_creator.load(mx_context)
Julian Treiber's avatar
Julian Treiber committed
316 317
        elif load_pretrained:
            self._net_creator.load_pretrained_weights(mx_context)
Nicola Gatto's avatar
Nicola Gatto committed
318 319 320 321
        else:
            if os.path.isdir(self._net_creator._model_dir_):
                shutil.rmtree(self._net_creator._model_dir_)

Sebastian N.'s avatar
Sebastian N. committed
322
        self._networks = self._net_creator.networks
Nicola Gatto's avatar
Nicola Gatto committed
323 324 325 326 327 328 329

        try:
            os.makedirs(self._net_creator._model_dir_)
        except OSError:
            if not os.path.isdir(self._net_creator._model_dir_):
                raise

Sebastian N.'s avatar
Sebastian N. committed
330
        trainers = [mx.gluon.Trainer(network.collect_params(), optimizer, optimizer_params) for network in self._networks.values() if len(network.collect_params().values()) != 0]
Nicola Gatto's avatar
Nicola Gatto committed
331

332 333
        margin = loss_params['margin'] if 'margin' in loss_params else 1.0
        sparseLabel = loss_params['sparse_label'] if 'sparse_label' in loss_params else True
Sebastian N.'s avatar
Sebastian N. committed
334
        ignore_indices = [loss_params['ignore_indices']] if 'ignore_indices' in loss_params else []
335 336
        loss_axis = loss_params['loss_axis'] if 'loss_axis' in loss_params else -1
        batch_axis = loss_params['batch_axis'] if 'batch_axis' in loss_params else 0
337 338
        if loss == 'softmax_cross_entropy':
            fromLogits = loss_params['from_logits'] if 'from_logits' in loss_params else False
339
            loss_function = mx.gluon.loss.SoftmaxCrossEntropyLoss(axis=loss_axis, from_logits=fromLogits, sparse_label=sparseLabel, batch_axis=batch_axis)
340
        elif loss == 'softmax_cross_entropy_ignore_indices':
Sebastian N.'s avatar
Sebastian N. committed
341
            fromLogits = loss_params['from_logits'] if 'from_logits' in loss_params else False
Julian Treiber's avatar
Julian Treiber committed
342
            loss_function = SoftmaxCrossEntropyLossIgnoreIndices(axis=loss_axis, ignore_indices=ignore_indices, from_logits=fromLogits, sparse_label=sparseLabel, batch_axis=batch_axis)
343 344 345
        elif loss == 'sigmoid_binary_cross_entropy':
            loss_function = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss()
        elif loss == 'cross_entropy':
346
            loss_function = CrossEntropyLoss(axis=loss_axis, sparse_label=sparseLabel, batch_axis=batch_axis)
Julian Treiber's avatar
Julian Treiber committed
347
        elif loss == 'dice_loss':
348 349
            loss_weight = loss_params['loss_weight'] if 'loss_weight' in loss_params else None
            loss_function = DiceLoss(axis=loss_axis, weight=loss_weight, sparse_label=sparseLabel, batch_axis=batch_axis)
Julian Treiber's avatar
Julian Treiber committed
350 351 352 353
        elif loss == 'softmax_cross_entropy_ignore_label':
            loss_weight = loss_params['loss_weight'] if 'loss_weight' in loss_params else None
            loss_ignore_label = loss_params['loss_ignore_label'] if 'loss_ignore_label' in loss_params else None
            loss_function = SoftmaxCrossEntropyLossIgnoreLabel(axis=loss_axis, ignore_label=loss_ignore_label, weight=loss_weight, batch_axis=batch_axis)
354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374
        elif loss == 'l2':
            loss_function = mx.gluon.loss.L2Loss()
        elif loss == 'l1':
            loss_function = mx.gluon.loss.L2Loss()
        elif loss == 'huber':
            rho = loss_params['rho'] if 'rho' in loss_params else 1
            loss_function = mx.gluon.loss.HuberLoss(rho=rho)
        elif loss == 'hinge':
            loss_function = mx.gluon.loss.HingeLoss(margin=margin)
        elif loss == 'squared_hinge':
            loss_function = mx.gluon.loss.SquaredHingeLoss(margin=margin)
        elif loss == 'logistic':
            labelFormat = loss_params['label_format'] if 'label_format' in loss_params else 'signed'
            loss_function = mx.gluon.loss.LogisticLoss(label_format=labelFormat)
        elif loss == 'kullback_leibler':
            fromLogits = loss_params['from_logits'] if 'from_logits' in loss_params else True
            loss_function = mx.gluon.loss.KLDivLoss(from_logits=fromLogits)
        elif loss == 'log_cosh':
            loss_function = LogCoshLoss()
        else:
            logging.error("Invalid loss parameter.")
Nicola Gatto's avatar
Nicola Gatto committed
375 376 377 378

        tic = None

        for epoch in range(begin_epoch, begin_epoch + num_epoch):
Sebastian N.'s avatar
Sebastian N. committed
379 380 381 382 383 384 385 386 387
            if shuffle_data:
                if preprocessing:
                    preproc_lib = "CNNPreprocessor_mnist_mnistClassifier_net_executor"
                    train_iter, test_iter, data_mean, data_std, train_images, test_images = self._data_loader.load_preprocessed_data(batch_size, preproc_lib, shuffle_data)
                else:
                    train_iter, test_iter, data_mean, data_std, train_images, test_images = self._data_loader.load_data(batch_size, shuffle_data)

            global_loss_train = 0.0
            train_batches = 0
Christian Fuß's avatar
Christian Fuß committed
388

389
            loss_total = 0
Nicola Gatto's avatar
Nicola Gatto committed
390 391 392
            train_iter.reset()
            for batch_i, batch in enumerate(train_iter):
                with autograd.record():
Sebastian N.'s avatar
Sebastian N. committed
393 394 395 396
                    labels = [batch.label[i].as_in_context(mx_context) for i in range(1)]

                    image_ = batch.data[0].as_in_context(mx_context)

397
                    predictions_ = mx.nd.zeros((batch_size, 10,), ctx=mx_context)
Sebastian N.'s avatar
Sebastian N. committed
398 399


Sebastian N.'s avatar
Sebastian N. committed
400 401
                    nd.waitall()

Sebastian N.'s avatar
Sebastian N. committed
402
                    lossList = []
403 404

                    predictions_ = self._networks[0](image_)
405

Sebastian N.'s avatar
Sebastian N. committed
406
                    lossList.append(loss_function(predictions_, labels[0]))
407

Sebastian N.'s avatar
Sebastian N. committed
408 409 410
                    loss = 0
                    for element in lossList:
                        loss = loss + element
Nicola Gatto's avatar
Nicola Gatto committed
411 412

                loss.backward()
Sebastian N.'s avatar
Sebastian N. committed
413

414 415
                loss_total += loss.sum().asscalar()

Sebastian N.'s avatar
Updated  
Sebastian N. committed
416
                global_loss_train += loss.sum().asscalar()
Sebastian N.'s avatar
Sebastian N. committed
417 418 419 420 421 422 423 424 425 426
                train_batches += 1

                if clip_global_grad_norm:
                    grads = []

                    for network in self._networks.values():
                        grads.extend([param.grad(mx_context) for param in network.collect_params().values()])

                    gluon.utils.clip_global_norm(grads, clip_global_grad_norm)

Sebastian N.'s avatar
Sebastian N. committed
427 428
                for trainer in trainers:
                    trainer.step(batch_size)
Nicola Gatto's avatar
Nicola Gatto committed
429 430 431 432

                if tic is None:
                    tic = time.time()
                else:
433
                    if batch_i % log_period == 0:
Nicola Gatto's avatar
Nicola Gatto committed
434
                        try:
435
                            speed = log_period * batch_size / (time.time() - tic)
Nicola Gatto's avatar
Nicola Gatto committed
436 437 438
                        except ZeroDivisionError:
                            speed = float("inf")

439 440 441 442
                        loss_avg = loss_total / (batch_size * log_period)
                        loss_total = 0

                        logging.info("Epoch[%d] Batch[%d] Speed: %.2f samples/sec Loss: %.5f" % (epoch, batch_i, speed, loss_avg))
Nicola Gatto's avatar
Nicola Gatto committed
443 444 445

                        tic = time.time()

Sebastian N.'s avatar
Updated  
Sebastian N. committed
446
            global_loss_train /= (train_batches * batch_size)
Sebastian N.'s avatar
Sebastian N. committed
447

Nicola Gatto's avatar
Nicola Gatto committed
448 449
            tic = None

450

451 452 453 454
            if eval_train:
                train_iter.reset()
                metric = mx.metric.create(eval_metric, **eval_metric_params)
                for batch_i, batch in enumerate(train_iter):
Sebastian N.'s avatar
Sebastian N. committed
455
                    labels = [batch.label[i].as_in_context(mx_context) for i in range(1)]
456

Sebastian N.'s avatar
Sebastian N. committed
457 458
                    image_ = batch.data[0].as_in_context(mx_context)

459
                    predictions_ = mx.nd.zeros((batch_size, 10,), ctx=mx_context)
460

461

Sebastian N.'s avatar
Sebastian N. committed
462 463
                    nd.waitall()

Sebastian N.'s avatar
Sebastian N. committed
464
                    outputs = []
Sebastian N.'s avatar
Sebastian N. committed
465 466
                    lossList = []
                    attentionList = []
467
                    predictions_ = self._networks[0](image_)
468

Sebastian N.'s avatar
Sebastian N. committed
469
                    outputs.append(predictions_)
Sebastian N.'s avatar
Sebastian N. committed
470
                    lossList.append(loss_function(predictions_, labels[0]))
Sebastian N.'s avatar
Sebastian N. committed
471 472 473


                    if save_attention_image == "True":
Christian Fuß's avatar
Christian Fuß committed
474 475
                        import matplotlib
                        matplotlib.use('Agg')
Sebastian N.'s avatar
Sebastian N. committed
476 477
                        import matplotlib.pyplot as plt
                        logging.getLogger('matplotlib').setLevel(logging.ERROR)
478

Sebastian N.'s avatar
Sebastian N. committed
479 480 481
                        if(os.path.isfile('src/test/resources/training_data/Show_attend_tell/dict.pkl')):
                            with open('src/test/resources/training_data/Show_attend_tell/dict.pkl', 'rb') as f:
                                dict = pickle.load(f)
482

483 484 485
                        plt.clf()
                        fig = plt.figure(figsize=(15,15))
                        max_length = len(labels)-1
Sebastian N.'s avatar
Sebastian N. committed
486 487

                        ax = fig.add_subplot(max_length//3, max_length//4, 1)
488
                        ax.imshow(train_images[0+batch_size*(batch_i)].transpose(1,2,0))
Sebastian N.'s avatar
Sebastian N. committed
489

Sebastian N.'s avatar
Sebastian N. committed
490 491
                        for l in range(max_length):
                            attention = attentionList[l]
Christian Fuß's avatar
Christian Fuß committed
492
                            attention = mx.nd.slice_axis(attention, axis=0, begin=0, end=1).squeeze()
Sebastian N.'s avatar
Sebastian N. committed
493
                            attention_resized = np.resize(attention.asnumpy(), (8, 8))
Sebastian N.'s avatar
Sebastian N. committed
494
                            ax = fig.add_subplot(max_length//3, max_length//4, l+2)
495 496 497
                            if int(labels[l+1][0].asscalar()) > len(dict):
                                ax.set_title("<unk>")
                            elif dict[int(labels[l+1][0].asscalar())] == "<end>":
Sebastian N.'s avatar
Sebastian N. committed
498
                                ax.set_title(".")
499
                                img = ax.imshow(train_images[0+batch_size*(batch_i)].transpose(1,2,0))
Sebastian N.'s avatar
Sebastian N. committed
500 501 502 503
                                ax.imshow(attention_resized, cmap='gray', alpha=0.6, extent=img.get_extent())
                                break
                            else:
                                ax.set_title(dict[int(labels[l+1][0].asscalar())])
504
                            img = ax.imshow(train_images[0+batch_size*(batch_i)].transpose(1,2,0))
Christian Fuß's avatar
Christian Fuß committed
505
                            ax.imshow(attention_resized, cmap='gray', alpha=0.6, extent=img.get_extent())
Sebastian N.'s avatar
Sebastian N. committed
506 507 508 509

                        plt.tight_layout()
                        target_dir = 'target/attention_images'
                        if not os.path.exists(target_dir):
Christian Fuß's avatar
Christian Fuß committed
510
                            os.makedirs(target_dir)
Sebastian N.'s avatar
Sebastian N. committed
511 512 513
                        plt.savefig(target_dir + '/attention_train.png')
                        plt.close()

514 515 516 517 518 519
                    predictions = []
                    for output_name in outputs:
                        if mx.nd.shape_array(mx.nd.squeeze(output_name)).size > 1:
                            predictions.append(mx.nd.argmax(output_name, axis=1))
                        else:
                            predictions.append(output_name)
520

521 522 523 524
                    metric.update(preds=predictions, labels=labels)
                train_metric_score = metric.get()[1]
            else:
                train_metric_score = 0
Nicola Gatto's avatar
Nicola Gatto committed
525

Sebastian N.'s avatar
Sebastian N. committed
526 527 528
            global_loss_test = 0.0
            test_batches = 0

Nicola Gatto's avatar
Nicola Gatto committed
529
            test_iter.reset()
Sebastian N.'s avatar
Sebastian N. committed
530
            metric = mx.metric.create(eval_metric, **eval_metric_params)
Nicola Gatto's avatar
Nicola Gatto committed
531
            for batch_i, batch in enumerate(test_iter):
Sebastian N.'s avatar
Sebastian N. committed
532
                if True:
Sebastian N.'s avatar
Sebastian N. committed
533
                    labels = [batch.label[i].as_in_context(mx_context) for i in range(1)]
534

Sebastian N.'s avatar
Sebastian N. committed
535 536
                    image_ = batch.data[0].as_in_context(mx_context)

537
                    predictions_ = mx.nd.zeros((batch_size, 10,), ctx=mx_context)
538

539

Sebastian N.'s avatar
Sebastian N. committed
540 541
                    nd.waitall()

Sebastian N.'s avatar
Sebastian N. committed
542
                    outputs = []
Sebastian N.'s avatar
Sebastian N. committed
543 544
                    lossList = []
                    attentionList = []
545
                    predictions_ = self._networks[0](image_)
546

Sebastian N.'s avatar
Sebastian N. committed
547
                    outputs.append(predictions_)
Sebastian N.'s avatar
Sebastian N. committed
548
                    lossList.append(loss_function(predictions_, labels[0]))
Sebastian N.'s avatar
Sebastian N. committed
549 550 551


                    if save_attention_image == "True":
552 553 554 555 556
                        if not eval_train:
                            import matplotlib
                            matplotlib.use('Agg')
                            import matplotlib.pyplot as plt
                            logging.getLogger('matplotlib').setLevel(logging.ERROR)
557

558 559 560
                            if(os.path.isfile('src/test/resources/training_data/Show_attend_tell/dict.pkl')):
                                with open('src/test/resources/training_data/Show_attend_tell/dict.pkl', 'rb') as f:
                                    dict = pickle.load(f)
Sebastian N.'s avatar
Sebastian N. committed
561

Sebastian N.'s avatar
Sebastian N. committed
562
                        plt.clf()
Sebastian N.'s avatar
Sebastian N. committed
563
                        fig = plt.figure(figsize=(15,15))
Sebastian N.'s avatar
Sebastian N. committed
564
                        max_length = len(labels)-1
565

Sebastian N.'s avatar
Sebastian N. committed
566
                        ax = fig.add_subplot(max_length//3, max_length//4, 1)
567
                        ax.imshow(test_images[0+batch_size*(batch_i)].transpose(1,2,0))
Sebastian N.'s avatar
Sebastian N. committed
568

Sebastian N.'s avatar
Sebastian N. committed
569 570
                        for l in range(max_length):
                            attention = attentionList[l]
Christian Fuß's avatar
Christian Fuß committed
571
                            attention = mx.nd.slice_axis(attention, axis=0, begin=0, end=1).squeeze()
Sebastian N.'s avatar
Sebastian N. committed
572
                            attention_resized = np.resize(attention.asnumpy(), (8, 8))
Sebastian N.'s avatar
Sebastian N. committed
573
                            ax = fig.add_subplot(max_length//3, max_length//4, l+2)
574 575 576
                            if int(mx.nd.slice_axis(outputs[l+1], axis=0, begin=0, end=1).squeeze().asscalar()) > len(dict):
                                ax.set_title("<unk>")
                            elif dict[int(mx.nd.slice_axis(outputs[l+1], axis=0, begin=0, end=1).squeeze().asscalar())] == "<end>":
Sebastian N.'s avatar
Sebastian N. committed
577
                                ax.set_title(".")
578
                                img = ax.imshow(test_images[0+batch_size*(batch_i)].transpose(1,2,0))
Sebastian N.'s avatar
Sebastian N. committed
579 580 581
                                ax.imshow(attention_resized, cmap='gray', alpha=0.6, extent=img.get_extent())
                                break
                            else:
582
                                ax.set_title(dict[int(mx.nd.slice_axis(outputs[l+1], axis=0, begin=0, end=1).squeeze().asscalar())])
583
                            img = ax.imshow(test_images[0+batch_size*(batch_i)].transpose(1,2,0))
Christian Fuß's avatar
Christian Fuß committed
584
                            ax.imshow(attention_resized, cmap='gray', alpha=0.6, extent=img.get_extent())
Sebastian N.'s avatar
Sebastian N. committed
585 586

                        plt.tight_layout()
587 588 589
                        target_dir = 'target/attention_images'
                        if not os.path.exists(target_dir):
                            os.makedirs(target_dir)
Sebastian N.'s avatar
Sebastian N. committed
590 591
                        plt.savefig(target_dir + '/attention_test.png')
                        plt.close()
Sebastian N.'s avatar
Sebastian N. committed
592 593 594 595
                loss = 0
                for element in lossList:
                    loss = loss + element

Sebastian N.'s avatar
Updated  
Sebastian N. committed
596
                global_loss_test += loss.sum().asscalar()
Sebastian N.'s avatar
Sebastian N. committed
597
                test_batches += 1
Sebastian N.'s avatar
Sebastian N. committed
598 599 600

                predictions = []
                for output_name in outputs:
601
                    predictions.append(output_name)
602

603
                metric.update(preds=predictions, labels=labels)
Nicola Gatto's avatar
Nicola Gatto committed
604 605
            test_metric_score = metric.get()[1]

Sebastian N.'s avatar
Updated  
Sebastian N. committed
606
            global_loss_test /= (test_batches * batch_size)
Nicola Gatto's avatar
Nicola Gatto committed
607

Sebastian N.'s avatar
Sebastian N. committed
608
            logging.info("Epoch[%d] Train metric: %f, Test metric: %f, Train loss: %f, Test loss: %f" % (epoch, train_metric_score, test_metric_score, global_loss_train, global_loss_test))
Sebastian N.'s avatar
Sebastian N. committed
609

Nicola Gatto's avatar
Nicola Gatto committed
610
            if (epoch - begin_epoch) % checkpoint_period == 0:
Sebastian N.'s avatar
Sebastian N. committed
611 612
                for i, network in self._networks.items():
                    network.save_parameters(self.parameter_path(i) + '-' + str(epoch).zfill(4) + '.params')
Nicola Gatto's avatar
Nicola Gatto committed
613

Sebastian N.'s avatar
Sebastian N. committed
614
        for i, network in self._networks.items():
Sebastian N.'s avatar
Updated  
Sebastian N. committed
615
            network.save_parameters(self.parameter_path(i) + '-' + str(num_epoch + begin_epoch + 1).zfill(4) + '.params')
Sebastian N.'s avatar
Sebastian N. committed
616
            network.export(self.parameter_path(i) + '_newest', epoch=0)
Nicola Gatto's avatar
Nicola Gatto committed
617

Sebastian N.'s avatar
Sebastian N. committed
618
    def parameter_path(self, index):
619
        return self._net_creator._model_dir_ + self._net_creator._model_prefix_ + '_' + str(index)