Aufgrund einer Wartung wird GitLab am 18.01. zwischen 8:00 und 9:00 Uhr kurzzeitig nicht zur Verfügung stehen. / Due to maintenance, GitLab will be temporarily unavailable on 18.01. between 8:00 and 9:00 am.

CNNSupervisedTrainer_Alexnet.py 22.3 KB
Newer Older
Nicola Gatto's avatar
Nicola Gatto committed
1
2
3
4
5
6
import mxnet as mx
import logging
import numpy as np
import time
import os
import shutil
Christian Fuß's avatar
Christian Fuß committed
7
import pickle
Sebastian Nickels's avatar
Sebastian Nickels committed
8
9
import math
import sys
Nicola Gatto's avatar
Nicola Gatto committed
10
11
from mxnet import gluon, autograd, nd

Eyüp Harputlu's avatar
Eyüp Harputlu committed
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
class CrossEntropyLoss(gluon.loss.Loss):
    def __init__(self, axis=-1, sparse_label=True, weight=None, batch_axis=0, **kwargs):
        super(CrossEntropyLoss, self).__init__(weight, batch_axis, **kwargs)
        self._axis = axis
        self._sparse_label = sparse_label

    def hybrid_forward(self, F, pred, label, sample_weight=None):
        pred = F.log(pred)
        if self._sparse_label:
            loss = -F.pick(pred, label, axis=self._axis, keepdims=True)
        else:
            label = gluon.loss._reshape_like(F, label, pred)
            loss = -F.sum(pred * label, axis=self._axis, keepdims=True)
        loss = gluon.loss._apply_weighting(F, loss, self._weight, sample_weight)
        return F.mean(loss, axis=self._batch_axis, exclude=True)

Eyüp Harputlu's avatar
Eyüp Harputlu committed
28
29
30
31
32
33
34
35
36
class LogCoshLoss(gluon.loss.Loss):
    def __init__(self, weight=None, batch_axis=0, **kwargs):
        super(LogCoshLoss, self).__init__(weight, batch_axis, **kwargs)

    def hybrid_forward(self, F, pred, label, sample_weight=None):
        loss = F.log(F.cosh(pred - label))
        loss = gluon.loss._apply_weighting(F, loss, self._weight, sample_weight)
        return F.mean(loss, axis=self._batch_axis, exclude=True)

Sebastian Nickels's avatar
Sebastian Nickels committed
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
class SoftmaxCrossEntropyLossIgnoreIndices(gluon.loss.Loss):
    def __init__(self, axis=-1, ignore_indices=[], sparse_label=True, from_logits=False, weight=None, batch_axis=0, **kwargs):
        super(SoftmaxCrossEntropyLossIgnoreIndices, self).__init__(weight, batch_axis, **kwargs)
        self._axis = axis
        self._ignore_indices = ignore_indices
        self._sparse_label = sparse_label
        self._from_logits = from_logits

    def hybrid_forward(self, F, pred, label, sample_weight=None):
        log_softmax = F.log_softmax
        pick = F.pick
        if not self._from_logits:
            pred = log_softmax(pred, self._axis)
        if self._sparse_label:
            loss = -pick(pred, label, axis=self._axis, keepdims=True)
        else:
            label = _reshape_like(F, label, pred)
            loss = -(pred * label).sum(axis=self._axis, keepdims=True)
        # ignore some indices for loss, e.g. <pad> tokens in NLP applications
        for i in self._ignore_indices:
57
            loss = loss * mx.nd.logical_not(mx.nd.equal(mx.nd.argmax(pred, axis=1), mx.nd.ones_like(mx.nd.argmax(pred, axis=1))*i) * mx.nd.equal(mx.nd.argmax(pred, axis=1), label))
Sebastian Nickels's avatar
Sebastian Nickels committed
58
59
        return loss.mean(axis=self._batch_axis, exclude=True)

