import mxnet as mx import logging import numpy as np import time import os import shutil from mxnet import gluon, autograd, nd class CrossEntropyLoss(gluon.loss.Loss): def __init__(self, axis=-1, sparse_label=True, weight=None, batch_axis=0, **kwargs): super(CrossEntropyLoss, self).__init__(weight, batch_axis, **kwargs) self._axis = axis self._sparse_label = sparse_label def hybrid_forward(self, F, pred, label, sample_weight=None): pred = F.log(pred) if self._sparse_label: loss = -F.pick(pred, label, axis=self._axis, keepdims=True) else: label = gluon.loss._reshape_like(F, label, pred) loss = -F.sum(pred * label, axis=self._axis, keepdims=True) loss = gluon.loss._apply_weighting(F, loss, self._weight, sample_weight) return F.mean(loss, axis=self._batch_axis, exclude=True) class LogCoshLoss(gluon.loss.Loss): def __init__(self, weight=None, batch_axis=0, **kwargs): super(LogCoshLoss, self).__init__(weight, batch_axis, **kwargs) def hybrid_forward(self, F, pred, label, sample_weight=None): loss = F.log(F.cosh(pred - label)) loss = gluon.loss._apply_weighting(F, loss, self._weight, sample_weight) return F.mean(loss, axis=self._batch_axis, exclude=True) class CNNSupervisedTrainer_mnist_mnistClassifier_net: def __init__(self, data_loader, net_constructor): self._data_loader = data_loader self._net_creator = net_constructor self._networks = {} def train(self, batch_size=64, num_epoch=10, eval_metric='acc', loss ='softmax_cross_entropy', loss_params={}, optimizer='adam', optimizer_params=(('learning_rate', 0.001),), load_checkpoint=True, context='gpu', checkpoint_period=5, normalize=True): if context == 'gpu': mx_context = mx.gpu() elif context == 'cpu': mx_context = mx.cpu() else: logging.error("Context argument is '" + context + "'. Only 'cpu' and 'gpu are valid arguments'.") if 'weight_decay' in optimizer_params: optimizer_params['wd'] = optimizer_params['weight_decay'] del optimizer_params['weight_decay'] if 'learning_rate_decay' in optimizer_params: min_learning_rate = 1e-08 if 'learning_rate_minimum' in optimizer_params: min_learning_rate = optimizer_params['learning_rate_minimum'] del optimizer_params['learning_rate_minimum'] optimizer_params['lr_scheduler'] = mx.lr_scheduler.FactorScheduler( optimizer_params['step_size'], factor=optimizer_params['learning_rate_decay'], stop_factor_lr=min_learning_rate) del optimizer_params['step_size'] del optimizer_params['learning_rate_decay'] train_iter, test_iter, data_mean, data_std = self._data_loader.load_data(batch_size) if normalize: self._net_creator.construct(context=mx_context, data_mean=data_mean, data_std=data_std) else: self._net_creator.construct(context=mx_context) begin_epoch = 0 if load_checkpoint: begin_epoch = self._net_creator.load(mx_context) else: if os.path.isdir(self._net_creator._model_dir_): shutil.rmtree(self._net_creator._model_dir_) self._networks = self._net_creator.networks try: os.makedirs(self._net_creator._model_dir_) except OSError: if not os.path.isdir(self._net_creator._model_dir_): raise trainers = [mx.gluon.Trainer(network.collect_params(), optimizer, optimizer_params) for network in self._networks.values()] margin = loss_params['margin'] if 'margin' in loss_params else 1.0 sparseLabel = loss_params['sparse_label'] if 'sparse_label' in loss_params else True if loss == 'softmax_cross_entropy': fromLogits = loss_params['from_logits'] if 'from_logits' in loss_params else False loss_function = mx.gluon.loss.SoftmaxCrossEntropyLoss(from_logits=fromLogits, sparse_label=sparseLabel) elif loss == 'sigmoid_binary_cross_entropy': loss_function = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss() elif loss == 'cross_entropy': loss_function = CrossEntropyLoss(sparse_label=sparseLabel) elif loss == 'l2': loss_function = mx.gluon.loss.L2Loss() elif loss == 'l1': loss_function = mx.gluon.loss.L2Loss() elif loss == 'huber': rho = loss_params['rho'] if 'rho' in loss_params else 1 loss_function = mx.gluon.loss.HuberLoss(rho=rho) elif loss == 'hinge': loss_function = mx.gluon.loss.HingeLoss(margin=margin) elif loss == 'squared_hinge': loss_function = mx.gluon.loss.SquaredHingeLoss(margin=margin) elif loss == 'logistic': labelFormat = loss_params['label_format'] if 'label_format' in loss_params else 'signed' loss_function = mx.gluon.loss.LogisticLoss(label_format=labelFormat) elif loss == 'kullback_leibler': fromLogits = loss_params['from_logits'] if 'from_logits' in loss_params else True loss_function = mx.gluon.loss.KLDivLoss(from_logits=fromLogits) elif loss == 'log_cosh': loss_function = LogCoshLoss() else: logging.error("Invalid loss parameter.") speed_period = 50 tic = None for epoch in range(begin_epoch, begin_epoch + num_epoch): train_iter.reset() for batch_i, batch in enumerate(train_iter): image_ = batch.data[0].as_in_context(mx_context) predictions_label = batch.label[0].as_in_context(mx_context) with autograd.record(): predictions_ = mx.nd.zeros((10,), ctx=mx_context) predictions_ = self._networks[0](image_) loss = \ loss_function(predictions_, predictions_label) loss.backward() for trainer in trainers: trainer.step(batch_size) if tic is None: tic = time.time() else: if batch_i % speed_period == 0: try: speed = speed_period * batch_size / (time.time() - tic) except ZeroDivisionError: speed = float("inf") logging.info("Epoch[%d] Batch[%d] Speed: %.2f samples/sec" % (epoch, batch_i, speed)) tic = time.time() tic = None train_iter.reset() metric = mx.metric.create(eval_metric) for batch_i, batch in enumerate(train_iter): image_ = batch.data[0].as_in_context(mx_context) labels = [ batch.label[0].as_in_context(mx_context) ] if True: predictions_ = mx.nd.zeros((10,), ctx=mx_context) predictions_ = self._networks[0](image_) predictions = [ mx.nd.argmax(predictions_, axis=1) ] metric.update(preds=predictions, labels=labels) train_metric_score = metric.get()[1] test_iter.reset() metric = mx.metric.create(eval_metric) for batch_i, batch in enumerate(test_iter): image_ = batch.data[0].as_in_context(mx_context) labels = [ batch.label[0].as_in_context(mx_context) ] if True: predictions_ = mx.nd.zeros((10,), ctx=mx_context) predictions_ = self._networks[0](image_) predictions = [ mx.nd.argmax(predictions_, axis=1) ] metric.update(preds=predictions, labels=labels) test_metric_score = metric.get()[1] logging.info("Epoch[%d] Train: %f, Test: %f" % (epoch, train_metric_score, test_metric_score)) if (epoch - begin_epoch) % checkpoint_period == 0: for i, network in self._networks.items(): network.save_parameters(self.parameter_path(i) + '-' + str(epoch).zfill(4) + '.params') for i, network in self._networks.items(): network.save_parameters(self.parameter_path(i) + '-' + str(num_epoch + begin_epoch).zfill(4) + '.params') network.export(self.parameter_path(i) + '_newest', epoch=0) def parameter_path(self, index): return self._net_creator._model_dir_ + self._net_creator._model_prefix_ + '_' + str(index)