CNNSupervisedTrainer_CifarClassifierNetwork.py 24.1 KB
Newer Older
Nicola Gatto's avatar
Nicola Gatto committed
1
2
3
4
5
6
import mxnet as mx
import logging
import numpy as np
import time
import os
import shutil
Christian Fuß's avatar
Christian Fuß committed
7
import pickle
Sebastian Nickels's avatar
Sebastian Nickels committed
8
9
import math
import sys
Nicola Gatto's avatar
Nicola Gatto committed
10
11
from mxnet import gluon, autograd, nd

Eyüp Harputlu's avatar
Eyüp Harputlu committed
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
class CrossEntropyLoss(gluon.loss.Loss):
    def __init__(self, axis=-1, sparse_label=True, weight=None, batch_axis=0, **kwargs):
        super(CrossEntropyLoss, self).__init__(weight, batch_axis, **kwargs)
        self._axis = axis
        self._sparse_label = sparse_label

    def hybrid_forward(self, F, pred, label, sample_weight=None):
        pred = F.log(pred)
        if self._sparse_label:
            loss = -F.pick(pred, label, axis=self._axis, keepdims=True)
        else:
            label = gluon.loss._reshape_like(F, label, pred)
            loss = -F.sum(pred * label, axis=self._axis, keepdims=True)
        loss = gluon.loss._apply_weighting(F, loss, self._weight, sample_weight)
        return F.mean(loss, axis=self._batch_axis, exclude=True)

Eyüp Harputlu's avatar
Eyüp Harputlu committed
28
29
30
31
32
33
34
35
36
class LogCoshLoss(gluon.loss.Loss):
    def __init__(self, weight=None, batch_axis=0, **kwargs):
        super(LogCoshLoss, self).__init__(weight, batch_axis, **kwargs)

    def hybrid_forward(self, F, pred, label, sample_weight=None):
        loss = F.log(F.cosh(pred - label))
        loss = gluon.loss._apply_weighting(F, loss, self._weight, sample_weight)
        return F.mean(loss, axis=self._batch_axis, exclude=True)

Sebastian Nickels's avatar
Sebastian Nickels committed
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
class SoftmaxCrossEntropyLossIgnoreIndices(gluon.loss.Loss):
    def __init__(self, axis=-1, ignore_indices=[], sparse_label=True, from_logits=False, weight=None, batch_axis=0, **kwargs):
        super(SoftmaxCrossEntropyLossIgnoreIndices, self).__init__(weight, batch_axis, **kwargs)
        self._axis = axis
        self._ignore_indices = ignore_indices
        self._sparse_label = sparse_label
        self._from_logits = from_logits

    def hybrid_forward(self, F, pred, label, sample_weight=None):
        log_softmax = F.log_softmax
        pick = F.pick
        if not self._from_logits:
            pred = log_softmax(pred, self._axis)
        if self._sparse_label:
            loss = -pick(pred, label, axis=self._axis, keepdims=True)
        else:
Julian Treiber's avatar
Julian Treiber committed
53
            label = gluon.loss._reshape_like(F, label, pred)
Sebastian Nickels's avatar
Sebastian Nickels committed
54
55
56
            loss = -(pred * label).sum(axis=self._axis, keepdims=True)
        # ignore some indices for loss, e.g. <pad> tokens in NLP applications
        for i in self._ignore_indices:
57
            loss = loss * mx.nd.logical_not(mx.nd.equal(mx.nd.argmax(pred, axis=1), mx.nd.ones_like(mx.nd.argmax(pred, axis=1))*i) * mx.nd.equal(mx.nd.argmax(pred, axis=1), label))
Sebastian Nickels's avatar
Sebastian Nickels committed
58
59
        return loss.mean(axis=self._batch_axis, exclude=True)

