CNNSupervisedTrainer_Alexnet.py 22.3 KB
Newer Older
Nicola Gatto's avatar
Nicola Gatto committed
1 2 3 4 5 6
import mxnet as mx
import logging
import numpy as np
import time
import os
import shutil
Christian Fuß's avatar
Christian Fuß committed
7
import pickle
Sebastian Nickels's avatar
Sebastian Nickels committed
8 9
import math
import sys
Nicola Gatto's avatar
Nicola Gatto committed
10 11
from mxnet import gluon, autograd, nd

Eyüp Harputlu's avatar
Eyüp Harputlu committed
12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27
class CrossEntropyLoss(gluon.loss.Loss):
    def __init__(self, axis=-1, sparse_label=True, weight=None, batch_axis=0, **kwargs):
        super(CrossEntropyLoss, self).__init__(weight, batch_axis, **kwargs)
        self._axis = axis
        self._sparse_label = sparse_label

    def hybrid_forward(self, F, pred, label, sample_weight=None):
        pred = F.log(pred)
        if self._sparse_label:
            loss = -F.pick(pred, label, axis=self._axis, keepdims=True)
        else:
            label = gluon.loss._reshape_like(F, label, pred)
            loss = -F.sum(pred * label, axis=self._axis, keepdims=True)
        loss = gluon.loss._apply_weighting(F, loss, self._weight, sample_weight)
        return F.mean(loss, axis=self._batch_axis, exclude=True)

Eyüp Harputlu's avatar
Eyüp Harputlu committed
28 29 30 31 32 33 34 35 36
class LogCoshLoss(gluon.loss.Loss):
    def __init__(self, weight=None, batch_axis=0, **kwargs):
        super(LogCoshLoss, self).__init__(weight, batch_axis, **kwargs)

    def hybrid_forward(self, F, pred, label, sample_weight=None):
        loss = F.log(F.cosh(pred - label))
        loss = gluon.loss._apply_weighting(F, loss, self._weight, sample_weight)
        return F.mean(loss, axis=self._batch_axis, exclude=True)

Sebastian Nickels's avatar
Sebastian Nickels committed
37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56
class SoftmaxCrossEntropyLossIgnoreIndices(gluon.loss.Loss):
    def __init__(self, axis=-1, ignore_indices=[], sparse_label=True, from_logits=False, weight=None, batch_axis=0, **kwargs):
        super(SoftmaxCrossEntropyLossIgnoreIndices, self).__init__(weight, batch_axis, **kwargs)
        self._axis = axis
        self._ignore_indices = ignore_indices
        self._sparse_label = sparse_label
        self._from_logits = from_logits

    def hybrid_forward(self, F, pred, label, sample_weight=None):
        log_softmax = F.log_softmax
        pick = F.pick
        if not self._from_logits:
            pred = log_softmax(pred, self._axis)
        if self._sparse_label:
            loss = -pick(pred, label, axis=self._axis, keepdims=True)
        else:
            label = _reshape_like(F, label, pred)
            loss = -(pred * label).sum(axis=self._axis, keepdims=True)
        # ignore some indices for loss, e.g. <pad> tokens in NLP applications
        for i in self._ignore_indices:
57
            loss = loss * mx.nd.logical_not(mx.nd.equal(mx.nd.argmax(pred, axis=1), mx.nd.ones_like(mx.nd.argmax(pred, axis=1))*i) * mx.nd.equal(mx.nd.argmax(pred, axis=1), label))
Sebastian Nickels's avatar
Sebastian Nickels committed
58 59
        return loss.mean(axis=self._batch_axis, exclude=True)

