CNNSupervisedTrainer_CifarClassifierNetwork.py 20.5 KB
Newer Older
Nicola Gatto's avatar
Nicola Gatto committed
1
2
3
4
5
6
import mxnet as mx
import logging
import numpy as np
import time
import os
import shutil
Christian Fuß's avatar
Christian Fuß committed
7
import pickle
Sebastian Nickels's avatar
Sebastian Nickels committed
8
9
import math
import sys
Nicola Gatto's avatar
Nicola Gatto committed
10
11
from mxnet import gluon, autograd, nd

Eyüp Harputlu's avatar
Eyüp Harputlu committed
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
class CrossEntropyLoss(gluon.loss.Loss):
    def __init__(self, axis=-1, sparse_label=True, weight=None, batch_axis=0, **kwargs):
        super(CrossEntropyLoss, self).__init__(weight, batch_axis, **kwargs)
        self._axis = axis
        self._sparse_label = sparse_label

    def hybrid_forward(self, F, pred, label, sample_weight=None):
        pred = F.log(pred)
        if self._sparse_label:
            loss = -F.pick(pred, label, axis=self._axis, keepdims=True)
        else:
            label = gluon.loss._reshape_like(F, label, pred)
            loss = -F.sum(pred * label, axis=self._axis, keepdims=True)
        loss = gluon.loss._apply_weighting(F, loss, self._weight, sample_weight)
        return F.mean(loss, axis=self._batch_axis, exclude=True)

Eyüp Harputlu's avatar
Eyüp Harputlu committed
28
29
30
31
32
33
34
35
36
class LogCoshLoss(gluon.loss.Loss):
    def __init__(self, weight=None, batch_axis=0, **kwargs):
        super(LogCoshLoss, self).__init__(weight, batch_axis, **kwargs)

    def hybrid_forward(self, F, pred, label, sample_weight=None):
        loss = F.log(F.cosh(pred - label))
        loss = gluon.loss._apply_weighting(F, loss, self._weight, sample_weight)
        return F.mean(loss, axis=self._batch_axis, exclude=True)

Sebastian Nickels's avatar
Sebastian Nickels committed
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
class SoftmaxCrossEntropyLossIgnoreIndices(gluon.loss.Loss):
    def __init__(self, axis=-1, ignore_indices=[], sparse_label=True, from_logits=False, weight=None, batch_axis=0, **kwargs):
        super(SoftmaxCrossEntropyLossIgnoreIndices, self).__init__(weight, batch_axis, **kwargs)
        self._axis = axis
        self._ignore_indices = ignore_indices
        self._sparse_label = sparse_label
        self._from_logits = from_logits

    def hybrid_forward(self, F, pred, label, sample_weight=None):
        log_softmax = F.log_softmax
        pick = F.pick
        if not self._from_logits:
            pred = log_softmax(pred, self._axis)
        if self._sparse_label:
            loss = -pick(pred, label, axis=self._axis, keepdims=True)
        else:
            label = _reshape_like(F, label, pred)
            loss = -(pred * label).sum(axis=self._axis, keepdims=True)
        # ignore some indices for loss, e.g. <pad> tokens in NLP applications
        for i in self._ignore_indices:
57
            loss = loss * mx.nd.logical_not(mx.nd.equal(mx.nd.argmax(pred, axis=1), mx.nd.ones_like(mx.nd.argmax(pred, axis=1))*i) * mx.nd.equal(mx.nd.argmax(pred, axis=1), label))
Sebastian Nickels's avatar
Sebastian Nickels committed
58
59
        return loss.mean(axis=self._batch_axis, exclude=True)