Julian Treiber's avatar
Julian Treiber committed
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
class DiceLoss(gluon.loss.Loss):
    def __init__(self, axis=-1, sparse_label=True, from_logits=False, weight=None,
                 batch_axis=0, **kwargs):
        super(DiceLoss, self).__init__(weight, batch_axis, **kwargs)
        self._axis = axis
        self._sparse_label = sparse_label
        self._from_logits = from_logits

    def dice_loss(self, F, pred, label):
        smooth = 1.
        pred_y = F.argmax(pred, axis = self._axis)
        intersection = pred_y * label
        score = (2 * F.mean(intersection, axis=self._batch_axis, exclude=True) + smooth) \
            / (F.mean(label, axis=self._batch_axis, exclude=True) + F.mean(pred_y, axis=self._batch_axis, exclude=True) + smooth)

        return - F.log(score)

    def hybrid_forward(self, F, pred, label, sample_weight=None):
        if not self._from_logits:
            pred = F.log_softmax(pred, self._axis)
        if self._sparse_label:
            loss = -F.pick(pred, label, axis=self._axis, keepdims=True)
        else:
            label = gluon.loss._reshape_like(F, label, pred)
            loss = -F.sum(pred*label, axis=self._axis, keepdims=True)
        loss = gluon.loss._apply_weighting(F, loss, self._weight, sample_weight)
        diceloss = self.dice_loss(F, pred, label)
        return F.mean(loss, axis=self._batch_axis, exclude=True) + diceloss

Sebastian Nickels's avatar
Sebastian Nickels committed
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
@mx.metric.register
class BLEU(mx.metric.EvalMetric):
    N = 4

    def __init__(self, exclude=None, name='bleu', output_names=None, label_names=None):
        super(BLEU, self).__init__(name=name, output_names=output_names, label_names=label_names)

        self._exclude = exclude or []

        self._match_counts = [0 for _ in range(self.N)]
        self._counts = [0 for _ in range(self.N)]

        self._size_ref = 0
        self._size_hyp = 0

    def update(self, labels, preds):
        labels, preds = mx.metric.check_label_shapes(labels, preds, True)

        new_labels = self._convert(labels)
        new_preds = self._convert(preds)

        for label, pred in zip(new_labels, new_preds):
            reference = [word for word in label if word not in self._exclude]
            hypothesis = [word for word in pred if word not in self._exclude]

            self._size_ref += len(reference)
            self._size_hyp += len(hypothesis)

            for n in range(self.N):
                reference_ngrams = self._get_ngrams(reference, n + 1)
                hypothesis_ngrams = self._get_ngrams(hypothesis, n + 1)

                match_count = 0

                for ngram in hypothesis_ngrams:
                    if ngram in reference_ngrams:
                        reference_ngrams.remove(ngram)

                        match_count += 1

                self._match_counts[n] += match_count
                self._counts[n] += len(hypothesis_ngrams)

    def get(self):
        precisions = [sys.float_info.min for n in range(self.N)]

        i = 1

        for n in range(self.N):
            match_counts = self._match_counts[n]
            counts = self._counts[n]

            if counts != 0:
                if match_counts == 0:
                    i *= 2
                    match_counts = 1 / i

146
147
                if (match_counts / counts) > 0:
                    precisions[n] = match_counts / counts
Sebastian Nickels's avatar
Sebastian Nickels committed
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174

        bleu = self._get_brevity_penalty() * math.exp(sum(map(math.log, precisions)) / self.N)

        return (self.name, bleu)

    def calculate(self):
        precisions = [sys.float_info.min for n in range(self.N)]

        i = 1

        for n in range(self.N):
            match_counts = self._match_counts[n]
            counts = self._counts[n]

            if counts != 0:
                if match_counts == 0:
                    i *= 2
                    match_counts = 1 / i

                precisions[n] = match_counts / counts

        return self._get_brevity_penalty() * math.exp(sum(map(math.log, precisions)) / self.N)

    def _get_brevity_penalty(self):
        if self._size_hyp >= self._size_ref:
            return 1
        else:
Sebastian Nickels's avatar
Sebastian Nickels committed
175
176
177
178
179
180
            if self._size_hyp > 0:
                size_hyp = self._size_hyp
            else:
                size_hyp = 1

            return math.exp(1 - (self._size_ref / size_hyp))
