CNNSupervisedTrainer.ftl 17.9 KB
Newer Older
Bernhard Rumpe's avatar
BR-sy  
Bernhard Rumpe committed
1
<#-- (c) https://github.com/MontiCore/monticore -->
Nicola Gatto's avatar
Nicola Gatto committed
2 3 4 5 6 7
import mxnet as mx
import logging
import numpy as np
import time
import os
import shutil
8
import pickle
Sebastian N.'s avatar
Sebastian N. committed
9 10
import math
import sys
Nicola Gatto's avatar
Nicola Gatto committed
11 12
from mxnet import gluon, autograd, nd

Eyüp Harputlu's avatar
Eyüp Harputlu committed
13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28
class CrossEntropyLoss(gluon.loss.Loss):
    def __init__(self, axis=-1, sparse_label=True, weight=None, batch_axis=0, **kwargs):
        super(CrossEntropyLoss, self).__init__(weight, batch_axis, **kwargs)
        self._axis = axis
        self._sparse_label = sparse_label

    def hybrid_forward(self, F, pred, label, sample_weight=None):
        pred = F.log(pred)
        if self._sparse_label:
            loss = -F.pick(pred, label, axis=self._axis, keepdims=True)
        else:
            label = gluon.loss._reshape_like(F, label, pred)
            loss = -F.sum(pred * label, axis=self._axis, keepdims=True)
        loss = gluon.loss._apply_weighting(F, loss, self._weight, sample_weight)
        return F.mean(loss, axis=self._batch_axis, exclude=True)

Eyüp Harputlu's avatar
Eyüp Harputlu committed
29 30 31 32 33 34 35 36 37
class LogCoshLoss(gluon.loss.Loss):
    def __init__(self, weight=None, batch_axis=0, **kwargs):
        super(LogCoshLoss, self).__init__(weight, batch_axis, **kwargs)

    def hybrid_forward(self, F, pred, label, sample_weight=None):
        loss = F.log(F.cosh(pred - label))
        loss = gluon.loss._apply_weighting(F, loss, self._weight, sample_weight)
        return F.mean(loss, axis=self._batch_axis, exclude=True)

38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53
class SoftmaxCrossEntropyLossIgnoreIndices(gluon.loss.Loss):
    def __init__(self, axis=-1, ignore_indices=[], sparse_label=True, from_logits=False, weight=None, batch_axis=0, **kwargs):
        super(SoftmaxCrossEntropyLossIgnoreIndices, self).__init__(weight, batch_axis, **kwargs)
        self._axis = axis
        self._ignore_indices = ignore_indices
        self._sparse_label = sparse_label
        self._from_logits = from_logits

    def hybrid_forward(self, F, pred, label, sample_weight=None):
        log_softmax = F.log_softmax
        pick = F.pick
        if not self._from_logits:
            pred = log_softmax(pred, self._axis)
        if self._sparse_label:
            loss = -pick(pred, label, axis=self._axis, keepdims=True)
        else:
54
            label = gluon.loss._reshape_like(F, label, pred)
55 56 57
            loss = -(pred * label).sum(axis=self._axis, keepdims=True)
        # ignore some indices for loss, e.g. <pad> tokens in NLP applications
        for i in self._ignore_indices:
58
            loss = loss * mx.nd.logical_not(mx.nd.equal(mx.nd.argmax(pred, axis=1), mx.nd.ones_like(mx.nd.argmax(pred, axis=1))*i) * mx.nd.equal(mx.nd.argmax(pred, axis=1), label))
59 60
        return loss.mean(axis=self._batch_axis, exclude=True)

Julian Treiber's avatar
Julian Treiber committed
61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89
class DiceLoss(gluon.loss.Loss):
    def __init__(self, axis=-1, sparse_label=True, from_logits=False, weight=None,
                 batch_axis=0, **kwargs):
        super(DiceLoss, self).__init__(weight, batch_axis, **kwargs)
        self._axis = axis
        self._sparse_label = sparse_label
        self._from_logits = from_logits

    def dice_loss(self, F, pred, label):
        smooth = 1.
        pred_y = F.argmax(pred, axis = self._axis)
        intersection = pred_y * label
        score = (2 * F.mean(intersection, axis=self._batch_axis, exclude=True) + smooth) \
            / (F.mean(label, axis=self._batch_axis, exclude=True) + F.mean(pred_y, axis=self._batch_axis, exclude=True) + smooth)

        return - F.log(score)

    def hybrid_forward(self, F, pred, label, sample_weight=None):
        if not self._from_logits:
            pred = F.log_softmax(pred, self._axis)
        if self._sparse_label:
            loss = -F.pick(pred, label, axis=self._axis, keepdims=True)
        else:
            label = gluon.loss._reshape_like(F, label, pred)
            loss = -F.sum(pred*label, axis=self._axis, keepdims=True)
        loss = gluon.loss._apply_weighting(F, loss, self._weight, sample_weight)
        diceloss = self.dice_loss(F, pred, label)
        return F.mean(loss, axis=self._batch_axis, exclude=True) + diceloss