Sebastian Nickels's avatar
Sebastian Nickels committed
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
@mx.metric.register
class BLEU(mx.metric.EvalMetric):
    N = 4

    def __init__(self, exclude=None, name='bleu', output_names=None, label_names=None):
        super(BLEU, self).__init__(name=name, output_names=output_names, label_names=label_names)

        self._exclude = exclude or []

        self._match_counts = [0 for _ in range(self.N)]
        self._counts = [0 for _ in range(self.N)]

        self._size_ref = 0
        self._size_hyp = 0

    def update(self, labels, preds):
        labels, preds = mx.metric.check_label_shapes(labels, preds, True)

        new_labels = self._convert(labels)
        new_preds = self._convert(preds)

        for label, pred in zip(new_labels, new_preds):
            reference = [word for word in label if word not in self._exclude]
            hypothesis = [word for word in pred if word not in self._exclude]

            self._size_ref += len(reference)
            self._size_hyp += len(hypothesis)

            for n in range(self.N):
                reference_ngrams = self._get_ngrams(reference, n + 1)
                hypothesis_ngrams = self._get_ngrams(hypothesis, n + 1)

                match_count = 0

                for ngram in hypothesis_ngrams:
                    if ngram in reference_ngrams:
                        reference_ngrams.remove(ngram)

                        match_count += 1

                self._match_counts[n] += match_count
                self._counts[n] += len(hypothesis_ngrams)

    def get(self):
        precisions = [sys.float_info.min for n in range(self.N)]

        i = 1

        for n in range(self.N):
            match_counts = self._match_counts[n]
            counts = self._counts[n]

            if counts != 0:
                if match_counts == 0:
                    i *= 2
                    match_counts = 1 / i

117
118
                if (match_counts / counts) > 0:
                    precisions[n] = match_counts / counts
Sebastian Nickels's avatar
Sebastian Nickels committed
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169

        bleu = self._get_brevity_penalty() * math.exp(sum(map(math.log, precisions)) / self.N)

        return (self.name, bleu)

    def calculate(self):
        precisions = [sys.float_info.min for n in range(self.N)]

        i = 1

        for n in range(self.N):
            match_counts = self._match_counts[n]
            counts = self._counts[n]

            if counts != 0:
                if match_counts == 0:
                    i *= 2
                    match_counts = 1 / i

                precisions[n] = match_counts / counts

        return self._get_brevity_penalty() * math.exp(sum(map(math.log, precisions)) / self.N)

    def _get_brevity_penalty(self):
        if self._size_hyp >= self._size_ref:
            return 1
        else:
            return math.exp(1 - (self._size_ref / self._size_hyp))

    @staticmethod
    def _get_ngrams(sentence, n):
        ngrams = []

        if len(sentence) >= n:
            for i in range(len(sentence) - n + 1):
                ngrams.append(sentence[i:i+n])

        return ngrams

    @staticmethod
    def _convert(nd_list):
        if len(nd_list) == 0:
            return []

        new_list = [[] for _ in range(nd_list[0].shape[0])]

        for element in nd_list:
            for i in range(element.shape[0]):
                new_list[i].append(element[i].asscalar())

        return new_list
Christian Fuß's avatar
Christian Fuß committed
170

Sebastian Nickels's avatar
Sebastian Nickels committed
171
172


173
class CNNSupervisedTrainer_CifarClassifierNetwork:
174
    def __init__(self, data_loader, net_constructor):
Nicola Gatto's avatar
Nicola Gatto committed
175
176
        self._data_loader = data_loader
        self._net_creator = net_constructor
177
        self._networks = {}
Nicola Gatto's avatar
Nicola Gatto committed
178
179
180
181

    def train(self, batch_size=64,
              num_epoch=10,
              eval_metric='acc',
Sebastian Nickels's avatar
Sebastian Nickels committed
182
              eval_metric_params={},
183
              eval_train=False,
Eyüp Harputlu's avatar
Eyüp Harputlu committed
184
185
              loss ='softmax_cross_entropy',
              loss_params={},
Nicola Gatto's avatar
Nicola Gatto committed
186
187
188
189
              optimizer='adam',
              optimizer_params=(('learning_rate', 0.001),),
              load_checkpoint=True,
              checkpoint_period=5,
190
191
              log_period=50,
              context='gpu',
192
              save_attention_image=False,
193
              use_teacher_forcing=False,
194
195
              normalize=True,
              preprocessing = False):