Sebastian Nickels's avatar
Sebastian Nickels committed
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203

    @staticmethod
    def _get_ngrams(sentence, n):
        ngrams = []

        if len(sentence) >= n:
            for i in range(len(sentence) - n + 1):
                ngrams.append(sentence[i:i+n])

        return ngrams

    @staticmethod
    def _convert(nd_list):
        if len(nd_list) == 0:
            return []

        new_list = [[] for _ in range(nd_list[0].shape[0])]

        for element in nd_list:
            for i in range(element.shape[0]):
                new_list[i].append(element[i].asscalar())

        return new_list
Christian Fuß's avatar
Christian Fuß committed
204

Sebastian Nickels's avatar
Sebastian Nickels committed
205
206


207
class CNNSupervisedTrainer_CifarClassifierNetwork:
208
    def __init__(self, data_loader, net_constructor):
Nicola Gatto's avatar
Nicola Gatto committed
209
210
        self._data_loader = data_loader
        self._net_creator = net_constructor
211
        self._networks = {}
Nicola Gatto's avatar
Nicola Gatto committed
212
213
214
215

    def train(self, batch_size=64,
              num_epoch=10,
              eval_metric='acc',
Sebastian Nickels's avatar
Sebastian Nickels committed
216
              eval_metric_params={},
217
              eval_train=False,
Eyüp Harputlu's avatar
Eyüp Harputlu committed
218
219
              loss ='softmax_cross_entropy',
              loss_params={},
Nicola Gatto's avatar
Nicola Gatto committed
220
221
222
223
              optimizer='adam',
              optimizer_params=(('learning_rate', 0.001),),
              load_checkpoint=True,
              checkpoint_period=5,
224
225
              log_period=50,
              context='gpu',
226
              save_attention_image=False,
227
              use_teacher_forcing=False,
228
              normalize=True,
229
230
              shuffle_data=False,
              clip_global_grad_norm=None,
231
              preprocessing = False):
Nicola Gatto's avatar
Nicola Gatto committed
232
233
234
235
236
237
238
        if context == 'gpu':
            mx_context = mx.gpu()
        elif context == 'cpu':
            mx_context = mx.cpu()
        else:
            logging.error("Context argument is '" + context + "'. Only 'cpu' and 'gpu are valid arguments'.")

239
240
        if preprocessing:
            preproc_lib = "CNNPreprocessor_CifarClassifierNetwork_executor"
241
            train_iter, test_iter, data_mean, data_std, train_images, test_images = self._data_loader.load_preprocessed_data(batch_size, preproc_lib, shuffle_data)
242
        else:
243
            train_iter, test_iter, data_mean, data_std, train_images, test_images = self._data_loader.load_data(batch_size, shuffle_data)
244

Nicola Gatto's avatar
Nicola Gatto committed
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
        if 'weight_decay' in optimizer_params:
            optimizer_params['wd'] = optimizer_params['weight_decay']
            del optimizer_params['weight_decay']
        if 'learning_rate_decay' in optimizer_params:
            min_learning_rate = 1e-08
            if 'learning_rate_minimum' in optimizer_params:
                min_learning_rate = optimizer_params['learning_rate_minimum']
                del optimizer_params['learning_rate_minimum']
            optimizer_params['lr_scheduler'] = mx.lr_scheduler.FactorScheduler(
                                                   optimizer_params['step_size'],
                                                   factor=optimizer_params['learning_rate_decay'],
                                                   stop_factor_lr=min_learning_rate)
            del optimizer_params['step_size']
            del optimizer_params['learning_rate_decay']

260
261
262
263
        if normalize:
            self._net_creator.construct(context=mx_context, data_mean=data_mean, data_std=data_std)
        else:
            self._net_creator.construct(context=mx_context)
Nicola Gatto's avatar
Nicola Gatto committed
264
265
266

        begin_epoch = 0
        if load_checkpoint:
Sebastian Nickels's avatar
Sebastian Nickels committed
267
            begin_epoch = self._net_creator.load(mx_context)
Nicola Gatto's avatar
Nicola Gatto committed
268
269
270
271
        else:
            if os.path.isdir(self._net_creator._model_dir_):
                shutil.rmtree(self._net_creator._model_dir_)

272
        self._networks = self._net_creator.networks