Sebastian Nickels's avatar
Sebastian Nickels committed
60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116
@mx.metric.register
class BLEU(mx.metric.EvalMetric):
    N = 4

    def __init__(self, exclude=None, name='bleu', output_names=None, label_names=None):
        super(BLEU, self).__init__(name=name, output_names=output_names, label_names=label_names)

        self._exclude = exclude or []

        self._match_counts = [0 for _ in range(self.N)]
        self._counts = [0 for _ in range(self.N)]

        self._size_ref = 0
        self._size_hyp = 0

    def update(self, labels, preds):
        labels, preds = mx.metric.check_label_shapes(labels, preds, True)

        new_labels = self._convert(labels)
        new_preds = self._convert(preds)

        for label, pred in zip(new_labels, new_preds):
            reference = [word for word in label if word not in self._exclude]
            hypothesis = [word for word in pred if word not in self._exclude]

            self._size_ref += len(reference)
            self._size_hyp += len(hypothesis)

            for n in range(self.N):
                reference_ngrams = self._get_ngrams(reference, n + 1)
                hypothesis_ngrams = self._get_ngrams(hypothesis, n + 1)

                match_count = 0

                for ngram in hypothesis_ngrams:
                    if ngram in reference_ngrams:
                        reference_ngrams.remove(ngram)

                        match_count += 1

                self._match_counts[n] += match_count
                self._counts[n] += len(hypothesis_ngrams)

    def get(self):
        precisions = [sys.float_info.min for n in range(self.N)]

        i = 1

        for n in range(self.N):
            match_counts = self._match_counts[n]
            counts = self._counts[n]

            if counts != 0:
                if match_counts == 0:
                    i *= 2
                    match_counts = 1 / i

117 118
                if (match_counts / counts) > 0:
                    precisions[n] = match_counts / counts
Sebastian Nickels's avatar
Sebastian Nickels committed
119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145

        bleu = self._get_brevity_penalty() * math.exp(sum(map(math.log, precisions)) / self.N)

        return (self.name, bleu)

    def calculate(self):
        precisions = [sys.float_info.min for n in range(self.N)]

        i = 1

        for n in range(self.N):
            match_counts = self._match_counts[n]
            counts = self._counts[n]

            if counts != 0:
                if match_counts == 0:
                    i *= 2
                    match_counts = 1 / i

                precisions[n] = match_counts / counts

        return self._get_brevity_penalty() * math.exp(sum(map(math.log, precisions)) / self.N)

    def _get_brevity_penalty(self):
        if self._size_hyp >= self._size_ref:
            return 1
        else:
Sebastian Nickels's avatar
Sebastian Nickels committed
146 147 148 149 150 151
            if self._size_hyp > 0:
                size_hyp = self._size_hyp
            else:
                size_hyp = 1

            return math.exp(1 - (self._size_ref / size_hyp))
Sebastian Nickels's avatar
Sebastian Nickels committed
152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174

    @staticmethod
    def _get_ngrams(sentence, n):
        ngrams = []

        if len(sentence) >= n:
            for i in range(len(sentence) - n + 1):
                ngrams.append(sentence[i:i+n])

        return ngrams

    @staticmethod
    def _convert(nd_list):
        if len(nd_list) == 0:
            return []

        new_list = [[] for _ in range(nd_list[0].shape[0])]

        for element in nd_list:
            for i in range(element.shape[0]):
                new_list[i].append(element[i].asscalar())

        return new_list
Christian Fuß's avatar
Christian Fuß committed
175

Sebastian Nickels's avatar
Sebastian Nickels committed
176 177


178
class CNNSupervisedTrainer_Alexnet:
179
    def __init__(self, data_loader, net_constructor):
Nicola Gatto's avatar
Nicola Gatto committed
180 181
        self._data_loader = data_loader
        self._net_creator = net_constructor
182
        self._networks = {}
Nicola Gatto's avatar
Nicola Gatto committed
183 184 185 186

    def train(self, batch_size=64,
              num_epoch=10,
              eval_metric='acc',
Sebastian Nickels's avatar
Sebastian Nickels committed
187
              eval_metric_params={},
188
              eval_train=False,
Eyüp Harputlu's avatar
Eyüp Harputlu committed
189 190
              loss ='softmax_cross_entropy',
              loss_params={},
Nicola Gatto's avatar
Nicola Gatto committed
191 192 193 194
              optimizer='adam',
              optimizer_params=(('learning_rate', 0.001),),
              load_checkpoint=True,
              checkpoint_period=5,
195 196
              log_period=50,
              context='gpu',
197
              save_attention_image=False,
198
              use_teacher_forcing=False,
199
              normalize=True,
200 201
              shuffle_data=False,
              clip_global_grad_norm=None,
202
              preprocessing = False):
Nicola Gatto's avatar
Nicola Gatto committed
203 204 205 206 207 208 209
        if context == 'gpu':
            mx_context = mx.gpu()
        elif context == 'cpu':
            mx_context = mx.cpu()
        else:
            logging.error("Context argument is '" + context + "'. Only 'cpu' and 'gpu are valid arguments'.")

210 211
        if preprocessing:
            preproc_lib = "CNNPreprocessor_Alexnet_executor"
