CNNCreator_Alexnet.py 13.2 KB
Newer Older
1
from caffe2.python import workspace, core, model_helper, brew, optimizer
2 3 4
from caffe2.python.predictor import mobile_exporter
from caffe2.proto import caffe2_pb2
import numpy as np
5
import math
6
import logging
7
import os
8
import sys
9
import lmdb
10

11 12 13
class CNNCreator_Alexnet:

    module = None
14 15 16
    _current_dir_ = os.path.join('./')
    _data_dir_    = os.path.join(_current_dir_, 'data', 'Alexnet')
    _model_dir_   = os.path.join(_current_dir_, 'model', 'Alexnet')
17

18 19
    _init_net_    = os.path.join(_model_dir_, 'init_net.pb')
    _predict_net_ = os.path.join(_model_dir_, 'predict_net.pb')
20

21 22 23 24 25 26 27 28 29 30 31
    def get_total_num_iter(self, num_epoch, batch_size, dataset_size):
        #Force floating point calculation
        batch_size_float = float(batch_size)
        dataset_size_float = float(dataset_size)

        iterations_float = math.ceil(num_epoch*(dataset_size_float/batch_size_float))
        iterations_int = int(iterations_float)

        return iterations_int


32 33
    def add_input(self, model, batch_size, db, db_type, device_opts):
        with core.DeviceScope(device_opts):
34 35 36 37 38 39 40
            if not os.path.isdir(db):
                logging.error("Data loading failure. Directory '" + os.path.abspath(db) + "' does not exist.")
                sys.exit(1)
            elif not (os.path.isfile(os.path.join(db, 'data.mdb')) and os.path.isfile(os.path.join(db, 'lock.mdb'))):
                logging.error("Data loading failure. Directory '" + os.path.abspath(db) + "' does not contain lmdb files.")
                sys.exit(1)

41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56
            # load the data
            data_uint8, label = brew.db_input(
                model,
                blobs_out=["data_uint8", "label"],
                batch_size=batch_size,
                db=db,
                db_type=db_type,
            )
            # cast the data to float
            data = model.Cast(data_uint8, "data", to=core.DataType.FLOAT)

            # scale data from [0,255] down to [0,1]
            data = model.Scale(data, data, scale=float(1./256))

            # don't need the gradient for the backward pass
            data = model.StopGradient(data, data)
57 58 59 60

            dataset_size = int (lmdb.open(db).stat()['entries'])

            return data, label, dataset_size
61

62
    def create_model(self, model, data, device_opts, is_test):
63 64 65 66
    	with core.DeviceScope(device_opts):

    		data = data
    		# data, output shape: {[3,224,224]}
67
    		conv1_ = brew.conv(model, data, 'conv1_', dim_in=3, dim_out=96, kernel=11, stride=4)
68
    		# conv1_, output shape: {[96,55,55]}
69
    		lrn1_ = brew.lrn(model, conv1_, 'lrn1_', size=5, alpha=0.0001, beta=0.75, bias=2.0)
70 71 72
    		pool1_ = brew.max_pool(model, lrn1_, 'pool1_', kernel=3, stride=2)
    		# pool1_, output shape: {[96,27,27]}
    		relu1_ = brew.relu(model, pool1_, pool1_)
73
    		conv2_1_ = brew.conv(model, get2_1_, 'conv2_1_', dim_in=48, dim_out=128, kernel=5, stride=1, pad=1)
74
    		# conv2_1_, output shape: {[128,27,27]}
75
    		lrn2_1_ = brew.lrn(model, conv2_1_, 'lrn2_1_', size=5, alpha=0.0001, beta=0.75, bias=2.0)
76 77 78
    		pool2_1_ = brew.max_pool(model, lrn2_1_, 'pool2_1_', kernel=3, stride=2)
    		# pool2_1_, output shape: {[128,13,13]}
    		relu2_1_ = brew.relu(model, pool2_1_, pool2_1_)
79
    		conv2_2_ = brew.conv(model, get2_2_, 'conv2_2_', dim_in=48, dim_out=128, kernel=5, stride=1, pad=1)
80
    		# conv2_2_, output shape: {[128,27,27]}