Nicola Gatto's avatar
Nicola Gatto committed
273
274
275
276
277
278
279

        try:
            os.makedirs(self._net_creator._model_dir_)
        except OSError:
            if not os.path.isdir(self._net_creator._model_dir_):
                raise

Sebastian Nickels's avatar
Sebastian Nickels committed
280
        trainers = [mx.gluon.Trainer(network.collect_params(), optimizer, optimizer_params) for network in self._networks.values() if len(network.collect_params().values()) != 0]
Nicola Gatto's avatar
Nicola Gatto committed
281

Eyüp Harputlu's avatar
Eyüp Harputlu committed
282
283
        margin = loss_params['margin'] if 'margin' in loss_params else 1.0
        sparseLabel = loss_params['sparse_label'] if 'sparse_label' in loss_params else True
284
        ignore_indices = [loss_params['ignore_indices']] if 'ignore_indices' in loss_params else []
285
        loss_axis = loss_params['loss_axis'] if 'loss_axis' in loss_params else -1
286
        batch_axis = loss_params['batch_axis'] if 'batch_axis' in loss_params else 0
Eyüp Harputlu's avatar
Eyüp Harputlu committed
287
288
        if loss == 'softmax_cross_entropy':
            fromLogits = loss_params['from_logits'] if 'from_logits' in loss_params else False
289
            loss_function = mx.gluon.loss.SoftmaxCrossEntropyLoss(axis=loss_axis, from_logits=fromLogits, sparse_label=sparseLabel, batch_axis=batch_axis)
290
        elif loss == 'softmax_cross_entropy_ignore_indices':
291
            fromLogits = loss_params['from_logits'] if 'from_logits' in loss_params else False
Julian Treiber's avatar
Julian Treiber committed
292
            loss_function = SoftmaxCrossEntropyLossIgnoreIndices(axis=loss_axis, ignore_indices=ignore_indices, from_logits=fromLogits, sparse_label=sparseLabel, batch_axis=batch_axis)
Eyüp Harputlu's avatar
Eyüp Harputlu committed
293
        elif loss == 'sigmoid_binary_cross_entropy':
Nicola Gatto's avatar
Nicola Gatto committed
294
            loss_function = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss()
Eyüp Harputlu's avatar
Eyüp Harputlu committed
295
        elif loss == 'cross_entropy':
296
            loss_function = CrossEntropyLoss(axis=loss_axis, sparse_label=sparseLabel, batch_axis=batch_axis)
Julian Treiber's avatar
Julian Treiber committed
297
298
299
        elif loss == 'dice_loss':
            fromLogits = loss_params['from_logits'] if 'from_logits' in loss_params else False
            dice_weight = loss_params['dice_weight'] if 'dice_weight' in loss_params else None
Julian Treiber's avatar
Julian Treiber committed
300
            loss_function = DiceLoss(axis=loss_axis, from_logits=fromLogits, weight=dice_weight, sparse_label=sparseLabel, batch_axis=batch_axis)
Eyüp Harputlu's avatar
Eyüp Harputlu committed
301
        elif loss == 'l2':
Nicola Gatto's avatar
Nicola Gatto committed
302
            loss_function = mx.gluon.loss.L2Loss()
Eyüp Harputlu's avatar
Eyüp Harputlu committed
303
        elif loss == 'l1':
Nicola Gatto's avatar
Nicola Gatto committed
304
            loss_function = mx.gluon.loss.L2Loss()
Eyüp Harputlu's avatar
Eyüp Harputlu committed
305
306
307
308
309
310
311
312
313
314
315
316
317
        elif loss == 'huber':
            rho = loss_params['rho'] if 'rho' in loss_params else 1
            loss_function = mx.gluon.loss.HuberLoss(rho=rho)
        elif loss == 'hinge':
            loss_function = mx.gluon.loss.HingeLoss(margin=margin)
        elif loss == 'squared_hinge':
            loss_function = mx.gluon.loss.SquaredHingeLoss(margin=margin)
        elif loss == 'logistic':
            labelFormat = loss_params['label_format'] if 'label_format' in loss_params else 'signed'
            loss_function = mx.gluon.loss.LogisticLoss(label_format=labelFormat)
        elif loss == 'kullback_leibler':
            fromLogits = loss_params['from_logits'] if 'from_logits' in loss_params else True
            loss_function = mx.gluon.loss.KLDivLoss(from_logits=fromLogits)