Sebastian Nickels's avatar
Sebastian Nickels committed
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
@mx.metric.register
class BLEU(mx.metric.EvalMetric):
    N = 4

    def __init__(self, exclude=None, name='bleu', output_names=None, label_names=None):
        super(BLEU, self).__init__(name=name, output_names=output_names, label_names=label_names)

        self._exclude = exclude or []

        self._match_counts = [0 for _ in range(self.N)]
        self._counts = [0 for _ in range(self.N)]

        self._size_ref = 0
        self._size_hyp = 0

    def update(self, labels, preds):
        labels, preds = mx.metric.check_label_shapes(labels, preds, True)

        new_labels = self._convert(labels)
        new_preds = self._convert(preds)

        for label, pred in zip(new_labels, new_preds):
            reference = [word for word in label if word not in self._exclude]
            hypothesis = [word for word in pred if word not in self._exclude]

            self._size_ref += len(reference)
            self._size_hyp += len(hypothesis)

            for n in range(self.N):
                reference_ngrams = self._get_ngrams(reference, n + 1)
                hypothesis_ngrams = self._get_ngrams(hypothesis, n + 1)

                match_count = 0

                for ngram in hypothesis_ngrams:
                    if ngram in reference_ngrams:
                        reference_ngrams.remove(ngram)

                        match_count += 1

                self._match_counts[n] += match_count
                self._counts[n] += len(hypothesis_ngrams)

    def get(self):
        precisions = [sys.float_info.min for n in range(self.N)]

        i = 1

        for n in range(self.N):
            match_counts = self._match_counts[n]
            counts = self._counts[n]

            if counts != 0:
                if match_counts == 0:
                    i *= 2
                    match_counts = 1 / i

117
118
                if (match_counts / counts) > 0:
                    precisions[n] = match_counts / counts
Sebastian Nickels's avatar
Sebastian Nickels committed
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169

        bleu = self._get_brevity_penalty() * math.exp(sum(map(math.log, precisions)) / self.N)

        return (self.name, bleu)

    def calculate(self):
        precisions = [sys.float_info.min for n in range(self.N)]

        i = 1

        for n in range(self.N):
            match_counts = self._match_counts[n]
            counts = self._counts[n]

            if counts != 0:
                if match_counts == 0:
                    i *= 2
                    match_counts = 1 / i

                precisions[n] = match_counts / counts

        return self._get_brevity_penalty() * math.exp(sum(map(math.log, precisions)) / self.N)

    def _get_brevity_penalty(self):
        if self._size_hyp >= self._size_ref:
            return 1
        else:
            return math.exp(1 - (self._size_ref / self._size_hyp))

    @staticmethod
    def _get_ngrams(sentence, n):
        ngrams = []

        if len(sentence) >= n:
            for i in range(len(sentence) - n + 1):
                ngrams.append(sentence[i:i+n])

        return ngrams

    @staticmethod
    def _convert(nd_list):
        if len(nd_list) == 0:
            return []

        new_list = [[] for _ in range(nd_list[0].shape[0])]

        for element in nd_list:
            for i in range(element.shape[0]):
                new_list[i].append(element[i].asscalar())

        return new_list
Christian Fuß's avatar
Christian Fuß committed
170

Sebastian Nickels's avatar
Sebastian Nickels committed
171
172


173
class CNNSupervisedTrainer_Alexnet:
174
    def __init__(self, data_loader, net_constructor):
Nicola Gatto's avatar
Nicola Gatto committed
175
176
        self._data_loader = data_loader
        self._net_creator = net_constructor
177
        self._networks = {}
Nicola Gatto's avatar
Nicola Gatto committed
178
179
180
181

    def train(self, batch_size=64,
              num_epoch=10,
              eval_metric='acc',
Sebastian Nickels's avatar
Sebastian Nickels committed
182
              eval_metric_params={},
183
              eval_train=False,
Eyüp Harputlu's avatar
Eyüp Harputlu committed
184
185
              loss ='softmax_cross_entropy',
              loss_params={},
Nicola Gatto's avatar
Nicola Gatto committed
186
187
188
189
              optimizer='adam',
              optimizer_params=(('learning_rate', 0.001),),
              load_checkpoint=True,
              checkpoint_period=5,
190
191
              log_period=50,
              context='gpu',
192
              save_attention_image=False,
193
              use_teacher_forcing=False,
194
              normalize=True,
195
196
              shuffle_data=False,
              clip_global_grad_norm=None,
197
              preprocessing = False):
Nicola Gatto's avatar
Nicola Gatto committed
198
199
200
201
202
203
204
        if context == 'gpu':
            mx_context = mx.gpu()
        elif context == 'cpu':
            mx_context = mx.cpu()
        else:
            logging.error("Context argument is '" + context + "'. Only 'cpu' and 'gpu are valid arguments'.")

205
206
        if preprocessing:
            preproc_lib = "CNNPreprocessor_Alexnet_executor"
