CNNSupervisedTrainer.ftl 11.4 KB
Newer Older
Nicola Gatto's avatar
Nicola Gatto committed
1 2 3 4 5 6 7 8
import mxnet as mx
import logging
import numpy as np
import time
import os
import shutil
from mxnet import gluon, autograd, nd

Eyüp Harputlu's avatar
Eyüp Harputlu committed
9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
class CrossEntropyLoss(gluon.loss.Loss):
    def __init__(self, axis=-1, sparse_label=True, weight=None, batch_axis=0, **kwargs):
        super(CrossEntropyLoss, self).__init__(weight, batch_axis, **kwargs)
        self._axis = axis
        self._sparse_label = sparse_label

    def hybrid_forward(self, F, pred, label, sample_weight=None):
        pred = F.log(pred)
        if self._sparse_label:
            loss = -F.pick(pred, label, axis=self._axis, keepdims=True)
        else:
            label = gluon.loss._reshape_like(F, label, pred)
            loss = -F.sum(pred * label, axis=self._axis, keepdims=True)
        loss = gluon.loss._apply_weighting(F, loss, self._weight, sample_weight)
        return F.mean(loss, axis=self._batch_axis, exclude=True)

Eyüp Harputlu's avatar
Eyüp Harputlu committed
25 26 27 28 29 30 31 32 33
class LogCoshLoss(gluon.loss.Loss):
    def __init__(self, weight=None, batch_axis=0, **kwargs):
        super(LogCoshLoss, self).__init__(weight, batch_axis, **kwargs)

    def hybrid_forward(self, F, pred, label, sample_weight=None):
        loss = F.log(F.cosh(pred - label))
        loss = gluon.loss._apply_weighting(F, loss, self._weight, sample_weight)
        return F.mean(loss, axis=self._batch_axis, exclude=True)

34
class ${tc.fileNameWithoutEnding}:
Christian Fuß's avatar
Christian Fuß committed
35
    def applyBeamSearch(input, length, width, maxLength, currProb, netIndex, bestOutput):
36
        bestProb = 0.0
Christian Fuß's avatar
Christian Fuß committed
37 38
        while length < maxLength:
            length += 1
39 40 41 42 43 44 45 46
            batchIndex = 0
            for batchEntry in input:
                top_k_indices = mx.nd.topk(batchEntry, axis=0, k=width)
                top_k_values = mx.nd.topk(batchEntry, ret_typ='value', axis=0, k=width)
                for index in range(top_k_indices.size):

                    #print mx.nd.array(top_k_indices[index])
                    #print top_k_values[index]
Christian Fuß's avatar
Christian Fuß committed
47
                    if length == 1:
48
                        #print mx.nd.array(top_k_indices[index])
Christian Fuß's avatar
Christian Fuß committed
49
                        result = applyBeamSearch(self._networks[netIndex](mx.nd.array(top_k_indices[index])), length, width, maxLength,
50 51
                            currProb * top_k_values[index], netIndex, self._networks[netIndex](mx.nd.array(top_k_indices[index])))
                    else:
Christian Fuß's avatar
Christian Fuß committed
52
                        result = applyBeamSearch(self._networks[netIndex](mx.nd.array(top_k_indices[index])), length, width, maxLength,
53 54
                            currProb * top_k_values[index], netIndex, bestOutput)

Christian Fuß's avatar
Christian Fuß committed
55
                    if length == maxLength:
56 57 58 59 60 61 62 63 64 65 66 67
                        #print currProb
                        if currProb > bestProb:
                            bestProb = currProb
                            bestOutput[batchIndex] = result[batchIndex]
                            #print "new bestOutput: ", bestOutput

                batchIndex += 1
        #print bestOutput
        #print bestProb
        return bestOutput


68
    def __init__(self, data_loader, net_constructor):
Nicola Gatto's avatar
Nicola Gatto committed
69 70
        self._data_loader = data_loader
        self._net_creator = net_constructor
71
        self._networks = {}