Eyüp Harputlu's avatar
Eyüp Harputlu committed
318
319
        elif loss == 'log_cosh':
            loss_function = LogCoshLoss()
Eyüp Harputlu's avatar
Eyüp Harputlu committed
320
321
        else:
            logging.error("Invalid loss parameter.")
Nicola Gatto's avatar
Nicola Gatto committed
322
323
324
325

        tic = None

        for epoch in range(begin_epoch, begin_epoch + num_epoch):
326
327
328
329
330
331
332
333
334
            if shuffle_data:
                if preprocessing:
                    preproc_lib = "CNNPreprocessor_CifarClassifierNetwork_executor"
                    train_iter, test_iter, data_mean, data_std, train_images, test_images = self._data_loader.load_preprocessed_data(batch_size, preproc_lib, shuffle_data)
                else:
                    train_iter, test_iter, data_mean, data_std, train_images, test_images = self._data_loader.load_data(batch_size, shuffle_data)

            global_loss_train = 0.0
            train_batches = 0
335
336

            loss_total = 0
Nicola Gatto's avatar
Nicola Gatto committed
337
338
            train_iter.reset()
            for batch_i, batch in enumerate(train_iter):
339
340
                with autograd.record():
                    labels = [batch.label[i].as_in_context(mx_context) for i in range(1)]
341

342
343
                    data_ = batch.data[0].as_in_context(mx_context)

344
                    softmax_ = mx.nd.zeros((batch_size, 10,), ctx=mx_context)
Christian Fuß's avatar
Christian Fuß committed
345

346

347
348
                    nd.waitall()

349
                    lossList = []
350

351
                    softmax_ = self._networks[0](data_)
352
353

                    lossList.append(loss_function(softmax_, labels[0]))
354

355
356
357
                    loss = 0
                    for element in lossList:
                        loss = loss + element
Nicola Gatto's avatar
Nicola Gatto committed
358
359

                loss.backward()
360

361
362
                loss_total += loss.sum().asscalar()

363
                global_loss_train += loss.sum().asscalar()
364
365
                train_batches += 1

366
367
368
369
370
371
372
373
                if clip_global_grad_norm:
                    grads = []

                    for network in self._networks.values():
                        grads.extend([param.grad(mx_context) for param in network.collect_params().values()])

                    gluon.utils.clip_global_norm(grads, clip_global_grad_norm)

374
375
                for trainer in trainers:
                    trainer.step(batch_size)
Nicola Gatto's avatar
Nicola Gatto committed
376
377
378
379

                if tic is None:
                    tic = time.time()
                else:
380
                    if batch_i % log_period == 0:
Nicola Gatto's avatar
Nicola Gatto committed
381
                        try:
382
                            speed = log_period * batch_size / (time.time() - tic)
Nicola Gatto's avatar
Nicola Gatto committed
383
384
385
                        except ZeroDivisionError:
                            speed = float("inf")

386
387
388
389
                        loss_avg = loss_total / (batch_size * log_period)
                        loss_total = 0

                        logging.info("Epoch[%d] Batch[%d] Speed: %.2f samples/sec Loss: %.5f" % (epoch, batch_i, speed, loss_avg))
Nicola Gatto's avatar
Nicola Gatto committed
390
391
392

                        tic = time.time()

393
            global_loss_train /= (train_batches * batch_size)
394

Nicola Gatto's avatar
Nicola Gatto committed
395
396
            tic = None

397
398
399
400
401

            if eval_train:
                train_iter.reset()
                metric = mx.metric.create(eval_metric, **eval_metric_params)
                for batch_i, batch in enumerate(train_iter):
402
                    labels = [batch.label[i].as_in_context(mx_context) for i in range(1)]
403

404
                    data_ = batch.data[0].as_in_context(mx_context)
405

406
                    softmax_ = mx.nd.zeros((batch_size, 10,), ctx=mx_context)