207
            train_iter, test_iter, data_mean, data_std, train_images, test_images = self._data_loader.load_preprocessed_data(batch_size, preproc_lib, shuffle_data)
208
        else:
209
            train_iter, test_iter, data_mean, data_std, train_images, test_images = self._data_loader.load_data(batch_size, shuffle_data)
210

Nicola Gatto's avatar
Nicola Gatto committed
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
        if 'weight_decay' in optimizer_params:
            optimizer_params['wd'] = optimizer_params['weight_decay']
            del optimizer_params['weight_decay']
        if 'learning_rate_decay' in optimizer_params:
            min_learning_rate = 1e-08
            if 'learning_rate_minimum' in optimizer_params:
                min_learning_rate = optimizer_params['learning_rate_minimum']
                del optimizer_params['learning_rate_minimum']
            optimizer_params['lr_scheduler'] = mx.lr_scheduler.FactorScheduler(
                                                   optimizer_params['step_size'],
                                                   factor=optimizer_params['learning_rate_decay'],
                                                   stop_factor_lr=min_learning_rate)
            del optimizer_params['step_size']
            del optimizer_params['learning_rate_decay']

226
227
228
229
        if normalize:
            self._net_creator.construct(context=mx_context, data_mean=data_mean, data_std=data_std)
        else:
            self._net_creator.construct(context=mx_context)
Nicola Gatto's avatar
Nicola Gatto committed
230
231
232

        begin_epoch = 0
        if load_checkpoint:
Sebastian Nickels's avatar
Sebastian Nickels committed
233
            begin_epoch = self._net_creator.load(mx_context)
Nicola Gatto's avatar
Nicola Gatto committed
234
235
236
237
        else:
            if os.path.isdir(self._net_creator._model_dir_):
                shutil.rmtree(self._net_creator._model_dir_)

238
        self._networks = self._net_creator.networks
Nicola Gatto's avatar
Nicola Gatto committed
239
240
241
242
243
244
245

        try:
            os.makedirs(self._net_creator._model_dir_)
        except OSError:
            if not os.path.isdir(self._net_creator._model_dir_):
                raise

Sebastian Nickels's avatar
Sebastian Nickels committed
246
        trainers = [mx.gluon.Trainer(network.collect_params(), optimizer, optimizer_params) for network in self._networks.values() if len(network.collect_params().values()) != 0]
Nicola Gatto's avatar
Nicola Gatto committed
247

Eyüp Harputlu's avatar
Eyüp Harputlu committed
248
249
        margin = loss_params['margin'] if 'margin' in loss_params else 1.0
        sparseLabel = loss_params['sparse_label'] if 'sparse_label' in loss_params else True
250
        ignore_indices = [loss_params['ignore_indices']] if 'ignore_indices' in loss_params else []
Eyüp Harputlu's avatar
Eyüp Harputlu committed
251
252
        if loss == 'softmax_cross_entropy':
            fromLogits = loss_params['from_logits'] if 'from_logits' in loss_params else False
253
            loss_function = mx.gluon.loss.SoftmaxCrossEntropyLoss(from_logits=fromLogits, sparse_label=sparseLabel)
254
        elif loss == 'softmax_cross_entropy_ignore_indices':
255
            fromLogits = loss_params['from_logits'] if 'from_logits' in loss_params else False
256
            loss_function = SoftmaxCrossEntropyLossIgnoreIndices(ignore_indices=ignore_indices, from_logits=fromLogits, sparse_label=sparseLabel)
Eyüp Harputlu's avatar
Eyüp Harputlu committed
257
        elif loss == 'sigmoid_binary_cross_entropy':
Nicola Gatto's avatar
Nicola Gatto committed
258
            loss_function = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss()
Eyüp Harputlu's avatar
Eyüp Harputlu committed
259
260
261
        elif loss == 'cross_entropy':
            loss_function = CrossEntropyLoss(sparse_label=sparseLabel)
        elif loss == 'l2':
Nicola Gatto's avatar
Nicola Gatto committed
262
            loss_function = mx.gluon.loss.L2Loss()
Eyüp Harputlu's avatar
Eyüp Harputlu committed
263
        elif loss == 'l1':
Nicola Gatto's avatar
Nicola Gatto committed
264
            loss_function = mx.gluon.loss.L2Loss()