Nicola Gatto's avatar
Nicola Gatto committed
196
197
198
199
200
201
202
        if context == 'gpu':
            mx_context = mx.gpu()
        elif context == 'cpu':
            mx_context = mx.cpu()
        else:
            logging.error("Context argument is '" + context + "'. Only 'cpu' and 'gpu are valid arguments'.")

203
204
205
206
207
208
        if preprocessing:
            preproc_lib = "CNNPreprocessor_CifarClassifierNetwork_executor"
            train_iter, test_iter, data_mean, data_std, train_images, test_images = self._data_loader.load_preprocessed_data(batch_size, preproc_lib)
        else:
            train_iter, test_iter, data_mean, data_std, train_images, test_images = self._data_loader.load_data(batch_size)

Nicola Gatto's avatar
Nicola Gatto committed
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
        if 'weight_decay' in optimizer_params:
            optimizer_params['wd'] = optimizer_params['weight_decay']
            del optimizer_params['weight_decay']
        if 'learning_rate_decay' in optimizer_params:
            min_learning_rate = 1e-08
            if 'learning_rate_minimum' in optimizer_params:
                min_learning_rate = optimizer_params['learning_rate_minimum']
                del optimizer_params['learning_rate_minimum']
            optimizer_params['lr_scheduler'] = mx.lr_scheduler.FactorScheduler(
                                                   optimizer_params['step_size'],
                                                   factor=optimizer_params['learning_rate_decay'],
                                                   stop_factor_lr=min_learning_rate)
            del optimizer_params['step_size']
            del optimizer_params['learning_rate_decay']

224
225
226
227
        if normalize:
            self._net_creator.construct(context=mx_context, data_mean=data_mean, data_std=data_std)
        else:
            self._net_creator.construct(context=mx_context)
Nicola Gatto's avatar
Nicola Gatto committed
228
229
230
231
232
233
234
235

        begin_epoch = 0
        if load_checkpoint:
            begin_epoch = self._net_creator.load(mx_context)
        else:
            if os.path.isdir(self._net_creator._model_dir_):
                shutil.rmtree(self._net_creator._model_dir_)

236
        self._networks = self._net_creator.networks
Nicola Gatto's avatar
Nicola Gatto committed
237
238
239
240
241
242
243

        try:
            os.makedirs(self._net_creator._model_dir_)
        except OSError:
            if not os.path.isdir(self._net_creator._model_dir_):
                raise

Sebastian Nickels's avatar
Sebastian Nickels committed
244
        trainers = [mx.gluon.Trainer(network.collect_params(), optimizer, optimizer_params) for network in self._networks.values() if len(network.collect_params().values()) != 0]
Nicola Gatto's avatar
Nicola Gatto committed
245

Eyüp Harputlu's avatar
Eyüp Harputlu committed
246
247
        margin = loss_params['margin'] if 'margin' in loss_params else 1.0
        sparseLabel = loss_params['sparse_label'] if 'sparse_label' in loss_params else True
248
        ignore_indices = [loss_params['ignore_indices']] if 'ignore_indices' in loss_params else []
Eyüp Harputlu's avatar
Eyüp Harputlu committed
249
250
        if loss == 'softmax_cross_entropy':
            fromLogits = loss_params['from_logits'] if 'from_logits' in loss_params else False
251
            loss_function = mx.gluon.loss.SoftmaxCrossEntropyLoss(from_logits=fromLogits, sparse_label=sparseLabel)
252
        elif loss == 'softmax_cross_entropy_ignore_indices':
253
            fromLogits = loss_params['from_logits'] if 'from_logits' in loss_params else False
254
            loss_function = SoftmaxCrossEntropyLossIgnoreIndices(ignore_indices=ignore_indices, from_logits=fromLogits, sparse_label=sparseLabel)
Eyüp Harputlu's avatar
Eyüp Harputlu committed
255
        elif loss == 'sigmoid_binary_cross_entropy':
Nicola Gatto's avatar
Nicola Gatto committed
256
            loss_function = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss()
