finetune_pretrained_alexnet.py 4.12 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116
import mxnet as mx

import custom_functions

CONTEXT = mx.cpu()
MODEL_PATH = "model/alexnet_pretrained/caffenet"
MODEL_PATH_FINETUNED = "model/alexnet_finetuned/caffenet"
DATA_DIR = "/media/sveta/4991e634-dd81-4cb9-bf46-2fa9c7159263/TORCS_raw/"
MODEL_PREFIX = "dpnet"

batch_size = 64
num_epoch = 1
begin_epoch = 0

symbol, arg_params, aux_params = mx.model.load_checkpoint(MODEL_PATH, 0)

last_layer_to_stay = "flatten_0"
all_layers = symbol.get_internals()
net = all_layers[last_layer_to_stay+'_output']
new_args = dict({k:arg_params[k] for k in arg_params if 'fc' not in k})

# fc5_ = mx.symbol.flatten(data=net)
fc5_ = mx.symbol.FullyConnected(data=net,
                                num_hidden=4096,
                                no_bias=False,
                                name="fc5_")
relu6_ = mx.symbol.Activation(data=fc5_,
                              act_type='relu',
                              name="relu6_")

dropout6_ = mx.symbol.Dropout(data=relu6_,
                              p=0.5,
                              name="dropout6_")
fc6_ = mx.symbol.FullyConnected(data=dropout6_,
                                num_hidden=4096,
                                no_bias=False,
                                name="fc6_")
relu7_ = mx.symbol.Activation(data=fc6_,
                              act_type='relu',
                              name="relu7_")

dropout7_ = mx.symbol.Dropout(data=relu7_,
                              p=0.5,
                              name="dropout7_")
fc7_ = mx.symbol.FullyConnected(data=dropout7_,
                                num_hidden=256,
                                no_bias=False,
                                name="fc7_")
relu8_ = mx.symbol.Activation(data=fc7_,
                              act_type='relu',
                              name="relu8_")

dropout8_ = mx.symbol.Dropout(data=relu8_,
                              p=0.5,
                              name="dropout8_")
fc8_ = mx.symbol.FullyConnected(data=dropout8_,
                                num_hidden=14,
                                no_bias=True,
                                name="fc8_")

predictions = mx.symbol.LinearRegressionOutput(data=fc8_,
                                               name="predictions")


optimizer = 'sgd'
optimizer_params = {
    'learning_rate': 0.01,
    'learning_rate_decay': 0.9,
    'step_size': 8000}

if 'weight_decay' in optimizer_params:
    optimizer_params['wd'] = optimizer_params['weight_decay']
    del optimizer_params['weight_decay']
if 'learning_rate_decay' in optimizer_params:
    min_learning_rate = 1e-08
    if 'learning_rate_minimum' in optimizer_params:
        min_learning_rate = optimizer_params['learning_rate_minimum']
        del optimizer_params['learning_rate_minimum']
    optimizer_params['lr_scheduler'] = mx.lr_scheduler.FactorScheduler(
        optimizer_params['step_size'],
        factor=optimizer_params['learning_rate_decay'],
        stop_factor_lr=min_learning_rate)
    del optimizer_params['step_size']
    del optimizer_params['learning_rate_decay']

train_iter, test_iter, data_mean, data_std = custom_functions.load_data_rec(DATA_DIR, batch_size)

module = mx.mod.Module(symbol=predictions,  #mx.symbol.Group([predictions]),
                       data_names=['data'],
                       label_names=['predictions_label'],
                       context=CONTEXT)
module.bind(data_shapes=[('data', (64, 3, 210, 280))], force_rebind=True)
module.set_params(arg_params=new_args,
                  aux_params=aux_params,
                  allow_missing=True)

module.fit(
    allow_missing=True,
    force_rebind=True,
    force_init = True,
    train_data=train_iter,
    eval_data=test_iter,
    eval_metric='mse',
    optimizer=optimizer,
    optimizer_params=optimizer_params,
    batch_end_callback=mx.callback.Speedometer(batch_size),
    epoch_end_callback=mx.callback.do_checkpoint(prefix=MODEL_PATH_FINETUNED + MODEL_PREFIX, period=1),
    begin_epoch=begin_epoch,
    num_epoch=num_epoch + begin_epoch)
module.save_checkpoint(MODEL_PATH_FINETUNED + MODEL_PREFIX, num_epoch + begin_epoch)
module.save_checkpoint(MODEL_PATH_FINETUNED + MODEL_PREFIX + '_newest', 0)