Nicola Gatto's avatar
Nicola Gatto committed
72 73 74 75

    def train(self, batch_size=64,
              num_epoch=10,
              eval_metric='acc',
Eyüp Harputlu's avatar
Eyüp Harputlu committed
76 77
              loss ='softmax_cross_entropy',
              loss_params={},
Nicola Gatto's avatar
Nicola Gatto committed
78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107
              optimizer='adam',
              optimizer_params=(('learning_rate', 0.001),),
              load_checkpoint=True,
              context='gpu',
              checkpoint_period=5,
              normalize=True):
        if context == 'gpu':
            mx_context = mx.gpu()
        elif context == 'cpu':
            mx_context = mx.cpu()
        else:
            logging.error("Context argument is '" + context + "'. Only 'cpu' and 'gpu are valid arguments'.")

        if 'weight_decay' in optimizer_params:
            optimizer_params['wd'] = optimizer_params['weight_decay']
            del optimizer_params['weight_decay']
        if 'learning_rate_decay' in optimizer_params:
            min_learning_rate = 1e-08
            if 'learning_rate_minimum' in optimizer_params:
                min_learning_rate = optimizer_params['learning_rate_minimum']
                del optimizer_params['learning_rate_minimum']
            optimizer_params['lr_scheduler'] = mx.lr_scheduler.FactorScheduler(
                                                   optimizer_params['step_size'],
                                                   factor=optimizer_params['learning_rate_decay'],
                                                   stop_factor_lr=min_learning_rate)
            del optimizer_params['step_size']
            del optimizer_params['learning_rate_decay']


        train_iter, test_iter, data_mean, data_std = self._data_loader.load_data(batch_size)
108 109 110 111 112

        if normalize:
            self._net_creator.construct(context=mx_context, data_mean=data_mean, data_std=data_std)
        else:
            self._net_creator.construct(context=mx_context)
Nicola Gatto's avatar
Nicola Gatto committed
113 114 115 116 117 118 119 120

        begin_epoch = 0
        if load_checkpoint:
            begin_epoch = self._net_creator.load(mx_context)
        else:
            if os.path.isdir(self._net_creator._model_dir_):
                shutil.rmtree(self._net_creator._model_dir_)

121
        self._networks = self._net_creator.networks
Nicola Gatto's avatar
Nicola Gatto committed
122 123 124 125 126 127 128

        try:
            os.makedirs(self._net_creator._model_dir_)
        except OSError:
            if not os.path.isdir(self._net_creator._model_dir_):
                raise

129
        trainers = [mx.gluon.Trainer(network.collect_params(), optimizer, optimizer_params) for network in self._networks.values()]
Nicola Gatto's avatar
Nicola Gatto committed
130

Eyüp Harputlu's avatar
Eyüp Harputlu committed
131 132 133 134 135 136
        margin = loss_params['margin'] if 'margin' in loss_params else 1.0
        sparseLabel = loss_params['sparse_label'] if 'sparse_label' in loss_params else True
        if loss == 'softmax_cross_entropy':
            fromLogits = loss_params['from_logits'] if 'from_logits' in loss_params else False
            loss_function = mx.gluon.loss.SoftmaxCrossEntropyLoss(from_logits=fromLogits, sparse_label=sparseLabel)
        elif loss == 'sigmoid_binary_cross_entropy':
Nicola Gatto's avatar
Nicola Gatto committed
137
            loss_function = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss()
Eyüp Harputlu's avatar
Eyüp Harputlu committed
138 139 140
        elif loss == 'cross_entropy':
            loss_function = CrossEntropyLoss(sparse_label=sparseLabel)
        elif loss == 'l2':
Nicola Gatto's avatar
Nicola Gatto committed
141
            loss_function = mx.gluon.loss.L2Loss()
Eyüp Harputlu's avatar
Eyüp Harputlu committed
142
        elif loss == 'l1':
Nicola Gatto's avatar
Nicola Gatto committed
143
            loss_function = mx.gluon.loss.L2Loss()
Eyüp Harputlu's avatar
Eyüp Harputlu committed
144 145 146 147 148 149 150 151 152 153 154 155 156
        elif loss == 'huber':
            rho = loss_params['rho'] if 'rho' in loss_params else 1
            loss_function = mx.gluon.loss.HuberLoss(rho=rho)
        elif loss == 'hinge':
            loss_function = mx.gluon.loss.HingeLoss(margin=margin)
        elif loss == 'squared_hinge':
            loss_function = mx.gluon.loss.SquaredHingeLoss(margin=margin)
        elif loss == 'logistic':
            labelFormat = loss_params['label_format'] if 'label_format' in loss_params else 'signed'
            loss_function = mx.gluon.loss.LogisticLoss(label_format=labelFormat)
        elif loss == 'kullback_leibler':
            fromLogits = loss_params['from_logits'] if 'from_logits' in loss_params else True
            loss_function = mx.gluon.loss.KLDivLoss(from_logits=fromLogits)
Eyüp Harputlu's avatar
Eyüp Harputlu committed
157 158
        elif loss == 'log_cosh':
            loss_function = LogCoshLoss()
Eyüp Harputlu's avatar
Eyüp Harputlu committed
159 160
        else:
            logging.error("Invalid loss parameter.")
Nicola Gatto's avatar
Nicola Gatto committed
161 162 163 164 165 166 167

        speed_period = 50
        tic = None

        for epoch in range(begin_epoch, begin_epoch + num_epoch):
            train_iter.reset()
            for batch_i, batch in enumerate(train_iter):