212
            train_iter, test_iter, data_mean, data_std, train_images, test_images = self._data_loader.load_preprocessed_data(batch_size, preproc_lib, shuffle_data)
213
        else:
214
            train_iter, test_iter, data_mean, data_std, train_images, test_images = self._data_loader.load_data(batch_size, shuffle_data)
215

Nicola Gatto's avatar
Nicola Gatto committed
216 217 218 219 220 221 222 223 224 225 226 227 228 229 230
        if 'weight_decay' in optimizer_params:
            optimizer_params['wd'] = optimizer_params['weight_decay']
            del optimizer_params['weight_decay']
        if 'learning_rate_decay' in optimizer_params:
            min_learning_rate = 1e-08
            if 'learning_rate_minimum' in optimizer_params:
                min_learning_rate = optimizer_params['learning_rate_minimum']
                del optimizer_params['learning_rate_minimum']
            optimizer_params['lr_scheduler'] = mx.lr_scheduler.FactorScheduler(
                                                   optimizer_params['step_size'],
                                                   factor=optimizer_params['learning_rate_decay'],
                                                   stop_factor_lr=min_learning_rate)
            del optimizer_params['step_size']
            del optimizer_params['learning_rate_decay']

231 232 233 234
        if normalize:
            self._net_creator.construct(context=mx_context, data_mean=data_mean, data_std=data_std)
        else:
            self._net_creator.construct(context=mx_context)
Nicola Gatto's avatar
Nicola Gatto committed
235 236 237

        begin_epoch = 0
        if load_checkpoint:
Sebastian Nickels's avatar
Sebastian Nickels committed
238
            begin_epoch = self._net_creator.load(mx_context)
Nicola Gatto's avatar
Nicola Gatto committed
239 240 241 242
        else:
            if os.path.isdir(self._net_creator._model_dir_):
                shutil.rmtree(self._net_creator._model_dir_)

243
        self._networks = self._net_creator.networks
Nicola Gatto's avatar
Nicola Gatto committed
244 245 246 247 248 249 250

        try:
            os.makedirs(self._net_creator._model_dir_)
        except OSError:
            if not os.path.isdir(self._net_creator._model_dir_):
                raise

Sebastian Nickels's avatar
Sebastian Nickels committed
251
        trainers = [mx.gluon.Trainer(network.collect_params(), optimizer, optimizer_params) for network in self._networks.values() if len(network.collect_params().values()) != 0]
Nicola Gatto's avatar
Nicola Gatto committed
252

Eyüp Harputlu's avatar
Eyüp Harputlu committed
253 254
        margin = loss_params['margin'] if 'margin' in loss_params else 1.0
        sparseLabel = loss_params['sparse_label'] if 'sparse_label' in loss_params else True
255
        ignore_indices = [loss_params['ignore_indices']] if 'ignore_indices' in loss_params else []
Eyüp Harputlu's avatar
Eyüp Harputlu committed
256 257
        if loss == 'softmax_cross_entropy':
            fromLogits = loss_params['from_logits'] if 'from_logits' in loss_params else False
258
            loss_function = mx.gluon.loss.SoftmaxCrossEntropyLoss(from_logits=fromLogits, sparse_label=sparseLabel)
259
        elif loss == 'softmax_cross_entropy_ignore_indices':
260
            fromLogits = loss_params['from_logits'] if 'from_logits' in loss_params else False
261
            loss_function = SoftmaxCrossEntropyLossIgnoreIndices(ignore_indices=ignore_indices, from_logits=fromLogits, sparse_label=sparseLabel)
Eyüp Harputlu's avatar
Eyüp Harputlu committed
262
        elif loss == 'sigmoid_binary_cross_entropy':
Nicola Gatto's avatar
Nicola Gatto committed
263
            loss_function = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss()
Eyüp Harputlu's avatar
Eyüp Harputlu committed
264 265 266
        elif loss == 'cross_entropy':
            loss_function = CrossEntropyLoss(sparse_label=sparseLabel)
        elif loss == 'l2':
Nicola Gatto's avatar
Nicola Gatto committed
267
            loss_function = mx.gluon.loss.L2Loss()
Eyüp Harputlu's avatar
Eyüp Harputlu committed
268
        elif loss == 'l1':
Nicola Gatto's avatar
Nicola Gatto committed
269
            loss_function = mx.gluon.loss.L2Loss()