81
    		lrn2_2_ = brew.lrn(model, conv2_2_, 'lrn2_2_', size=5, alpha=0.0001, beta=0.75, bias=2.0)
82 83 84
    		pool2_2_ = brew.max_pool(model, lrn2_2_, 'pool2_2_', kernel=3, stride=2)
    		# pool2_2_, output shape: {[128,13,13]}
    		relu2_2_ = brew.relu(model, pool2_2_, pool2_2_)
85
    		conv3_ = brew.conv(model, concatenate3_, 'conv3_', dim_in=256, dim_out=384, kernel=3, stride=1, pad=1)
86 87
    		# conv3_, output shape: {[384,13,13]}
    		relu3_ = brew.relu(model, conv3_, conv3_)
88
    		conv4_1_ = brew.conv(model, get4_1_, 'conv4_1_', dim_in=192, dim_out=192, kernel=3, stride=1, pad=1)
89 90
    		# conv4_1_, output shape: {[192,13,13]}
    		relu4_1_ = brew.relu(model, conv4_1_, conv4_1_)
91
    		conv5_1_ = brew.conv(model, relu4_1_, 'conv5_1_', dim_in=192, dim_out=128, kernel=3, stride=1, pad=1)
92 93 94 95
    		# conv5_1_, output shape: {[128,13,13]}
    		pool5_1_ = brew.max_pool(model, conv5_1_, 'pool5_1_', kernel=3, stride=2)
    		# pool5_1_, output shape: {[128,6,6]}
    		relu5_1_ = brew.relu(model, pool5_1_, pool5_1_)
96
    		conv4_2_ = brew.conv(model, get4_2_, 'conv4_2_', dim_in=192, dim_out=192, kernel=3, stride=1, pad=1)
97 98
    		# conv4_2_, output shape: {[192,13,13]}
    		relu4_2_ = brew.relu(model, conv4_2_, conv4_2_)
99
    		conv5_2_ = brew.conv(model, relu4_2_, 'conv5_2_', dim_in=192, dim_out=128, kernel=3, stride=1, pad=1)
100 101 102 103 104 105 106
    		# conv5_2_, output shape: {[128,13,13]}
    		pool5_2_ = brew.max_pool(model, conv5_2_, 'pool5_2_', kernel=3, stride=2)
    		# pool5_2_, output shape: {[128,6,6]}
    		relu5_2_ = brew.relu(model, pool5_2_, pool5_2_)
    		fc6_ = brew.fc(model, concatenate6_, 'fc6_', dim_in=256 * 6 * 6, dim_out=4096)
    		# fc6_, output shape: {[4096,1,1]}
    		relu6_ = brew.relu(model, fc6_, fc6_)
107
    		dropout6_ = brew.dropout(model, relu6_, 'dropout6_', ratio=0.5, is_test=False)
108 109 110
    		fc7_ = brew.fc(model, dropout6_, 'fc7_', dim_in=4096, dim_out=4096)
    		# fc7_, output shape: {[4096,1,1]}
    		relu7_ = brew.relu(model, fc7_, fc7_)
111
    		dropout7_ = brew.dropout(model, relu7_, 'dropout7_', ratio=0.5, is_test=False)
112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158
    		fc8_ = brew.fc(model, dropout7_, 'fc8_', dim_in=4096, dim_out=10)
    		# fc8_, output shape: {[10,1,1]}
    		predictions = brew.softmax(model, fc8_, 'predictions')

    		return predictions

    # this adds the loss and optimizer
    def add_training_operators(self, model, output, label, device_opts, opt_type, base_learning_rate, policy, stepsize, epsilon, beta1, beta2, gamma, momentum) :
    	with core.DeviceScope(device_opts):
    		xent = model.LabelCrossEntropy([output, label], 'xent')
    		loss = model.AveragedLoss(xent, "loss")

    		model.AddGradientOperators([loss])

    		if opt_type == 'adam':
    		    if policy == 'step':
    		        opt = optimizer.build_adam(model, base_learning_rate=base_learning_rate, policy=policy, stepsize=stepsize, beta1=beta1, beta2=beta2, epsilon=epsilon)
    		    elif policy == 'fixed' or policy == 'inv':
    		        opt = optimizer.build_adam(model, base_learning_rate=base_learning_rate, policy=policy, beta1=beta1, beta2=beta2, epsilon=epsilon)
    		    print("adam optimizer selected")
    		elif opt_type == 'sgd':
    		    if policy == 'step':
    		        opt = optimizer.build_sgd(model, base_learning_rate=base_learning_rate, policy=policy, stepsize=stepsize, gamma=gamma, momentum=momentum)
    		    elif policy == 'fixed' or policy == 'inv':
    		        opt = optimizer.build_sgd(model, base_learning_rate=base_learning_rate, policy=policy, gamma=gamma, momentum=momentum)
    		    print("sgd optimizer selected")
    		elif opt_type == 'rmsprop':
    		    if policy == 'step':
    		        opt = optimizer.build_rms_prop(model, base_learning_rate=base_learning_rate, policy=policy, stepsize=stepsize, decay=gamma, momentum=momentum, epsilon=epsilon)
    		    elif policy == 'fixed' or policy == 'inv':
    		        opt = optimizer.build_rms_prop(model, base_learning_rate=base_learning_rate, policy=policy, decay=gamma, momentum=momentum, epsilon=epsilon)
    		    print("rmsprop optimizer selected")
    		elif opt_type == 'adagrad':
    		    if policy == 'step':
    		        opt = optimizer.build_adagrad(model, base_learning_rate=base_learning_rate, policy=policy, stepsize=stepsize, decay=gamma, epsilon=epsilon)
    		    elif policy == 'fixed' or policy == 'inv':
    		        opt = optimizer.build_adagrad(model, base_learning_rate=base_learning_rate, policy=policy, decay=gamma, epsilon=epsilon)
    		    print("adagrad optimizer selected")

    def add_accuracy(self, model, output, label, device_opts, eval_metric):
        with core.DeviceScope(device_opts):
            if eval_metric == 'accuracy':
                accuracy = brew.accuracy(model, [output, label], "accuracy")
            elif eval_metric == 'top_k_accuracy':
                accuracy = brew.accuracy(model, [output, label], "accuracy", top_k=3)
            return accuracy

159 160
    def train(self, num_epoch=1000, batch_size=64, context='gpu', eval_metric='accuracy', opt_type='adam', base_learning_rate=0.001, weight_decay=0.001, policy='fixed', stepsize=1, epsilon=1E-8, beta1=0.9, beta2=0.999, gamma=0.999, momentum=0.9) :
        if context == 'cpu':
161 162
            device_opts = core.DeviceOption(caffe2_pb2.CPU, 0)
            print("CPU mode selected")
163
        elif context == 'gpu':
164 165 166
            device_opts = core.DeviceOption(caffe2_pb2.CUDA, 0)
            print("GPU mode selected")

167
    	workspace.ResetWorkspace(self._model_dir_)
168 169 170 171

    	arg_scope = {"order": "NCHW"}
    	# == Training model ==
    	train_model= model_helper.ModelHelper(name="train_net", arg_scope=arg_scope)
172
    	data, label, train_dataset_size = self.add_input(train_model, batch_size=batch_size, db=os.path.join(self._data_dir_, 'train_lmdb'), db_type='lmdb', device_opts=device_opts)
173
    	predictions = self.create_model(train_model, data, device_opts=device_opts, is_test=False)
174 175 176 177 178 179 180 181 182 183
    	self.add_training_operators(train_model, predictions, label, device_opts, opt_type, base_learning_rate, policy, stepsize, epsilon, beta1, beta2, gamma, momentum)
    	self.add_accuracy(train_model, predictions, label, device_opts, eval_metric)
    	with core.DeviceScope(device_opts):
    		brew.add_weight_decay(train_model, weight_decay)

    	# Initialize and create the training network
    	workspace.RunNetOnce(train_model.param_init_net)
    	workspace.CreateNet(train_model.net, overwrite=True)

    	# Main Training Loop
184 185 186
    	iterations = self.get_total_num_iter(num_epoch, batch_size, train_dataset_size)
        print("** Starting Training for " + str(num_epoch) + " epochs = " + str(iterations) + " iterations **")
    	for i in range(iterations):
187
    		workspace.RunNet(train_model.net)
188 189
    		if i % 50 == 0:
    			print 'Iter ' + str(i) + ': ' + 'Loss ' + str(workspace.FetchBlob("loss")) + ' - ' + 'Accuracy ' + str(workspace.FetchBlob('accuracy'))
190 191 192 193 194
    	print("Training done")

    	print("== Running Test model ==")
    	# == Testing model. ==
    	test_model= model_helper.ModelHelper(name="test_net", arg_scope=arg_scope, init_params=False)
195
    	data, label, test_dataset_size = self.add_input(test_model, batch_size=batch_size, db=os.path.join(self._data_dir_, 'test_lmdb'), db_type='lmdb', device_opts=device_opts)
196
    	predictions = self.create_model(test_model, data, device_opts=device_opts, is_test=True)
197 198 199 200 201
    	self.add_accuracy(test_model, predictions, label, device_opts, eval_metric)
    	workspace.RunNetOnce(test_model.param_init_net)
    	workspace.CreateNet(test_model.net, overwrite=True)

    	# Main Testing Loop
202 203
    	test_accuracy = np.zeros(test_dataset_size/batch_size)
    	for i in range(test_dataset_size/batch_size):
204 205 206 207 208 209 210 211 212 213
    		# Run a forward pass of the net on the current batch
    		workspace.RunNet(test_model.net)
    		# Collect the batch accuracy from the workspace
    		test_accuracy[i] = workspace.FetchBlob('accuracy')

    	print('Test_accuracy: {:.4f}'.format(test_accuracy.mean()))

    	# == Deployment model. ==
    	# We simply need the main AddModel part.
    	deploy_model = model_helper.ModelHelper(name="deploy_net", arg_scope=arg_scope, init_params=False)
214
    	self.create_model(deploy_model, "data", device_opts, is_test=True)
215 216

    	print("Saving deploy model")
217
    	self.save_net(self._init_net_, self._predict_net_, deploy_model)
218 219 220 221 222 223 224 225 226

    def save_net(self, init_net_path, predict_net_path, model):

    	init_net, predict_net = mobile_exporter.Export(
    		workspace,
    		model.net,
    		model.params
    	)

227
        try:
228
            os.makedirs(self._model_dir_)
229
        except OSError:
230
            if not os.path.isdir(self._model_dir_):
231 232
                raise

233
    	print("Save the model to init_net.pb and predict_net.pb")
234
    	with open(predict_net_path, 'wb') as f:
235
    		f.write(model.net._net.SerializeToString())
236
    	with open(init_net_path, 'wb') as f:
237 238 239
    		f.write(init_net.SerializeToString())

    	print("Save the model to init_net.pbtxt and predict_net.pbtxt")
240 241

    	with open(init_net_path.replace('.pb','.pbtxt'), 'w') as f:
242
    		f.write(str(init_net))
243
    	with open(predict_net_path.replace('.pb','.pbtxt'), 'w') as f:
244 245 246 247
    		f.write(str(predict_net))
    	print("== Saved init_net and predict_net ==")

    def load_net(self, init_net_path, predict_net_path, device_opts):
248 249 250 251 252 253 254 255 256
        if not os.path.isfile(init_net_path):
            logging.error("Network loading failure. File '" + os.path.abspath(init_net_path) + "' does not exist.")
            sys.exit(1)
        elif not os.path.isfile(predict_net_path):
            logging.error("Network loading failure. File '" + os.path.abspath(predict_net_path) + "' does not exist.")
            sys.exit(1)

        init_def = caffe2_pb2.NetDef()
    	with open(init_net_path, 'rb') as f:
257 258 259 260 261
    		init_def.ParseFromString(f.read())
    		init_def.device_option.CopyFrom(device_opts)
    		workspace.RunNetOnce(init_def.SerializeToString())

    	net_def = caffe2_pb2.NetDef()
262
    	with open(predict_net_path, 'rb') as f:
263 264 265
    		net_def.ParseFromString(f.read())
    		net_def.device_option.CopyFrom(device_opts)
    		workspace.CreateNet(net_def.SerializeToString(), overwrite=True)
266
    	print("== Loaded init_net and predict_net ==")