Eyüp Harputlu's avatar
Eyüp Harputlu committed
265
266
267
268
269
270
271
272
273
274
275
276
277
        elif loss == 'huber':
            rho = loss_params['rho'] if 'rho' in loss_params else 1
            loss_function = mx.gluon.loss.HuberLoss(rho=rho)
        elif loss == 'hinge':
            loss_function = mx.gluon.loss.HingeLoss(margin=margin)
        elif loss == 'squared_hinge':
            loss_function = mx.gluon.loss.SquaredHingeLoss(margin=margin)
        elif loss == 'logistic':
            labelFormat = loss_params['label_format'] if 'label_format' in loss_params else 'signed'
            loss_function = mx.gluon.loss.LogisticLoss(label_format=labelFormat)
        elif loss == 'kullback_leibler':
            fromLogits = loss_params['from_logits'] if 'from_logits' in loss_params else True
            loss_function = mx.gluon.loss.KLDivLoss(from_logits=fromLogits)
Eyüp Harputlu's avatar
Eyüp Harputlu committed
278
279
        elif loss == 'log_cosh':
            loss_function = LogCoshLoss()
Eyüp Harputlu's avatar
Eyüp Harputlu committed
280
281
        else:
            logging.error("Invalid loss parameter.")
Nicola Gatto's avatar
Nicola Gatto committed
282
283
284
285

        tic = None

        for epoch in range(begin_epoch, begin_epoch + num_epoch):
286
287
288
289
290
291
292
293
294
            if shuffle_data:
                if preprocessing:
                    preproc_lib = "CNNPreprocessor_Alexnet_executor"
                    train_iter, test_iter, data_mean, data_std, train_images, test_images = self._data_loader.load_preprocessed_data(batch_size, preproc_lib, shuffle_data)
                else:
                    train_iter, test_iter, data_mean, data_std, train_images, test_images = self._data_loader.load_data(batch_size, shuffle_data)

            global_loss_train = 0.0
            train_batches = 0
295
296

            loss_total = 0
Nicola Gatto's avatar
Nicola Gatto committed
297
298
            train_iter.reset()
            for batch_i, batch in enumerate(train_iter):
299
300
                with autograd.record():
                    labels = [batch.label[i].as_in_context(mx_context) for i in range(1)]
301

302
303
                    data_ = batch.data[0].as_in_context(mx_context)

304
                    predictions_ = mx.nd.zeros((batch_size, 10,), ctx=mx_context)
Christian Fuß's avatar
Christian Fuß committed
305

306

307
308
                    nd.waitall()

309
                    lossList = []
310

311
                    predictions_ = self._networks[0](data_)
312
313

                    lossList.append(loss_function(predictions_, labels[0]))
314

315
316
317
                    loss = 0
                    for element in lossList:
                        loss = loss + element
Nicola Gatto's avatar
Nicola Gatto committed
318
319

                loss.backward()
320

321
322
                loss_total += loss.sum().asscalar()

323
324
325
                global_loss_train += float(loss.mean().asscalar())
                train_batches += 1

326
327
328
329
330
331
332
333
                if clip_global_grad_norm:
                    grads = []

                    for network in self._networks.values():
                        grads.extend([param.grad(mx_context) for param in network.collect_params().values()])

                    gluon.utils.clip_global_norm(grads, clip_global_grad_norm)

334
335
                for trainer in trainers:
                    trainer.step(batch_size)
Nicola Gatto's avatar
Nicola Gatto committed
336
337
338
339

                if tic is None:
                    tic = time.time()
                else:
340
                    if batch_i % log_period == 0:
Nicola Gatto's avatar
Nicola Gatto committed
341
                        try:
342
                            speed = log_period * batch_size / (time.time() - tic)
Nicola Gatto's avatar
Nicola Gatto committed
343
344
345
                        except ZeroDivisionError:
                            speed = float("inf")

346
347
348
349
                        loss_avg = loss_total / (batch_size * log_period)
                        loss_total = 0

                        logging.info("Epoch[%d] Batch[%d] Speed: %.2f samples/sec Loss: %.5f" % (epoch, batch_i, speed, loss_avg))
Nicola Gatto's avatar
Nicola Gatto committed
350
351
352

                        tic = time.time()

353
354
355
            if train_batches > 0:
                global_loss_train /= train_batches

Nicola Gatto's avatar
Nicola Gatto committed
356
357
            tic = None

358
359
360
361
362

            if eval_train:
                train_iter.reset()
                metric = mx.metric.create(eval_metric, **eval_metric_params)
                for batch_i, batch in enumerate(train_iter):