Eyüp Harputlu's avatar
Eyüp Harputlu committed
257
258
259
        elif loss == 'cross_entropy':
            loss_function = CrossEntropyLoss(sparse_label=sparseLabel)
        elif loss == 'l2':
Nicola Gatto's avatar
Nicola Gatto committed
260
            loss_function = mx.gluon.loss.L2Loss()
Eyüp Harputlu's avatar
Eyüp Harputlu committed
261
        elif loss == 'l1':
Nicola Gatto's avatar
Nicola Gatto committed
262
            loss_function = mx.gluon.loss.L2Loss()
Eyüp Harputlu's avatar
Eyüp Harputlu committed
263
264
265
266
267
268
269
270
271
272
273
274
275
        elif loss == 'huber':
            rho = loss_params['rho'] if 'rho' in loss_params else 1
            loss_function = mx.gluon.loss.HuberLoss(rho=rho)
        elif loss == 'hinge':
            loss_function = mx.gluon.loss.HingeLoss(margin=margin)
        elif loss == 'squared_hinge':
            loss_function = mx.gluon.loss.SquaredHingeLoss(margin=margin)
        elif loss == 'logistic':
            labelFormat = loss_params['label_format'] if 'label_format' in loss_params else 'signed'
            loss_function = mx.gluon.loss.LogisticLoss(label_format=labelFormat)
        elif loss == 'kullback_leibler':
            fromLogits = loss_params['from_logits'] if 'from_logits' in loss_params else True
            loss_function = mx.gluon.loss.KLDivLoss(from_logits=fromLogits)
Eyüp Harputlu's avatar
Eyüp Harputlu committed
276
277
        elif loss == 'log_cosh':
            loss_function = LogCoshLoss()
Eyüp Harputlu's avatar
Eyüp Harputlu committed
278
279
        else:
            logging.error("Invalid loss parameter.")
Nicola Gatto's avatar
Nicola Gatto committed
280
281
282
283

        tic = None

        for epoch in range(begin_epoch, begin_epoch + num_epoch):
284
285

            loss_total = 0
Nicola Gatto's avatar
Nicola Gatto committed
286
287
            train_iter.reset()
            for batch_i, batch in enumerate(train_iter):
288
289
                with autograd.record():
                    labels = [batch.label[i].as_in_context(mx_context) for i in range(1)]
290

291
292
                    data_ = batch.data[0].as_in_context(mx_context)

293
                    softmax_ = mx.nd.zeros((batch_size, 10,), ctx=mx_context)
Christian Fuß's avatar
Christian Fuß committed
294

295

296
297
                    nd.waitall()

298
                    lossList = []
299

300
                    softmax_ = self._networks[0](data_)
301
302

                    lossList.append(loss_function(softmax_, labels[0]))
303

304
305
306
                    loss = 0
                    for element in lossList:
                        loss = loss + element
Nicola Gatto's avatar
Nicola Gatto committed
307
308

                loss.backward()
309

310
311
                loss_total += loss.sum().asscalar()

312
313
                for trainer in trainers:
                    trainer.step(batch_size)
Nicola Gatto's avatar
Nicola Gatto committed
314
315
316
317

                if tic is None:
                    tic = time.time()
                else:
318
                    if batch_i % log_period == 0:
Nicola Gatto's avatar
Nicola Gatto committed
319
                        try:
320
                            speed = log_period * batch_size / (time.time() - tic)
Nicola Gatto's avatar
Nicola Gatto committed
321
322
323
                        except ZeroDivisionError:
                            speed = float("inf")

324
325
326
327
                        loss_avg = loss_total / (batch_size * log_period)
                        loss_total = 0

                        logging.info("Epoch[%d] Batch[%d] Speed: %.2f samples/sec Loss: %.5f" % (epoch, batch_i, speed, loss_avg))
Nicola Gatto's avatar
Nicola Gatto committed
328
329
330
331
332

                        tic = time.time()

            tic = None

333
334
335
336
337

            if eval_train:
                train_iter.reset()
                metric = mx.metric.create(eval_metric, **eval_metric_params)
                for batch_i, batch in enumerate(train_iter):