Eyüp Harputlu's avatar
Eyüp Harputlu committed
270 271 272 273 274 275 276 277 278 279 280 281 282
        elif loss == 'huber':
            rho = loss_params['rho'] if 'rho' in loss_params else 1
            loss_function = mx.gluon.loss.HuberLoss(rho=rho)
        elif loss == 'hinge':
            loss_function = mx.gluon.loss.HingeLoss(margin=margin)
        elif loss == 'squared_hinge':
            loss_function = mx.gluon.loss.SquaredHingeLoss(margin=margin)
        elif loss == 'logistic':
            labelFormat = loss_params['label_format'] if 'label_format' in loss_params else 'signed'
            loss_function = mx.gluon.loss.LogisticLoss(label_format=labelFormat)
        elif loss == 'kullback_leibler':
            fromLogits = loss_params['from_logits'] if 'from_logits' in loss_params else True
            loss_function = mx.gluon.loss.KLDivLoss(from_logits=fromLogits)
Eyüp Harputlu's avatar
Eyüp Harputlu committed
283 284
        elif loss == 'log_cosh':
            loss_function = LogCoshLoss()
Eyüp Harputlu's avatar
Eyüp Harputlu committed
285 286
        else:
            logging.error("Invalid loss parameter.")
Nicola Gatto's avatar
Nicola Gatto committed
287 288 289 290

        tic = None

        for epoch in range(begin_epoch, begin_epoch + num_epoch):
291 292 293 294 295 296 297 298 299
            if shuffle_data:
                if preprocessing:
                    preproc_lib = "CNNPreprocessor_Alexnet_executor"
                    train_iter, test_iter, data_mean, data_std, train_images, test_images = self._data_loader.load_preprocessed_data(batch_size, preproc_lib, shuffle_data)
                else:
                    train_iter, test_iter, data_mean, data_std, train_images, test_images = self._data_loader.load_data(batch_size, shuffle_data)

            global_loss_train = 0.0
            train_batches = 0
300 301

            loss_total = 0
Nicola Gatto's avatar
Nicola Gatto committed
302 303
            train_iter.reset()
            for batch_i, batch in enumerate(train_iter):
304 305
                with autograd.record():
                    labels = [batch.label[i].as_in_context(mx_context) for i in range(1)]
306

307 308
                    data_ = batch.data[0].as_in_context(mx_context)

309
                    predictions_ = mx.nd.zeros((batch_size, 10,), ctx=mx_context)
Christian Fuß's avatar
Christian Fuß committed
310

311

312 313
                    nd.waitall()

314
                    lossList = []
315

316
                    predictions_ = self._networks[0](data_)
317 318

                    lossList.append(loss_function(predictions_, labels[0]))
319

320 321 322
                    loss = 0
                    for element in lossList:
                        loss = loss + element
Nicola Gatto's avatar
Nicola Gatto committed
323 324

                loss.backward()
325

326 327
                loss_total += loss.sum().asscalar()

328
                global_loss_train += loss.sum().asscalar()
329 330
                train_batches += 1

331 332 333 334 335 336 337 338
                if clip_global_grad_norm:
                    grads = []

                    for network in self._networks.values():
                        grads.extend([param.grad(mx_context) for param in network.collect_params().values()])

                    gluon.utils.clip_global_norm(grads, clip_global_grad_norm)

339 340
                for trainer in trainers:
                    trainer.step(batch_size)
Nicola Gatto's avatar
Nicola Gatto committed
341 342 343 344

                if tic is None:
                    tic = time.time()
                else:
345
                    if batch_i % log_period == 0:
Nicola Gatto's avatar
Nicola Gatto committed
346
                        try:
347
                            speed = log_period * batch_size / (time.time() - tic)
Nicola Gatto's avatar
Nicola Gatto committed
348 349 350
                        except ZeroDivisionError:
                            speed = float("inf")

351 352 353 354
                        loss_avg = loss_total / (batch_size * log_period)
                        loss_total = 0

                        logging.info("Epoch[%d] Batch[%d] Speed: %.2f samples/sec Loss: %.5f" % (epoch, batch_i, speed, loss_avg))
Nicola Gatto's avatar
Nicola Gatto committed
355 356 357

                        tic = time.time()

358
            global_loss_train /= (train_batches * batch_size)
359

Nicola Gatto's avatar
Nicola Gatto committed
360 361
            tic = None

362 363 364 365 366

            if eval_train:
                train_iter.reset()
                metric = mx.metric.create(eval_metric, **eval_metric_params)
                for batch_i, batch in enumerate(train_iter):