363
                    labels = [batch.label[i].as_in_context(mx_context) for i in range(1)]
364

365
                    data_ = batch.data[0].as_in_context(mx_context)
366

367
                    predictions_ = mx.nd.zeros((batch_size, 10,), ctx=mx_context)
368

369

370
371
                    nd.waitall()

372
                    outputs = []
373
374
                    lossList = []
                    attentionList = []
375
                    predictions_ = self._networks[0](data_)
376

377
                    outputs.append(predictions_)
378
                    lossList.append(loss_function(predictions_, labels[0]))
379

380
381

                    if save_attention_image == "True":
382
383
                        import matplotlib
                        matplotlib.use('Agg')
384
385
386
387
388
389
390
                        import matplotlib.pyplot as plt
                        logging.getLogger('matplotlib').setLevel(logging.ERROR)

                        if(os.path.isfile('src/test/resources/training_data/Show_attend_tell/dict.pkl')):
                            with open('src/test/resources/training_data/Show_attend_tell/dict.pkl', 'rb') as f:
                                dict = pickle.load(f)

391
392
393
394
                        plt.clf()
                        fig = plt.figure(figsize=(15,15))
                        max_length = len(labels)-1

395
                        ax = fig.add_subplot(max_length//3, max_length//4, 1)
396
                        ax.imshow(train_images[0+batch_size*(batch_i)].transpose(1,2,0))
397

398
399
                        for l in range(max_length):
                            attention = attentionList[l]
400
                            attention = mx.nd.slice_axis(attention, axis=0, begin=0, end=1).squeeze()
401
                            attention_resized = np.resize(attention.asnumpy(), (8, 8))
402
                            ax = fig.add_subplot(max_length//3, max_length//4, l+2)
403
404
405
                            if int(labels[l+1][0].asscalar()) > len(dict):
                                ax.set_title("<unk>")
                            elif dict[int(labels[l+1][0].asscalar())] == "<end>":
406
                                ax.set_title(".")
407
                                img = ax.imshow(train_images[0+batch_size*(batch_i)].transpose(1,2,0))
408
409
410
411
                                ax.imshow(attention_resized, cmap='gray', alpha=0.6, extent=img.get_extent())
                                break
                            else:
                                ax.set_title(dict[int(labels[l+1][0].asscalar())])
412
                            img = ax.imshow(train_images[0+batch_size*(batch_i)].transpose(1,2,0))
413
                            ax.imshow(attention_resized, cmap='gray', alpha=0.6, extent=img.get_extent())
414
415
416
417

                        plt.tight_layout()
                        target_dir = 'target/attention_images'
                        if not os.path.exists(target_dir):
418
                            os.makedirs(target_dir)
419
420
421
                        plt.savefig(target_dir + '/attention_train.png')
                        plt.close()

422
423
424
425
426
427
                    predictions = []
                    for output_name in outputs:
                        if mx.nd.shape_array(mx.nd.squeeze(output_name)).size > 1:
                            predictions.append(mx.nd.argmax(output_name, axis=1))
                        else:
                            predictions.append(output_name)
428

429
430
431
432
                    metric.update(preds=predictions, labels=labels)
                train_metric_score = metric.get()[1]
            else:
                train_metric_score = 0
Nicola Gatto's avatar
Nicola Gatto committed
433

434
435
436
            global_loss_test = 0.0
            test_batches = 0

Nicola Gatto's avatar
Nicola Gatto committed
437
            test_iter.reset()
Sebastian Nickels's avatar
Sebastian Nickels committed
438
            metric = mx.metric.create(eval_metric, **eval_metric_params)
Nicola Gatto's avatar
Nicola Gatto committed
439
            for batch_i, batch in enumerate(test_iter):
440
                if True:
441
                    labels = [batch.label[i].as_in_context(mx_context) for i in range(1)]
442

443
                    data_ = batch.data[0].as_in_context(mx_context)
444

445
                    predictions_ = mx.nd.zeros((batch_size, 10,), ctx=mx_context)
446

447

448
449
                    nd.waitall()

450
                    outputs = []
451
452
                    lossList = []
                    attentionList = []
453
                    predictions_ = self._networks[0](data_)
454

455
                    outputs.append(predictions_)
456
                    lossList.append(loss_function(predictions_, labels[0]))
457

458
459

                    if save_attention_image == "True":
460
461
462
463
464
465
466
467
468
469
                        if not eval_train:
                            import matplotlib
                            matplotlib.use('Agg')
                            import matplotlib.pyplot as plt
                            logging.getLogger('matplotlib').setLevel(logging.ERROR)

                            if(os.path.isfile('src/test/resources/training_data/Show_attend_tell/dict.pkl')):
                                with open('src/test/resources/training_data/Show_attend_tell/dict.pkl', 'rb') as f:
                                    dict = pickle.load(f)

470
                        plt.clf()
471
                        fig = plt.figure(figsize=(15,15))
472
473
                        max_length = len(labels)-1

474
                        ax = fig.add_subplot(max_length//3, max_length//4, 1)
475
                        ax.imshow(test_images[0+batch_size*(batch_i)].transpose(1,2,0))
476

477
478
                        for l in range(max_length):
                            attention = attentionList[l]
479
                            attention = mx.nd.slice_axis(attention, axis=0, begin=0, end=1).squeeze()
480
                            attention_resized = np.resize(attention.asnumpy(), (8, 8))
481
                            ax = fig.add_subplot(max_length//3, max_length//4, l+2)
482
483
484
                            if int(mx.nd.slice_axis(outputs[l+1], axis=0, begin=0, end=1).squeeze().asscalar()) > len(dict):
                                ax.set_title("<unk>")
                            elif dict[int(mx.nd.slice_axis(outputs[l+1], axis=0, begin=0, end=1).squeeze().asscalar())] == "<end>":
485
                                ax.set_title(".")
486
                                img = ax.imshow(test_images[0+batch_size*(batch_i)].transpose(1,2,0))
487
488
489
                                ax.imshow(attention_resized, cmap='gray', alpha=0.6, extent=img.get_extent())
                                break
                            else:
490
                                ax.set_title(dict[int(mx.nd.slice_axis(outputs[l+1], axis=0, begin=0, end=1).squeeze().asscalar())])
491
                            img = ax.imshow(test_images[0+batch_size*(batch_i)].transpose(1,2,0))
492
                            ax.imshow(attention_resized, cmap='gray', alpha=0.6, extent=img.get_extent())
493
494

                        plt.tight_layout()
495
496
497
                        target_dir = 'target/attention_images'
                        if not os.path.exists(target_dir):
                            os.makedirs(target_dir)
498
499
                        plt.savefig(target_dir + '/attention_test.png')
                        plt.close()
500
501
502
503
504
505
                loss = 0
                for element in lossList:
                    loss = loss + element

                global_loss_test += float(loss.mean().asscalar())
                test_batches += 1
506

507
                predictions = []
508
                for output_name in outputs:
Sebastian Nickels's avatar
Sebastian Nickels committed
509
                    if mx.nd.shape_array(mx.nd.squeeze(output_name)).size > 1:
510
511
512
513
                        predictions.append(mx.nd.argmax(output_name, axis=1))
                    #ArgMax already applied
                    else:
                        predictions.append(output_name)
514

515
                metric.update(preds=predictions, labels=labels)
Nicola Gatto's avatar
Nicola Gatto committed
516
517
            test_metric_score = metric.get()[1]

518
519
            if test_batches > 0:
                global_loss_test /= test_batches
Nicola Gatto's avatar
Nicola Gatto committed
520

521
            logging.info("Epoch[%d] Train metric: %f, Test metric: %f, Train loss: %f, Test loss: %f" % (epoch, train_metric_score, test_metric_score, global_loss_train, global_loss_test))
522

Nicola Gatto's avatar
Nicola Gatto committed
523
            if (epoch - begin_epoch) % checkpoint_period == 0:
524
525
                for i, network in self._networks.items():
                    network.save_parameters(self.parameter_path(i) + '-' + str(epoch).zfill(4) + '.params')
Nicola Gatto's avatar
Nicola Gatto committed
526

527
528
529
        for i, network in self._networks.items():
            network.save_parameters(self.parameter_path(i) + '-' + str(num_epoch + begin_epoch).zfill(4) + '.params')
            network.export(self.parameter_path(i) + '_newest', epoch=0)
Nicola Gatto's avatar
Nicola Gatto committed
530

531
    def parameter_path(self, index):
Bernhard Rumpe's avatar
BR-sy    
Bernhard Rumpe committed
532
        return self._net_creator._model_dir_ + self._net_creator._model_prefix_ + '_' + str(index)