Added CNNTrain parameter 'context'

parent 1cc5893e
Pipeline #52303 failed with stage
in 10 seconds
......@@ -16,7 +16,7 @@
<!-- .. SE-Libraries .................................................. -->
<emadl.version>0.2.1-SNAPSHOT</emadl.version>
<CNNTrain.version>0.2.0-SNAPSHOT</CNNTrain.version>
<CNNTrain.version>0.2.1-SNAPSHOT</CNNTrain.version>
<cnnarch-generator.version>0.2.1-SNAPSHOT</cnnarch-generator.version>
<embedded-montiarc-math-generator>0.0.9-SNAPSHOT</embedded-montiarc-math-generator>
......
......@@ -2,10 +2,12 @@ configuration CifarNetwork{
num_epoch:10
batch_size:64
normalize:true
context:gpu
load_checkpoint:false
optimizer:adam{
learning_rate:0.01
learning_rate_decay:0.8
step_size:1000
weight_decay:0.0001
}
}
......@@ -110,9 +110,15 @@ class CNNCreator_cifar10_cifar10Classifier_net:
optimizer='adam',
optimizer_params=(('learning_rate', 0.001),),
load_checkpoint=True,
context=mx.gpu(),
context='gpu',
checkpoint_period=5,
normalize=True):
if context == 'gpu':
mx_context = mx.gpu()
elif context == 'cpu':
mx_context = mx.cpu()
else:
logging.error("Context argument is '" + context + "'. Only 'cpu' and 'gpu are valid arguments'.")
if 'weight_decay' in optimizer_params:
optimizer_params['wd'] = optimizer_params['weight_decay']
......@@ -133,13 +139,13 @@ class CNNCreator_cifar10_cifar10Classifier_net:
train_iter, test_iter, data_mean, data_std = self.load_data(batch_size)
if self.module == None:
if normalize:
self.construct(context, data_mean, data_std)
self.construct(mx_context, data_mean, data_std)
else:
self.construct(context)
self.construct(mx_context)
begin_epoch = 0
if load_checkpoint:
begin_epoch = self.load(context)
begin_epoch = self.load(mx_context)
else:
if os.path.isdir(self._model_dir_):
shutil.rmtree(self._model_dir_)
......
......@@ -13,9 +13,11 @@ if __name__ == "__main__":
batch_size = 64,
num_epoch = 10,
load_checkpoint = False,
context = 'gpu',
normalize = True,
optimizer = 'adam',
optimizer_params = {
'weight_decay': 1.0E-4,
'learning_rate': 0.01,
'learning_rate_decay': 0.8,
'step_size': 1000}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment