From 3166e5713f890097256d2302f73e337330819c2f Mon Sep 17 00:00:00 2001 From: Thomas Michael Timmermanns <thomas.timmermanns@rwth-aachen.de> Date: Thu, 31 May 2018 16:50:32 +0200 Subject: [PATCH] Added CNNTrain parameter 'context' --- src/main/resources/templates/CNNCreator.ftl | 14 ++++++++++---- .../resources/target_code/CNNCreator_Alexnet.py | 16 +++++++++++----- .../CNNCreator_CifarClassifierNetwork.py | 16 +++++++++++----- .../resources/target_code/CNNCreator_VGG16.py | 16 +++++++++++----- 4 files changed, 43 insertions(+), 19 deletions(-) diff --git a/src/main/resources/templates/CNNCreator.ftl b/src/main/resources/templates/CNNCreator.ftl index ec4ce1d8..834fbf8a 100644 --- a/src/main/resources/templates/CNNCreator.ftl +++ b/src/main/resources/templates/CNNCreator.ftl @@ -110,9 +110,15 @@ class ${tc.fileNameWithoutEnding}: optimizer='adam', optimizer_params=(('learning_rate', 0.001),), load_checkpoint=True, - context=mx.gpu(), + context='gpu', checkpoint_period=5, normalize=True): + if context == 'gpu': + mx_context = mx.gpu() + elif context == 'cpu': + mx_context = mx.cpu() + else: + logging.error("Context argument is '" + context + "'. Only 'cpu' and 'gpu are valid arguments'.") if 'weight_decay' in optimizer_params: optimizer_params['wd'] = optimizer_params['weight_decay'] @@ -133,13 +139,13 @@ class ${tc.fileNameWithoutEnding}: train_iter, test_iter, data_mean, data_std = self.load_data(batch_size) if self.module == None: if normalize: - self.construct(context, data_mean, data_std) + self.construct(mx_context, data_mean, data_std) else: - self.construct(context) + self.construct(mx_context) begin_epoch = 0 if load_checkpoint: - begin_epoch = self.load(context) + begin_epoch = self.load(mx_context) else: if os.path.isdir(self._model_dir_): shutil.rmtree(self._model_dir_) diff --git a/src/test/resources/target_code/CNNCreator_Alexnet.py b/src/test/resources/target_code/CNNCreator_Alexnet.py index 5fca8603..678b44c2 100644 --- a/src/test/resources/target_code/CNNCreator_Alexnet.py +++ b/src/test/resources/target_code/CNNCreator_Alexnet.py @@ -110,9 +110,15 @@ class CNNCreator_Alexnet: optimizer='adam', optimizer_params=(('learning_rate', 0.001),), load_checkpoint=True, - context=mx.gpu(), + context='gpu', checkpoint_period=5, normalize=True): + if context == 'gpu': + mx_context = mx.gpu() + elif context == 'cpu': + mx_context = mx.cpu() + else: + logging.error("Context argument is '" + context + "'. Only 'cpu' and 'gpu are valid arguments'.") if 'weight_decay' in optimizer_params: optimizer_params['wd'] = optimizer_params['weight_decay'] @@ -133,13 +139,13 @@ class CNNCreator_Alexnet: train_iter, test_iter, data_mean, data_std = self.load_data(batch_size) if self.module == None: if normalize: - self.construct(context, data_mean, data_std) + self.construct(mx_context, data_mean, data_std) else: - self.construct(context) + self.construct(mx_context) begin_epoch = 0 if load_checkpoint: - begin_epoch = self.load(context) + begin_epoch = self.load(mx_context) else: if os.path.isdir(self._model_dir_): shutil.rmtree(self._model_dir_) @@ -417,4 +423,4 @@ class CNNCreator_Alexnet: self.module = mx.mod.Module(symbol=mx.symbol.Group([predictions]), data_names=self._input_names_, label_names=self._output_names_, - context=context) \ No newline at end of file + context=context) diff --git a/src/test/resources/target_code/CNNCreator_CifarClassifierNetwork.py b/src/test/resources/target_code/CNNCreator_CifarClassifierNetwork.py index 974f294b..6dd81feb 100644 --- a/src/test/resources/target_code/CNNCreator_CifarClassifierNetwork.py +++ b/src/test/resources/target_code/CNNCreator_CifarClassifierNetwork.py @@ -110,9 +110,15 @@ class CNNCreator_CifarClassifierNetwork: optimizer='adam', optimizer_params=(('learning_rate', 0.001),), load_checkpoint=True, - context=mx.gpu(), + context='gpu', checkpoint_period=5, normalize=True): + if context == 'gpu': + mx_context = mx.gpu() + elif context == 'cpu': + mx_context = mx.cpu() + else: + logging.error("Context argument is '" + context + "'. Only 'cpu' and 'gpu are valid arguments'.") if 'weight_decay' in optimizer_params: optimizer_params['wd'] = optimizer_params['weight_decay'] @@ -133,13 +139,13 @@ class CNNCreator_CifarClassifierNetwork: train_iter, test_iter, data_mean, data_std = self.load_data(batch_size) if self.module == None: if normalize: - self.construct(context, data_mean, data_std) + self.construct(mx_context, data_mean, data_std) else: - self.construct(context) + self.construct(mx_context) begin_epoch = 0 if load_checkpoint: - begin_epoch = self.load(context) + begin_epoch = self.load(mx_context) else: if os.path.isdir(self._model_dir_): shutil.rmtree(self._model_dir_) @@ -655,4 +661,4 @@ class CNNCreator_CifarClassifierNetwork: self.module = mx.mod.Module(symbol=mx.symbol.Group([softmax]), data_names=self._input_names_, label_names=self._output_names_, - context=context) \ No newline at end of file + context=context) diff --git a/src/test/resources/target_code/CNNCreator_VGG16.py b/src/test/resources/target_code/CNNCreator_VGG16.py index 540307c6..ba39f3a2 100644 --- a/src/test/resources/target_code/CNNCreator_VGG16.py +++ b/src/test/resources/target_code/CNNCreator_VGG16.py @@ -110,9 +110,15 @@ class CNNCreator_VGG16: optimizer='adam', optimizer_params=(('learning_rate', 0.001),), load_checkpoint=True, - context=mx.gpu(), + context='gpu', checkpoint_period=5, normalize=True): + if context == 'gpu': + mx_context = mx.gpu() + elif context == 'cpu': + mx_context = mx.cpu() + else: + logging.error("Context argument is '" + context + "'. Only 'cpu' and 'gpu are valid arguments'.") if 'weight_decay' in optimizer_params: optimizer_params['wd'] = optimizer_params['weight_decay'] @@ -133,13 +139,13 @@ class CNNCreator_VGG16: train_iter, test_iter, data_mean, data_std = self.load_data(batch_size) if self.module == None: if normalize: - self.construct(context, data_mean, data_std) + self.construct(mx_context, data_mean, data_std) else: - self.construct(context) + self.construct(mx_context) begin_epoch = 0 if load_checkpoint: - begin_epoch = self.load(context) + begin_epoch = self.load(mx_context) else: if os.path.isdir(self._model_dir_): shutil.rmtree(self._model_dir_) @@ -453,4 +459,4 @@ class CNNCreator_VGG16: self.module = mx.mod.Module(symbol=mx.symbol.Group([predictions]), data_names=self._input_names_, label_names=self._output_names_, - context=context) \ No newline at end of file + context=context) -- GitLab