finetune_pretrained_alexnet.py 4.12 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import mxnet as mx

import custom_functions

CONTEXT = mx.cpu()
MODEL_PATH = "model/alexnet_pretrained/caffenet"
MODEL_PATH_FINETUNED = "model/alexnet_finetuned/caffenet"
DATA_DIR = "/media/sveta/4991e634-dd81-4cb9-bf46-2fa9c7159263/TORCS_raw/"
MODEL_PREFIX = "dpnet"

batch_size = 64
num_epoch = 1
begin_epoch = 0

symbol, arg_params, aux_params = mx.model.load_checkpoint(MODEL_PATH, 0)

last_layer_to_stay = "flatten_0"
all_layers = symbol.get_internals()
net = all_layers[last_layer_to_stay+'_output']
new_args = dict({k:arg_params[k] for k in arg_params if 'fc' not in k})

# fc5_ = mx.symbol.flatten(data=net)
fc5_ = mx.symbol.FullyConnected(data=net,
                                num_hidden=4096,
                                no_bias=False,
                                name="fc5_")
relu6_ = mx.symbol.Activation(data=fc5_,
                              act_type='relu',
                              name="relu6_")

dropout6_ = mx.symbol.Dropout(data=relu6_,
                              p=0.5,
                              name="dropout6_")
fc6_ = mx.symbol.FullyConnected(data=dropout6_,
                                num_hidden=4096,
                                no_bias=False,
                                name="fc6_")
relu7_ = mx.symbol.Activation(data=fc6_,
                              act_type='relu',
                              name="relu7_")

dropout7_ = mx.symbol.Dropout(data=relu7_,
                              p=0.5,
                              name="dropout7_")
fc7_ = mx.symbol.FullyConnected(data=dropout7_,
                                num_hidden=256,
                                no_bias=False,
                                name="fc7_")
relu8_ = mx.symbol.Activation(data=fc7_,
                              act_type='relu',
                              name="relu8_")

dropout8_ = mx.symbol.Dropout(data=relu8_,
                              p=0.5,
                              name="dropout8_")
fc8_ = mx.symbol.FullyConnected(data=dropout8_,
                                num_hidden=14,
                                no_bias=True,
                                name="fc8_")

predictions = mx.symbol.LinearRegressionOutput(data=fc8_,
                                               name="predictions")


optimizer = 'sgd'
optimizer_params = {
    'learning_rate': 0.01,
    'learning_rate_decay': 0.9,
    'step_size': 8000}

if 'weight_decay' in optimizer_params:
    optimizer_params['wd'] = optimizer_params['weight_decay']
    del optimizer_params['weight_decay']
if 'learning_rate_decay' in optimizer_params:
    min_learning_rate = 1e-08
    if 'learning_rate_minimum' in optimizer_params:
        min_learning_rate = optimizer_params['learning_rate_minimum']
        del optimizer_params['learning_rate_minimum']
    optimizer_params['lr_scheduler'] = mx.lr_scheduler.FactorScheduler(
        optimizer_params['step_size'],
        factor=optimizer_params['learning_rate_decay'],
        stop_factor_lr=min_learning_rate)
    del optimizer_params['step_size']
    del optimizer_params['learning_rate_decay']

train_iter, test_iter, data_mean, data_std = custom_functions.load_data_rec(DATA_DIR, batch_size)

module = mx.mod.Module(symbol=predictions,  #mx.symbol.Group([predictions]),
                       data_names=['data'],
                       label_names=['predictions_label'],
                       context=CONTEXT)
module.bind(data_shapes=[('data', (64, 3, 210, 280))], force_rebind=True)
module.set_params(arg_params=new_args,
                  aux_params=aux_params,
                  allow_missing=True)

module.fit(
    allow_missing=True,
    force_rebind=True,
    force_init = True,
    train_data=train_iter,
    eval_data=test_iter,
    eval_metric='mse',
    optimizer=optimizer,
    optimizer_params=optimizer_params,
    batch_end_callback=mx.callback.Speedometer(batch_size),
    epoch_end_callback=mx.callback.do_checkpoint(prefix=MODEL_PATH_FINETUNED + MODEL_PREFIX, period=1),
    begin_epoch=begin_epoch,
    num_epoch=num_epoch + begin_epoch)
module.save_checkpoint(MODEL_PATH_FINETUNED + MODEL_PREFIX, num_epoch + begin_epoch)
module.save_checkpoint(MODEL_PATH_FINETUNED + MODEL_PREFIX + '_newest', 0)