168
                <#list tc.architectureInputs as input_name>
Christian Fuß's avatar
Christian Fuß committed
169
                ${input_name} = batch.data[${input_name?index}].as_in_context(mx_context)
170 171
                </#list>
                <#list tc.architectureOutputs as output_name>
172
                ${output_name}label = batch.label[${output_name?index}].as_in_context(mx_context)
173 174
                </#list>

Nicola Gatto's avatar
Nicola Gatto committed
175
                with autograd.record():
176
<#include "pythonExecuteWithLoss.ftl">
177

178 179 180
                    loss = 0
                    for element in lossList:
                        loss = loss + element
181

Nicola Gatto's avatar
Nicola Gatto committed
182
                loss.backward()
183 184 185

                for trainer in trainers:
                    trainer.step(batch_size)
Nicola Gatto's avatar
Nicola Gatto committed
186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204

                if tic is None:
                    tic = time.time()
                else:
                    if batch_i % speed_period == 0:
                        try:
                            speed = speed_period * batch_size / (time.time() - tic)
                        except ZeroDivisionError:
                            speed = float("inf")

                        logging.info("Epoch[%d] Batch[%d] Speed: %.2f samples/sec" % (epoch, batch_i, speed))

                        tic = time.time()

            tic = None

            train_iter.reset()
            metric = mx.metric.create(eval_metric)
            for batch_i, batch in enumerate(train_iter):
205
                <#list tc.architectureInputs as input_name>
Christian Fuß's avatar
Christian Fuß committed
206
                ${input_name} = batch.data[${input_name?index}].as_in_context(mx_context)
207 208
                </#list>

209
                labels = [
210 211 212 213
<#list tc.architectureOutputs as output_name>
                    batch.label[${output_name?index}].as_in_context(mx_context)<#sep>,
</#list>

214 215
                ]

216
                outputs=[]
217

218 219
                if True: <#-- Fix indentation -->
<#include "pythonExecute.ftl">
220

221
                predictions = []
222
                for output_name in outputs:
223 224 225 226 227
                    if mx.nd.shape_array(output_name).size > 1:
                        predictions.append(mx.nd.argmax(output_name, axis=1))
                    #ArgMax already applied
                    else:
                        predictions.append(output_name)
228

Christian Fuß's avatar
Christian Fuß committed
229 230 231
                #print [word[0] for word in predictions]
                #print labels[0]

232
                metric.update(preds=predictions, labels=labels)
Nicola Gatto's avatar
Nicola Gatto committed
233 234 235 236 237
            train_metric_score = metric.get()[1]

            test_iter.reset()
            metric = mx.metric.create(eval_metric)
            for batch_i, batch in enumerate(test_iter):
238
                <#list tc.architectureInputs as input_name>
Christian Fuß's avatar
Christian Fuß committed
239
                ${input_name} = batch.data[${input_name?index}].as_in_context(mx_context)
240 241
                </#list>

242
                labels = [
243 244 245 246
<#list tc.architectureOutputs as output_name>
                    batch.label[${output_name?index}].as_in_context(mx_context)<#sep>,
</#list>

247 248
                ]

249 250
                outputs=[]

251 252
                if True: <#-- Fix indentation -->
<#include "pythonExecute.ftl">
253

254
                predictions = []
255
                for output_name in outputs:
256 257 258 259 260
                    if mx.nd.shape_array(output_name).size > 1:
                        predictions.append(mx.nd.argmax(output_name, axis=1))
                    #ArgMax already applied
                    else:
                        predictions.append(output_name)
261 262

                metric.update(preds=predictions, labels=labels)
Nicola Gatto's avatar
Nicola Gatto committed
263 264 265 266
            test_metric_score = metric.get()[1]

            logging.info("Epoch[%d] Train: %f, Test: %f" % (epoch, train_metric_score, test_metric_score))

267

Nicola Gatto's avatar
Nicola Gatto committed
268
            if (epoch - begin_epoch) % checkpoint_period == 0:
269 270
                for i, network in self._networks.items():
                    network.save_parameters(self.parameter_path(i) + '-' + str(epoch).zfill(4) + '.params')
Nicola Gatto's avatar
Nicola Gatto committed
271

272 273 274
        for i, network in self._networks.items():
            network.save_parameters(self.parameter_path(i) + '-' + str(num_epoch + begin_epoch).zfill(4) + '.params')
            network.export(self.parameter_path(i) + '_newest', epoch=0)
Nicola Gatto's avatar
Nicola Gatto committed
275

276 277
    def parameter_path(self, index):
        return self._net_creator._model_dir_ + self._net_creator._model_prefix_ + '_' + str(index)