407

408

409
410
                    nd.waitall()

411
                    outputs = []
412
413
                    lossList = []
                    attentionList = []
414
                    softmax_ = self._networks[0](data_)
415

416
                    outputs.append(softmax_)
417
                    lossList.append(loss_function(softmax_, labels[0]))
418

419
420

                    if save_attention_image == "True":
421
422
                        import matplotlib
                        matplotlib.use('Agg')
423
424
425
426
427
428
429
                        import matplotlib.pyplot as plt
                        logging.getLogger('matplotlib').setLevel(logging.ERROR)

                        if(os.path.isfile('src/test/resources/training_data/Show_attend_tell/dict.pkl')):
                            with open('src/test/resources/training_data/Show_attend_tell/dict.pkl', 'rb') as f:
                                dict = pickle.load(f)

430
431
432
433
                        plt.clf()
                        fig = plt.figure(figsize=(15,15))
                        max_length = len(labels)-1

434
                        ax = fig.add_subplot(max_length//3, max_length//4, 1)
435
                        ax.imshow(train_images[0+batch_size*(batch_i)].transpose(1,2,0))
436

437
438
                        for l in range(max_length):
                            attention = attentionList[l]
439
                            attention = mx.nd.slice_axis(attention, axis=0, begin=0, end=1).squeeze()
440
                            attention_resized = np.resize(attention.asnumpy(), (8, 8))
441
                            ax = fig.add_subplot(max_length//3, max_length//4, l+2)
442
443
444
                            if int(labels[l+1][0].asscalar()) > len(dict):
                                ax.set_title("<unk>")
                            elif dict[int(labels[l+1][0].asscalar())] == "<end>":
445
                                ax.set_title(".")
446
                                img = ax.imshow(train_images[0+batch_size*(batch_i)].transpose(1,2,0))
447
448
449
450
                                ax.imshow(attention_resized, cmap='gray', alpha=0.6, extent=img.get_extent())
                                break
                            else:
                                ax.set_title(dict[int(labels[l+1][0].asscalar())])
451
                            img = ax.imshow(train_images[0+batch_size*(batch_i)].transpose(1,2,0))
452
                            ax.imshow(attention_resized, cmap='gray', alpha=0.6, extent=img.get_extent())
453
454
455
456

                        plt.tight_layout()
                        target_dir = 'target/attention_images'
                        if not os.path.exists(target_dir):
457
                            os.makedirs(target_dir)
458
459
460
                        plt.savefig(target_dir + '/attention_train.png')
                        plt.close()

461
462
463
464
465
466
                    predictions = []
                    for output_name in outputs:
                        if mx.nd.shape_array(mx.nd.squeeze(output_name)).size > 1:
                            predictions.append(mx.nd.argmax(output_name, axis=1))
                        else:
                            predictions.append(output_name)
467

468
469
470
471
                    metric.update(preds=predictions, labels=labels)
                train_metric_score = metric.get()[1]
            else:
                train_metric_score = 0
Nicola Gatto's avatar
Nicola Gatto committed
472

473
474
475
            global_loss_test = 0.0
            test_batches = 0

Nicola Gatto's avatar
Nicola Gatto committed
476
            test_iter.reset()
Sebastian Nickels's avatar
Sebastian Nickels committed
477
            metric = mx.metric.create(eval_metric, **eval_metric_params)
Nicola Gatto's avatar
Nicola Gatto committed
478
            for batch_i, batch in enumerate(test_iter):
479
                if True:
480
                    labels = [batch.label[i].as_in_context(mx_context) for i in range(1)]
481

482
                    data_ = batch.data[0].as_in_context(mx_context)
483

484
                    softmax_ = mx.nd.zeros((batch_size, 10,), ctx=mx_context)
485

486

487
488
                    nd.waitall()

489
                    outputs = []
490
491
                    lossList = []
                    attentionList = []
492
                    softmax_ = self._networks[0](data_)
493

494
                    outputs.append(softmax_)
495
                    lossList.append(loss_function(softmax_, labels[0]))
496

497
498

                    if save_attention_image == "True":
499
500
501
502
503
504
505
506
507
508
                        if not eval_train:
                            import matplotlib
                            matplotlib.use('Agg')
                            import matplotlib.pyplot as plt
                            logging.getLogger('matplotlib').setLevel(logging.ERROR)

                            if(os.path.isfile('src/test/resources/training_data/Show_attend_tell/dict.pkl')):
                                with open('src/test/resources/training_data/Show_attend_tell/dict.pkl', 'rb') as f:
                                    dict = pickle.load(f)

509
                        plt.clf()
510
                        fig = plt.figure(figsize=(15,15))
511
512
                        max_length = len(labels)-1

513
                        ax = fig.add_subplot(max_length//3, max_length//4, 1)
514
                        ax.imshow(test_images[0+batch_size*(batch_i)].transpose(1,2,0))
515

516
517
                        for l in range(max_length):
                            attention = attentionList[l]
518
                            attention = mx.nd.slice_axis(attention, axis=0, begin=0, end=1).squeeze()
519
                            attention_resized = np.resize(attention.asnumpy(), (8, 8))
520
                            ax = fig.add_subplot(max_length//3, max_length//4, l+2)
521
522
523
                            if int(mx.nd.slice_axis(outputs[l+1], axis=0, begin=0, end=1).squeeze().asscalar()) > len(dict):
                                ax.set_title("<unk>")
                            elif dict[int(mx.nd.slice_axis(outputs[l+1], axis=0, begin=0, end=1).squeeze().asscalar())] == "<end>":
524
                                ax.set_title(".")
525
                                img = ax.imshow(test_images[0+batch_size*(batch_i)].transpose(1,2,0))
526
527
528
                                ax.imshow(attention_resized, cmap='gray', alpha=0.6, extent=img.get_extent())
                                break
                            else:
529
                                ax.set_title(dict[int(mx.nd.slice_axis(outputs[l+1], axis=0, begin=0, end=1).squeeze().asscalar())])
530
                            img = ax.imshow(test_images[0+batch_size*(batch_i)].transpose(1,2,0))
531
                            ax.imshow(attention_resized, cmap='gray', alpha=0.6, extent=img.get_extent())
532
533

                        plt.tight_layout()
534
535
536
                        target_dir = 'target/attention_images'
                        if not os.path.exists(target_dir):
                            os.makedirs(target_dir)
537
538
                        plt.savefig(target_dir + '/attention_test.png')
                        plt.close()
539
540
541
542
                loss = 0
                for element in lossList:
                    loss = loss + element

543
                global_loss_test += loss.sum().asscalar()
544
                test_batches += 1
545

546
                predictions = []
547
                for output_name in outputs:
548
                    predictions.append(output_name)
549

550
                metric.update(preds=predictions, labels=labels)
Nicola Gatto's avatar
Nicola Gatto committed
551
552
            test_metric_score = metric.get()[1]

553
            global_loss_test /= (test_batches * batch_size)
Nicola Gatto's avatar
Nicola Gatto committed
554

555
            logging.info("Epoch[%d] Train metric: %f, Test metric: %f, Train loss: %f, Test loss: %f" % (epoch, train_metric_score, test_metric_score, global_loss_train, global_loss_test))
556

Nicola Gatto's avatar
Nicola Gatto committed
557
            if (epoch - begin_epoch) % checkpoint_period == 0:
558
559
                for i, network in self._networks.items():
                    network.save_parameters(self.parameter_path(i) + '-' + str(epoch).zfill(4) + '.params')
Nicola Gatto's avatar
Nicola Gatto committed
560

561
        for i, network in self._networks.items():
Sebastian Nickels's avatar
Sebastian Nickels committed
562
            network.save_parameters(self.parameter_path(i) + '-' + str(num_epoch + begin_epoch + 1).zfill(4) + '.params')
563
            network.export(self.parameter_path(i) + '_newest', epoch=0)
Nicola Gatto's avatar
Nicola Gatto committed
564

565
    def parameter_path(self, index):
Bernhard Rumpe's avatar
BR-sy    
Bernhard Rumpe committed
566
        return self._net_creator._model_dir_ + self._net_creator._model_prefix_ + '_' + str(index)