Sebastian N.'s avatar
Sebastian N. committed
90 91 92 93 94 95 96
@mx.metric.register
class BLEU(mx.metric.EvalMetric):
    N = 4

    def __init__(self, exclude=None, name='bleu', output_names=None, label_names=None):
        super(BLEU, self).__init__(name=name, output_names=output_names, label_names=label_names)

Sebastian N.'s avatar
Sebastian N. committed
97
        self._exclude = exclude or []
Sebastian N.'s avatar
Sebastian N. committed
98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146

        self._match_counts = [0 for _ in range(self.N)]
        self._counts = [0 for _ in range(self.N)]

        self._size_ref = 0
        self._size_hyp = 0

    def update(self, labels, preds):
        labels, preds = mx.metric.check_label_shapes(labels, preds, True)

        new_labels = self._convert(labels)
        new_preds = self._convert(preds)

        for label, pred in zip(new_labels, new_preds):
            reference = [word for word in label if word not in self._exclude]
            hypothesis = [word for word in pred if word not in self._exclude]

            self._size_ref += len(reference)
            self._size_hyp += len(hypothesis)

            for n in range(self.N):
                reference_ngrams = self._get_ngrams(reference, n + 1)
                hypothesis_ngrams = self._get_ngrams(hypothesis, n + 1)

                match_count = 0

                for ngram in hypothesis_ngrams:
                    if ngram in reference_ngrams:
                        reference_ngrams.remove(ngram)

                        match_count += 1

                self._match_counts[n] += match_count
                self._counts[n] += len(hypothesis_ngrams)

    def get(self):
        precisions = [sys.float_info.min for n in range(self.N)]

        i = 1

        for n in range(self.N):
            match_counts = self._match_counts[n]
            counts = self._counts[n]

            if counts != 0:
                if match_counts == 0:
                    i *= 2
                    match_counts = 1 / i

147 148
                if (match_counts / counts) > 0:
                    precisions[n] = match_counts / counts
Sebastian N.'s avatar
Sebastian N. committed
149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175

        bleu = self._get_brevity_penalty() * math.exp(sum(map(math.log, precisions)) / self.N)

        return (self.name, bleu)

    def calculate(self):
        precisions = [sys.float_info.min for n in range(self.N)]

        i = 1

        for n in range(self.N):
            match_counts = self._match_counts[n]
            counts = self._counts[n]

            if counts != 0:
                if match_counts == 0:
                    i *= 2
                    match_counts = 1 / i

                precisions[n] = match_counts / counts

        return self._get_brevity_penalty() * math.exp(sum(map(math.log, precisions)) / self.N)

    def _get_brevity_penalty(self):
        if self._size_hyp >= self._size_ref:
            return 1
        else:
176 177 178 179 180 181
            if self._size_hyp > 0:
                size_hyp = self._size_hyp
            else:
                size_hyp = 1

            return math.exp(1 - (self._size_ref / size_hyp))
Sebastian N.'s avatar
Sebastian N. committed
182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204

    @staticmethod
    def _get_ngrams(sentence, n):
        ngrams = []

        if len(sentence) >= n:
            for i in range(len(sentence) - n + 1):
                ngrams.append(sentence[i:i+n])

        return ngrams

    @staticmethod
    def _convert(nd_list):
        if len(nd_list) == 0:
            return []

        new_list = [[] for _ in range(nd_list[0].shape[0])]

        for element in nd_list:
            for i in range(element.shape[0]):
                new_list[i].append(element[i].asscalar())

        return new_list
205

206 207


208
class ${tc.fileNameWithoutEnding}:
209
    def __init__(self, data_loader, net_constructor):
Nicola Gatto's avatar
Nicola Gatto committed
210 211
        self._data_loader = data_loader
        self._net_creator = net_constructor
212
        self._networks = {}
Nicola Gatto's avatar
Nicola Gatto committed
213 214 215 216

    def train(self, batch_size=64,
              num_epoch=10,
              eval_metric='acc',
Sebastian N.'s avatar
Sebastian N. committed
217
              eval_metric_params={},
218
              eval_train=False,
Eyüp Harputlu's avatar
Eyüp Harputlu committed
219 220
              loss ='softmax_cross_entropy',
              loss_params={},
Nicola Gatto's avatar
Nicola Gatto committed
221 222 223 224
              optimizer='adam',
              optimizer_params=(('learning_rate', 0.001),),
              load_checkpoint=True,
              checkpoint_period=5,
Julian Treiber's avatar
Julian Treiber committed
225
              load_pretrained=False,
226 227
              log_period=50,
              context='gpu',
228
              save_attention_image=False,
229
              use_teacher_forcing=False,
230
              normalize=True,
231 232
              shuffle_data=False,
              clip_global_grad_norm=None,
233
              preprocessing = False):
Nicola Gatto's avatar
Nicola Gatto committed
234 235 236 237 238 239 240
        if context == 'gpu':
            mx_context = mx.gpu()
        elif context == 'cpu':
            mx_context = mx.cpu()
        else:
            logging.error("Context argument is '" + context + "'. Only 'cpu' and 'gpu are valid arguments'.")

241 242
        if preprocessing:
            preproc_lib = "CNNPreprocessor_${tc.fileNameWithoutEnding?keep_after("CNNSupervisedTrainer_")}_executor"
243
            train_iter, test_iter, data_mean, data_std, train_images, test_images = self._data_loader.load_preprocessed_data(batch_size, preproc_lib, shuffle_data)
244
        else:
245
            train_iter, test_iter, data_mean, data_std, train_images, test_images = self._data_loader.load_data(batch_size, shuffle_data)
246

Nicola Gatto's avatar
Nicola Gatto committed
247 248 249 250 251 252 253 254 255 256 257 258 259 260 261
        if 'weight_decay' in optimizer_params:
            optimizer_params['wd'] = optimizer_params['weight_decay']
            del optimizer_params['weight_decay']
        if 'learning_rate_decay' in optimizer_params:
            min_learning_rate = 1e-08
            if 'learning_rate_minimum' in optimizer_params:
                min_learning_rate = optimizer_params['learning_rate_minimum']
                del optimizer_params['learning_rate_minimum']
            optimizer_params['lr_scheduler'] = mx.lr_scheduler.FactorScheduler(
                                                   optimizer_params['step_size'],
                                                   factor=optimizer_params['learning_rate_decay'],
                                                   stop_factor_lr=min_learning_rate)
            del optimizer_params['step_size']
            del optimizer_params['learning_rate_decay']

262 263 264 265
        if normalize:
            self._net_creator.construct(context=mx_context, data_mean=data_mean, data_std=data_std)
        else:
            self._net_creator.construct(context=mx_context)
Nicola Gatto's avatar
Nicola Gatto committed
266 267 268

        begin_epoch = 0
        if load_checkpoint:
Julian Treiber's avatar
Julian Treiber committed
269
            begin_epoch = self._net_creator.load(mx_context, load_pretrained=load_pretrained)
Nicola Gatto's avatar
Nicola Gatto committed
270 271 272 273
        else:
            if os.path.isdir(self._net_creator._model_dir_):
                shutil.rmtree(self._net_creator._model_dir_)

274
        self._networks = self._net_creator.networks
Nicola Gatto's avatar
Nicola Gatto committed
275 276 277 278 279 280 281

        try:
            os.makedirs(self._net_creator._model_dir_)
        except OSError:
            if not os.path.isdir(self._net_creator._model_dir_):
                raise

Sebastian N.'s avatar
Sebastian N. committed
282
        trainers = [mx.gluon.Trainer(network.collect_params(), optimizer, optimizer_params) for network in self._networks.values() if len(network.collect_params().values()) != 0]
Nicola Gatto's avatar
Nicola Gatto committed
283

Eyüp Harputlu's avatar
Eyüp Harputlu committed
284 285
        margin = loss_params['margin'] if 'margin' in loss_params else 1.0
        sparseLabel = loss_params['sparse_label'] if 'sparse_label' in loss_params else True
286
        ignore_indices = [loss_params['ignore_indices']] if 'ignore_indices' in loss_params else []
287
        loss_axis = loss_params['loss_axis'] if 'loss_axis' in loss_params else -1
288
        batch_axis = loss_params['batch_axis'] if 'batch_axis' in loss_params else 0
Eyüp Harputlu's avatar
Eyüp Harputlu committed
289 290
        if loss == 'softmax_cross_entropy':
            fromLogits = loss_params['from_logits'] if 'from_logits' in loss_params else False
291
            loss_function = mx.gluon.loss.SoftmaxCrossEntropyLoss(axis=loss_axis, from_logits=fromLogits, sparse_label=sparseLabel, batch_axis=batch_axis)
292
        elif loss == 'softmax_cross_entropy_ignore_indices':
293
            fromLogits = loss_params['from_logits'] if 'from_logits' in loss_params else False
Julian Treiber's avatar
Julian Treiber committed
294
            loss_function = SoftmaxCrossEntropyLossIgnoreIndices(axis=loss_axis, ignore_indices=ignore_indices, from_logits=fromLogits, sparse_label=sparseLabel, batch_axis=batch_axis)
Eyüp Harputlu's avatar
Eyüp Harputlu committed
295
        elif loss == 'sigmoid_binary_cross_entropy':
Nicola Gatto's avatar
Nicola Gatto committed
296
            loss_function = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss()
Eyüp Harputlu's avatar
Eyüp Harputlu committed
297
        elif loss == 'cross_entropy':
298
            loss_function = CrossEntropyLoss(axis=loss_axis, sparse_label=sparseLabel, batch_axis=batch_axis)
Julian Treiber's avatar
Julian Treiber committed
299
        elif loss == 'dice_loss':
Julian Treiber's avatar
Julian Treiber committed
300 301
            loss_weight = loss_params['loss_weight'] if 'loss_weight' in loss_params else None
            loss_function = DiceLoss(axis=loss_axis, weight=loss_weight, sparse_label=sparseLabel, batch_axis=batch_axis)
Eyüp Harputlu's avatar
Eyüp Harputlu committed
302
        elif loss == 'l2':
Nicola Gatto's avatar
Nicola Gatto committed
303
            loss_function = mx.gluon.loss.L2Loss()
Eyüp Harputlu's avatar
Eyüp Harputlu committed
304
        elif loss == 'l1':
Nicola Gatto's avatar
Nicola Gatto committed
305
            loss_function = mx.gluon.loss.L2Loss()
Eyüp Harputlu's avatar
Eyüp Harputlu committed
306 307 308 309 310 311 312 313 314 315 316 317 318
        elif loss == 'huber':
            rho = loss_params['rho'] if 'rho' in loss_params else 1
            loss_function = mx.gluon.loss.HuberLoss(rho=rho)
        elif loss == 'hinge':
            loss_function = mx.gluon.loss.HingeLoss(margin=margin)
        elif loss == 'squared_hinge':
            loss_function = mx.gluon.loss.SquaredHingeLoss(margin=margin)
        elif loss == 'logistic':
            labelFormat = loss_params['label_format'] if 'label_format' in loss_params else 'signed'
            loss_function = mx.gluon.loss.LogisticLoss(label_format=labelFormat)
        elif loss == 'kullback_leibler':
            fromLogits = loss_params['from_logits'] if 'from_logits' in loss_params else True
            loss_function = mx.gluon.loss.KLDivLoss(from_logits=fromLogits)
Eyüp Harputlu's avatar
Eyüp Harputlu committed
319 320
        elif loss == 'log_cosh':
            loss_function = LogCoshLoss()
Eyüp Harputlu's avatar
Eyüp Harputlu committed
321 322
        else:
            logging.error("Invalid loss parameter.")
Nicola Gatto's avatar
Nicola Gatto committed
323 324 325 326

        tic = None

        for epoch in range(begin_epoch, begin_epoch + num_epoch):
327 328 329 330 331 332
            if shuffle_data:
                if preprocessing:
                    preproc_lib = "CNNPreprocessor_${tc.fileNameWithoutEnding?keep_after("CNNSupervisedTrainer_")}_executor"
                    train_iter, test_iter, data_mean, data_std, train_images, test_images = self._data_loader.load_preprocessed_data(batch_size, preproc_lib, shuffle_data)
                else:
                    train_iter, test_iter, data_mean, data_std, train_images, test_images = self._data_loader.load_data(batch_size, shuffle_data)
333

334 335 336
            global_loss_train = 0.0
            train_batches = 0

337
            loss_total = 0
Nicola Gatto's avatar
Nicola Gatto committed
338 339 340
            train_iter.reset()
            for batch_i, batch in enumerate(train_iter):
                with autograd.record():
341
<#include "pythonExecuteTrain.ftl">
342

343 344 345
                    loss = 0
                    for element in lossList:
                        loss = loss + element
Nicola Gatto's avatar
Nicola Gatto committed
346 347

                loss.backward()
348

349 350
                loss_total += loss.sum().asscalar()

Sebastian N.'s avatar
Sebastian N. committed
351
                global_loss_train += loss.sum().asscalar()
352 353
                train_batches += 1