338
                    labels = [batch.label[i].as_in_context(mx_context) for i in range(1)]
339

340
                    data_ = batch.data[0].as_in_context(mx_context)
341

342
                    softmax_ = mx.nd.zeros((batch_size, 10,), ctx=mx_context)
343

344

345
346
                    nd.waitall()

347
                    outputs = []
348
                    attentionList=[]
349
                    softmax_ = self._networks[0](data_)
350

351
                    outputs.append(softmax_)
352

353
354

                    if save_attention_image == "True":
355
356
                        import matplotlib
                        matplotlib.use('Agg')
357
358
359
360
361
362
363
                        import matplotlib.pyplot as plt
                        logging.getLogger('matplotlib').setLevel(logging.ERROR)

                        if(os.path.isfile('src/test/resources/training_data/Show_attend_tell/dict.pkl')):
                            with open('src/test/resources/training_data/Show_attend_tell/dict.pkl', 'rb') as f:
                                dict = pickle.load(f)

364
365
366
367
                        plt.clf()
                        fig = plt.figure(figsize=(15,15))
                        max_length = len(labels)-1

368
                        ax = fig.add_subplot(max_length//3, max_length//4, 1)
369
                        ax.imshow(train_images[0+batch_size*(batch_i)].transpose(1,2,0))
370

371
372
                        for l in range(max_length):
                            attention = attentionList[l]
373
                            attention = mx.nd.slice_axis(attention, axis=0, begin=0, end=1).squeeze()
374
                            attention_resized = np.resize(attention.asnumpy(), (8, 8))
375
                            ax = fig.add_subplot(max_length//3, max_length//4, l+2)
376
377
378
                            if int(labels[l+1][0].asscalar()) > len(dict):
                                ax.set_title("<unk>")
                            elif dict[int(labels[l+1][0].asscalar())] == "<end>":
379
                                ax.set_title(".")
380
                                img = ax.imshow(train_images[0+batch_size*(batch_i)].transpose(1,2,0))
381
382
383
384
                                ax.imshow(attention_resized, cmap='gray', alpha=0.6, extent=img.get_extent())
                                break
                            else:
                                ax.set_title(dict[int(labels[l+1][0].asscalar())])
385
                            img = ax.imshow(train_images[0+batch_size*(batch_i)].transpose(1,2,0))
386
                            ax.imshow(attention_resized, cmap='gray', alpha=0.6, extent=img.get_extent())
387
388
389
390

                        plt.tight_layout()
                        target_dir = 'target/attention_images'
                        if not os.path.exists(target_dir):
391
                            os.makedirs(target_dir)
392
393
394
                        plt.savefig(target_dir + '/attention_train.png')
                        plt.close()

395
396
397
398
399
400
                    predictions = []
                    for output_name in outputs:
                        if mx.nd.shape_array(mx.nd.squeeze(output_name)).size > 1:
                            predictions.append(mx.nd.argmax(output_name, axis=1))
                        else:
                            predictions.append(output_name)
401

402
403
404
405
                    metric.update(preds=predictions, labels=labels)
                train_metric_score = metric.get()[1]
            else:
                train_metric_score = 0
Nicola Gatto's avatar
Nicola Gatto committed
406
407

            test_iter.reset()
Sebastian Nickels's avatar
Sebastian Nickels committed
408
            metric = mx.metric.create(eval_metric, **eval_metric_params)
Nicola Gatto's avatar
Nicola Gatto committed
409
            for batch_i, batch in enumerate(test_iter):
410
                if True: 
411
                    labels = [batch.label[i].as_in_context(mx_context) for i in range(1)]
412

413
                    data_ = batch.data[0].as_in_context(mx_context)
414

415
                    softmax_ = mx.nd.zeros((batch_size, 10,), ctx=mx_context)
416

417

418
419
                    nd.waitall()

420
                    outputs = []
421
                    attentionList=[]
422
                    softmax_ = self._networks[0](data_)
423

424
                    outputs.append(softmax_)
425

426
427

                    if save_attention_image == "True":
428
429
430
431
432
433
434
435
436
437
                        if not eval_train:
                            import matplotlib
                            matplotlib.use('Agg')
                            import matplotlib.pyplot as plt
                            logging.getLogger('matplotlib').setLevel(logging.ERROR)

                            if(os.path.isfile('src/test/resources/training_data/Show_attend_tell/dict.pkl')):
                                with open('src/test/resources/training_data/Show_attend_tell/dict.pkl', 'rb') as f:
                                    dict = pickle.load(f)

438
                        plt.clf()
439
                        fig = plt.figure(figsize=(15,15))
440
441
                        max_length = len(labels)-1

442
                        ax = fig.add_subplot(max_length//3, max_length//4, 1)
443
                        ax.imshow(test_images[0+batch_size*(batch_i)].transpose(1,2,0))
444

445
446
                        for l in range(max_length):
                            attention = attentionList[l]
447
                            attention = mx.nd.slice_axis(attention, axis=0, begin=0, end=1).squeeze()
448
                            attention_resized = np.resize(attention.asnumpy(), (8, 8))
449
                            ax = fig.add_subplot(max_length//3, max_length//4, l+2)
450
451
452
                            if int(mx.nd.slice_axis(outputs[l+1], axis=0, begin=0, end=1).squeeze().asscalar()) > len(dict):
                                ax.set_title("<unk>")
                            elif dict[int(mx.nd.slice_axis(outputs[l+1], axis=0, begin=0, end=1).squeeze().asscalar())] == "<end>":
453
                                ax.set_title(".")
454
                                img = ax.imshow(test_images[0+batch_size*(batch_i)].transpose(1,2,0))
455
456
457
                                ax.imshow(attention_resized, cmap='gray', alpha=0.6, extent=img.get_extent())
                                break
                            else:
458
                                ax.set_title(dict[int(mx.nd.slice_axis(outputs[l+1], axis=0, begin=0, end=1).squeeze().asscalar())])
459
                            img = ax.imshow(test_images[0+batch_size*(batch_i)].transpose(1,2,0))
460
                            ax.imshow(attention_resized, cmap='gray', alpha=0.6, extent=img.get_extent())
461
462

                        plt.tight_layout()
463
464
465
                        target_dir = 'target/attention_images'
                        if not os.path.exists(target_dir):
                            os.makedirs(target_dir)
466
467
468
                        plt.savefig(target_dir + '/attention_test.png')
                        plt.close()

469
                predictions = []
470
                for output_name in outputs:
Sebastian Nickels's avatar
Sebastian Nickels committed
471
                    if mx.nd.shape_array(mx.nd.squeeze(output_name)).size > 1:
472
473
474
475
                        predictions.append(mx.nd.argmax(output_name, axis=1))
                    #ArgMax already applied
                    else:
                        predictions.append(output_name)
476

477
                metric.update(preds=predictions, labels=labels)
Nicola Gatto's avatar
Nicola Gatto committed
478
479
480
481
            test_metric_score = metric.get()[1]

            logging.info("Epoch[%d] Train: %f, Test: %f" % (epoch, train_metric_score, test_metric_score))

482

Nicola Gatto's avatar
Nicola Gatto committed
483
            if (epoch - begin_epoch) % checkpoint_period == 0:
484
485
                for i, network in self._networks.items():
                    network.save_parameters(self.parameter_path(i) + '-' + str(epoch).zfill(4) + '.params')
Nicola Gatto's avatar
Nicola Gatto committed
486

487
488
489
        for i, network in self._networks.items():
            network.save_parameters(self.parameter_path(i) + '-' + str(num_epoch + begin_epoch).zfill(4) + '.params')
            network.export(self.parameter_path(i) + '_newest', epoch=0)
Nicola Gatto's avatar
Nicola Gatto committed
490

491
    def parameter_path(self, index):
Bernhard Rumpe's avatar
BR-sy    
Bernhard Rumpe committed
492
        return self._net_creator._model_dir_ + self._net_creator._model_prefix_ + '_' + str(index)