367
                    labels = [batch.label[i].as_in_context(mx_context) for i in range(1)]
368

369
                    data_ = batch.data[0].as_in_context(mx_context)
370

371
                    predictions_ = mx.nd.zeros((batch_size, 10,), ctx=mx_context)
372

373

374 375
                    nd.waitall()

376
                    outputs = []
377 378
                    lossList = []
                    attentionList = []
379
                    predictions_ = self._networks[0](data_)
380

381
                    outputs.append(predictions_)
382
                    lossList.append(loss_function(predictions_, labels[0]))
383

384 385

                    if save_attention_image == "True":
386 387
                        import matplotlib
                        matplotlib.use('Agg')
388 389 390 391 392 393 394
                        import matplotlib.pyplot as plt
                        logging.getLogger('matplotlib').setLevel(logging.ERROR)

                        if(os.path.isfile('src/test/resources/training_data/Show_attend_tell/dict.pkl')):
                            with open('src/test/resources/training_data/Show_attend_tell/dict.pkl', 'rb') as f:
                                dict = pickle.load(f)

395 396 397 398
                        plt.clf()
                        fig = plt.figure(figsize=(15,15))
                        max_length = len(labels)-1

399
                        ax = fig.add_subplot(max_length//3, max_length//4, 1)
400
                        ax.imshow(train_images[0+batch_size*(batch_i)].transpose(1,2,0))
401

402 403
                        for l in range(max_length):
                            attention = attentionList[l]
404
                            attention = mx.nd.slice_axis(attention, axis=0, begin=0, end=1).squeeze()
405
                            attention_resized = np.resize(attention.asnumpy(), (8, 8))
406
                            ax = fig.add_subplot(max_length//3, max_length//4, l+2)
407 408 409
                            if int(labels[l+1][0].asscalar()) > len(dict):
                                ax.set_title("<unk>")
                            elif dict[int(labels[l+1][0].asscalar())] == "<end>":
410
                                ax.set_title(".")
411
                                img = ax.imshow(train_images[0+batch_size*(batch_i)].transpose(1,2,0))
412 413 414 415
                                ax.imshow(attention_resized, cmap='gray', alpha=0.6, extent=img.get_extent())
                                break
                            else:
                                ax.set_title(dict[int(labels[l+1][0].asscalar())])
416
                            img = ax.imshow(train_images[0+batch_size*(batch_i)].transpose(1,2,0))
417
                            ax.imshow(attention_resized, cmap='gray', alpha=0.6, extent=img.get_extent())
418 419 420 421

                        plt.tight_layout()
                        target_dir = 'target/attention_images'
                        if not os.path.exists(target_dir):
422
                            os.makedirs(target_dir)
423 424 425
                        plt.savefig(target_dir + '/attention_train.png')
                        plt.close()

426 427 428 429 430 431
                    predictions = []
                    for output_name in outputs:
                        if mx.nd.shape_array(mx.nd.squeeze(output_name)).size > 1:
                            predictions.append(mx.nd.argmax(output_name, axis=1))
                        else:
                            predictions.append(output_name)
432

433 434 435 436
                    metric.update(preds=predictions, labels=labels)
                train_metric_score = metric.get()[1]
            else:
                train_metric_score = 0
Nicola Gatto's avatar
Nicola Gatto committed
437

438 439 440
            global_loss_test = 0.0
            test_batches = 0

Nicola Gatto's avatar
Nicola Gatto committed
441
            test_iter.reset()
Sebastian Nickels's avatar
Sebastian Nickels committed
442
            metric = mx.metric.create(eval_metric, **eval_metric_params)
Nicola Gatto's avatar
Nicola Gatto committed
443
            for batch_i, batch in enumerate(test_iter):
444
                if True:
445
                    labels = [batch.label[i].as_in_context(mx_context) for i in range(1)]
446

447
                    data_ = batch.data[0].as_in_context(mx_context)
448

449
                    predictions_ = mx.nd.zeros((batch_size, 10,), ctx=mx_context)
450

451

452 453
                    nd.waitall()

454
                    outputs = []
455 456
                    lossList = []
                    attentionList = []
457
                    predictions_ = self._networks[0](data_)
458

459
                    outputs.append(predictions_)
460
                    lossList.append(loss_function(predictions_, labels[0]))
461

462 463

                    if save_attention_image == "True":
464 465 466 467 468 469 470 471 472 473
                        if not eval_train:
                            import matplotlib
                            matplotlib.use('Agg')
                            import matplotlib.pyplot as plt
                            logging.getLogger('matplotlib').setLevel(logging.ERROR)

                            if(os.path.isfile('src/test/resources/training_data/Show_attend_tell/dict.pkl')):
                                with open('src/test/resources/training_data/Show_attend_tell/dict.pkl', 'rb') as f:
                                    dict = pickle.load(f)

474
                        plt.clf()
475
                        fig = plt.figure(figsize=(15,15))
476 477
                        max_length = len(labels)-1

478
                        ax = fig.add_subplot(max_length//3, max_length//4, 1)
479
                        ax.imshow(test_images[0+batch_size*(batch_i)].transpose(1,2,0))
480

481 482
                        for l in range(max_length):
                            attention = attentionList[l]
483
                            attention = mx.nd.slice_axis(attention, axis=0, begin=0, end=1).squeeze()
484
                            attention_resized = np.resize(attention.asnumpy(), (8, 8))
485
                            ax = fig.add_subplot(max_length//3, max_length//4, l+2)
486 487 488
                            if int(mx.nd.slice_axis(outputs[l+1], axis=0, begin=0, end=1).squeeze().asscalar()) > len(dict):
                                ax.set_title("<unk>")
                            elif dict[int(mx.nd.slice_axis(outputs[l+1], axis=0, begin=0, end=1).squeeze().asscalar())] == "<end>":
489
                                ax.set_title(".")
490
                                img = ax.imshow(test_images[0+batch_size*(batch_i)].transpose(1,2,0))
491 492 493
                                ax.imshow(attention_resized, cmap='gray', alpha=0.6, extent=img.get_extent())
                                break
                            else:
494
                                ax.set_title(dict[int(mx.nd.slice_axis(outputs[l+1], axis=0, begin=0, end=1).squeeze().asscalar())])
495
                            img = ax.imshow(test_images[0+batch_size*(batch_i)].transpose(1,2,0))
496
                            ax.imshow(attention_resized, cmap='gray', alpha=0.6, extent=img.get_extent())
497 498

                        plt.tight_layout()
499 500 501
                        target_dir = 'target/attention_images'
                        if not os.path.exists(target_dir):
                            os.makedirs(target_dir)
502 503
                        plt.savefig(target_dir + '/attention_test.png')
                        plt.close()
504 505 506 507
                loss = 0
                for element in lossList:
                    loss = loss + element

508
                global_loss_test += loss.sum().asscalar()
509
                test_batches += 1
510

511
                predictions = []
512
                for output_name in outputs:
Sebastian Nickels's avatar
Sebastian Nickels committed
513
                    if mx.nd.shape_array(mx.nd.squeeze(output_name)).size > 1:
514 515 516 517
                        predictions.append(mx.nd.argmax(output_name, axis=1))
                    #ArgMax already applied
                    else:
                        predictions.append(output_name)
518

519
                metric.update(preds=predictions, labels=labels)
Nicola Gatto's avatar
Nicola Gatto committed
520 521
            test_metric_score = metric.get()[1]

522
            global_loss_test /= (test_batches * batch_size)
Nicola Gatto's avatar
Nicola Gatto committed
523

524
            logging.info("Epoch[%d] Train metric: %f, Test metric: %f, Train loss: %f, Test loss: %f" % (epoch, train_metric_score, test_metric_score, global_loss_train, global_loss_test))
525

Nicola Gatto's avatar
Nicola Gatto committed
526
            if (epoch - begin_epoch) % checkpoint_period == 0:
527 528
                for i, network in self._networks.items():
                    network.save_parameters(self.parameter_path(i) + '-' + str(epoch).zfill(4) + '.params')
Nicola Gatto's avatar
Nicola Gatto committed
529

530
        for i, network in self._networks.items():
Sebastian Nickels's avatar
Sebastian Nickels committed
531
            network.save_parameters(self.parameter_path(i) + '-' + str(num_epoch + begin_epoch + 1).zfill(4) + '.params')
532
            network.export(self.parameter_path(i) + '_newest', epoch=0)
Nicola Gatto's avatar
Nicola Gatto committed
533

534
    def parameter_path(self, index):
Bernhard Rumpe's avatar
BR-sy  
Bernhard Rumpe committed
535
        return self._net_creator._model_dir_ + self._net_creator._model_prefix_ + '_' + str(index)