354 355 356 357 358 359 360 361
                if clip_global_grad_norm:
                    grads = []

                    for network in self._networks.values():
                        grads.extend([param.grad(mx_context) for param in network.collect_params().values()])

                    gluon.utils.clip_global_norm(grads, clip_global_grad_norm)

362 363
                for trainer in trainers:
                    trainer.step(batch_size)
Nicola Gatto's avatar
Nicola Gatto committed
364 365 366 367

                if tic is None:
                    tic = time.time()
                else:
368
                    if batch_i % log_period == 0:
Nicola Gatto's avatar
Nicola Gatto committed
369
                        try:
370
                            speed = log_period * batch_size / (time.time() - tic)
Nicola Gatto's avatar
Nicola Gatto committed
371 372 373
                        except ZeroDivisionError:
                            speed = float("inf")

374 375 376 377
                        loss_avg = loss_total / (batch_size * log_period)
                        loss_total = 0

                        logging.info("Epoch[%d] Batch[%d] Speed: %.2f samples/sec Loss: %.5f" % (epoch, batch_i, speed, loss_avg))
Nicola Gatto's avatar
Nicola Gatto committed
378 379 380

                        tic = time.time()

Sebastian N.'s avatar
Sebastian N. committed
381
            global_loss_train /= (train_batches * batch_size)
382

Nicola Gatto's avatar
Nicola Gatto committed
383 384
            tic = None

385

386 387 388 389
            if eval_train:
                train_iter.reset()
                metric = mx.metric.create(eval_metric, **eval_metric_params)
                for batch_i, batch in enumerate(train_iter):
390
<#include "pythonExecuteTest.ftl">
391

392

393 394 395
<#include "saveAttentionImageTrain.ftl">


396 397 398 399 400 401
                    predictions = []
                    for output_name in outputs:
                        if mx.nd.shape_array(mx.nd.squeeze(output_name)).size > 1:
                            predictions.append(mx.nd.argmax(output_name, axis=1))
                        else:
                            predictions.append(output_name)
402

403 404 405 406
                    metric.update(preds=predictions, labels=labels)
                train_metric_score = metric.get()[1]
            else:
                train_metric_score = 0
Nicola Gatto's avatar
Nicola Gatto committed
407

408 409 410
            global_loss_test = 0.0
            test_batches = 0

Nicola Gatto's avatar
Nicola Gatto committed
411
            test_iter.reset()
Sebastian N.'s avatar
Sebastian N. committed
412
            metric = mx.metric.create(eval_metric, **eval_metric_params)
Nicola Gatto's avatar
Nicola Gatto committed
413
            for batch_i, batch in enumerate(test_iter):
414
                if True: <#-- Fix indentation -->
415
<#include "pythonExecuteTest.ftl">
416 417


418 419
<#include "saveAttentionImageTest.ftl">

420 421 422 423
                loss = 0
                for element in lossList:
                    loss = loss + element

Sebastian N.'s avatar
Sebastian N. committed
424
                global_loss_test += loss.sum().asscalar()
425
                test_batches += 1
426

427
                predictions = []
428
                for output_name in outputs:
429
                    predictions.append(output_name)
430 431

                metric.update(preds=predictions, labels=labels)
Nicola Gatto's avatar
Nicola Gatto committed
432 433
            test_metric_score = metric.get()[1]

Sebastian N.'s avatar
Sebastian N. committed
434
            global_loss_test /= (test_batches * batch_size)
Nicola Gatto's avatar
Nicola Gatto committed
435

Sebastian N.'s avatar
Sebastian N. committed
436
            logging.info("Epoch[%d] Train metric: %f, Test metric: %f, Train loss: %f, Test loss: %f" % (epoch, train_metric_score, test_metric_score, global_loss_train, global_loss_test))
437

Nicola Gatto's avatar
Nicola Gatto committed
438
            if (epoch - begin_epoch) % checkpoint_period == 0:
439 440
                for i, network in self._networks.items():
                    network.save_parameters(self.parameter_path(i) + '-' + str(epoch).zfill(4) + '.params')
Nicola Gatto's avatar
Nicola Gatto committed
441

442
        for i, network in self._networks.items():
Sebastian N.'s avatar
Sebastian N. committed
443
            network.save_parameters(self.parameter_path(i) + '-' + str(num_epoch + begin_epoch + 1).zfill(4) + '.params')
444
            network.export(self.parameter_path(i) + '_newest', epoch=0)
Nicola Gatto's avatar
Nicola Gatto committed
445

446
    def parameter_path(self, index):
Bernhard Rumpe's avatar
BR-sy  
Bernhard Rumpe committed
447
        return self._net_creator._model_dir_ + self._net_creator._model_prefix_ + '_' + str(index)