Added CNNTrain parameter 'context'

parent 69ec2e86
Pipeline #52302 failed with stage
in 46 seconds
......@@ -110,9 +110,15 @@ class ${tc.fileNameWithoutEnding}:
optimizer='adam',
optimizer_params=(('learning_rate', 0.001),),
load_checkpoint=True,
context=mx.gpu(),
context='gpu',
checkpoint_period=5,
normalize=True):
if context == 'gpu':
mx_context = mx.gpu()
elif context == 'cpu':
mx_context = mx.cpu()
else:
logging.error("Context argument is '" + context + "'. Only 'cpu' and 'gpu are valid arguments'.")
if 'weight_decay' in optimizer_params:
optimizer_params['wd'] = optimizer_params['weight_decay']
......@@ -133,13 +139,13 @@ class ${tc.fileNameWithoutEnding}:
train_iter, test_iter, data_mean, data_std = self.load_data(batch_size)
if self.module == None:
if normalize:
self.construct(context, data_mean, data_std)
self.construct(mx_context, data_mean, data_std)
else:
self.construct(context)
self.construct(mx_context)
begin_epoch = 0
if load_checkpoint:
begin_epoch = self.load(context)
begin_epoch = self.load(mx_context)
else:
if os.path.isdir(self._model_dir_):
shutil.rmtree(self._model_dir_)
......
......@@ -110,9 +110,15 @@ class CNNCreator_Alexnet:
optimizer='adam',
optimizer_params=(('learning_rate', 0.001),),
load_checkpoint=True,
context=mx.gpu(),
context='gpu',
checkpoint_period=5,
normalize=True):
if context == 'gpu':
mx_context = mx.gpu()
elif context == 'cpu':
mx_context = mx.cpu()
else:
logging.error("Context argument is '" + context + "'. Only 'cpu' and 'gpu are valid arguments'.")
if 'weight_decay' in optimizer_params:
optimizer_params['wd'] = optimizer_params['weight_decay']
......@@ -133,13 +139,13 @@ class CNNCreator_Alexnet:
train_iter, test_iter, data_mean, data_std = self.load_data(batch_size)
if self.module == None:
if normalize:
self.construct(context, data_mean, data_std)
self.construct(mx_context, data_mean, data_std)
else:
self.construct(context)
self.construct(mx_context)
begin_epoch = 0
if load_checkpoint:
begin_epoch = self.load(context)
begin_epoch = self.load(mx_context)
else:
if os.path.isdir(self._model_dir_):
shutil.rmtree(self._model_dir_)
......@@ -417,4 +423,4 @@ class CNNCreator_Alexnet:
self.module = mx.mod.Module(symbol=mx.symbol.Group([predictions]),
data_names=self._input_names_,
label_names=self._output_names_,
context=context)
\ No newline at end of file
context=context)
......@@ -110,9 +110,15 @@ class CNNCreator_CifarClassifierNetwork:
optimizer='adam',
optimizer_params=(('learning_rate', 0.001),),
load_checkpoint=True,
context=mx.gpu(),
context='gpu',
checkpoint_period=5,
normalize=True):
if context == 'gpu':
mx_context = mx.gpu()
elif context == 'cpu':
mx_context = mx.cpu()
else:
logging.error("Context argument is '" + context + "'. Only 'cpu' and 'gpu are valid arguments'.")
if 'weight_decay' in optimizer_params:
optimizer_params['wd'] = optimizer_params['weight_decay']
......@@ -133,13 +139,13 @@ class CNNCreator_CifarClassifierNetwork:
train_iter, test_iter, data_mean, data_std = self.load_data(batch_size)
if self.module == None:
if normalize:
self.construct(context, data_mean, data_std)
self.construct(mx_context, data_mean, data_std)
else:
self.construct(context)
self.construct(mx_context)
begin_epoch = 0
if load_checkpoint:
begin_epoch = self.load(context)
begin_epoch = self.load(mx_context)
else:
if os.path.isdir(self._model_dir_):
shutil.rmtree(self._model_dir_)
......@@ -655,4 +661,4 @@ class CNNCreator_CifarClassifierNetwork:
self.module = mx.mod.Module(symbol=mx.symbol.Group([softmax]),
data_names=self._input_names_,
label_names=self._output_names_,
context=context)
\ No newline at end of file
context=context)
......@@ -110,9 +110,15 @@ class CNNCreator_VGG16:
optimizer='adam',
optimizer_params=(('learning_rate', 0.001),),
load_checkpoint=True,
context=mx.gpu(),
context='gpu',
checkpoint_period=5,
normalize=True):
if context == 'gpu':
mx_context = mx.gpu()
elif context == 'cpu':
mx_context = mx.cpu()
else:
logging.error("Context argument is '" + context + "'. Only 'cpu' and 'gpu are valid arguments'.")
if 'weight_decay' in optimizer_params:
optimizer_params['wd'] = optimizer_params['weight_decay']
......@@ -133,13 +139,13 @@ class CNNCreator_VGG16:
train_iter, test_iter, data_mean, data_std = self.load_data(batch_size)
if self.module == None:
if normalize:
self.construct(context, data_mean, data_std)
self.construct(mx_context, data_mean, data_std)
else:
self.construct(context)
self.construct(mx_context)
begin_epoch = 0
if load_checkpoint:
begin_epoch = self.load(context)
begin_epoch = self.load(mx_context)
else:
if os.path.isdir(self._model_dir_):
shutil.rmtree(self._model_dir_)
......@@ -453,4 +459,4 @@ class CNNCreator_VGG16:
self.module = mx.mod.Module(symbol=mx.symbol.Group([predictions]),
data_names=self._input_names_,
label_names=self._output_names_,
context=context)
\ No newline at end of file
context=context)
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment