From 3166e5713f890097256d2302f73e337330819c2f Mon Sep 17 00:00:00 2001
From: Thomas Michael Timmermanns <thomas.timmermanns@rwth-aachen.de>
Date: Thu, 31 May 2018 16:50:32 +0200
Subject: [PATCH] Added CNNTrain parameter 'context'

---
 src/main/resources/templates/CNNCreator.ftl      | 14 ++++++++++----
 .../resources/target_code/CNNCreator_Alexnet.py  | 16 +++++++++++-----
 .../CNNCreator_CifarClassifierNetwork.py         | 16 +++++++++++-----
 .../resources/target_code/CNNCreator_VGG16.py    | 16 +++++++++++-----
 4 files changed, 43 insertions(+), 19 deletions(-)

diff --git a/src/main/resources/templates/CNNCreator.ftl b/src/main/resources/templates/CNNCreator.ftl
index ec4ce1d8..834fbf8a 100644
--- a/src/main/resources/templates/CNNCreator.ftl
+++ b/src/main/resources/templates/CNNCreator.ftl
@@ -110,9 +110,15 @@ class ${tc.fileNameWithoutEnding}:
               optimizer='adam',
               optimizer_params=(('learning_rate', 0.001),),
               load_checkpoint=True,
-              context=mx.gpu(),
+              context='gpu',
               checkpoint_period=5,
               normalize=True):
+        if context == 'gpu':
+            mx_context = mx.gpu()
+        elif context == 'cpu':
+            mx_context = mx.cpu()
+        else:
+            logging.error("Context argument is '" + context + "'. Only 'cpu' and 'gpu are valid arguments'.")
 
         if 'weight_decay' in optimizer_params:
             optimizer_params['wd'] = optimizer_params['weight_decay']
@@ -133,13 +139,13 @@ class ${tc.fileNameWithoutEnding}:
         train_iter, test_iter, data_mean, data_std = self.load_data(batch_size)
         if self.module == None:
             if normalize:
-                self.construct(context, data_mean, data_std)
+                self.construct(mx_context, data_mean, data_std)
             else:
-                self.construct(context)
+                self.construct(mx_context)
 
         begin_epoch = 0
         if load_checkpoint:
-            begin_epoch = self.load(context)
+            begin_epoch = self.load(mx_context)
         else:
             if os.path.isdir(self._model_dir_):
                 shutil.rmtree(self._model_dir_)
diff --git a/src/test/resources/target_code/CNNCreator_Alexnet.py b/src/test/resources/target_code/CNNCreator_Alexnet.py
index 5fca8603..678b44c2 100644
--- a/src/test/resources/target_code/CNNCreator_Alexnet.py
+++ b/src/test/resources/target_code/CNNCreator_Alexnet.py
@@ -110,9 +110,15 @@ class CNNCreator_Alexnet:
               optimizer='adam',
               optimizer_params=(('learning_rate', 0.001),),
               load_checkpoint=True,
-              context=mx.gpu(),
+              context='gpu',
               checkpoint_period=5,
               normalize=True):
+        if context == 'gpu':
+            mx_context = mx.gpu()
+        elif context == 'cpu':
+            mx_context = mx.cpu()
+        else:
+            logging.error("Context argument is '" + context + "'. Only 'cpu' and 'gpu are valid arguments'.")
 
         if 'weight_decay' in optimizer_params:
             optimizer_params['wd'] = optimizer_params['weight_decay']
@@ -133,13 +139,13 @@ class CNNCreator_Alexnet:
         train_iter, test_iter, data_mean, data_std = self.load_data(batch_size)
         if self.module == None:
             if normalize:
-                self.construct(context, data_mean, data_std)
+                self.construct(mx_context, data_mean, data_std)
             else:
-                self.construct(context)
+                self.construct(mx_context)
 
         begin_epoch = 0
         if load_checkpoint:
-            begin_epoch = self.load(context)
+            begin_epoch = self.load(mx_context)
         else:
             if os.path.isdir(self._model_dir_):
                 shutil.rmtree(self._model_dir_)
@@ -417,4 +423,4 @@ class CNNCreator_Alexnet:
         self.module = mx.mod.Module(symbol=mx.symbol.Group([predictions]),
                                          data_names=self._input_names_,
                                          label_names=self._output_names_,
-                                         context=context)
\ No newline at end of file
+                                         context=context)
diff --git a/src/test/resources/target_code/CNNCreator_CifarClassifierNetwork.py b/src/test/resources/target_code/CNNCreator_CifarClassifierNetwork.py
index 974f294b..6dd81feb 100644
--- a/src/test/resources/target_code/CNNCreator_CifarClassifierNetwork.py
+++ b/src/test/resources/target_code/CNNCreator_CifarClassifierNetwork.py
@@ -110,9 +110,15 @@ class CNNCreator_CifarClassifierNetwork:
               optimizer='adam',
               optimizer_params=(('learning_rate', 0.001),),
               load_checkpoint=True,
-              context=mx.gpu(),
+              context='gpu',
               checkpoint_period=5,
               normalize=True):
+        if context == 'gpu':
+            mx_context = mx.gpu()
+        elif context == 'cpu':
+            mx_context = mx.cpu()
+        else:
+            logging.error("Context argument is '" + context + "'. Only 'cpu' and 'gpu are valid arguments'.")
 
         if 'weight_decay' in optimizer_params:
             optimizer_params['wd'] = optimizer_params['weight_decay']
@@ -133,13 +139,13 @@ class CNNCreator_CifarClassifierNetwork:
         train_iter, test_iter, data_mean, data_std = self.load_data(batch_size)
         if self.module == None:
             if normalize:
-                self.construct(context, data_mean, data_std)
+                self.construct(mx_context, data_mean, data_std)
             else:
-                self.construct(context)
+                self.construct(mx_context)
 
         begin_epoch = 0
         if load_checkpoint:
-            begin_epoch = self.load(context)
+            begin_epoch = self.load(mx_context)
         else:
             if os.path.isdir(self._model_dir_):
                 shutil.rmtree(self._model_dir_)
@@ -655,4 +661,4 @@ class CNNCreator_CifarClassifierNetwork:
         self.module = mx.mod.Module(symbol=mx.symbol.Group([softmax]),
                                          data_names=self._input_names_,
                                          label_names=self._output_names_,
-                                         context=context)
\ No newline at end of file
+                                         context=context)
diff --git a/src/test/resources/target_code/CNNCreator_VGG16.py b/src/test/resources/target_code/CNNCreator_VGG16.py
index 540307c6..ba39f3a2 100644
--- a/src/test/resources/target_code/CNNCreator_VGG16.py
+++ b/src/test/resources/target_code/CNNCreator_VGG16.py
@@ -110,9 +110,15 @@ class CNNCreator_VGG16:
               optimizer='adam',
               optimizer_params=(('learning_rate', 0.001),),
               load_checkpoint=True,
-              context=mx.gpu(),
+              context='gpu',
               checkpoint_period=5,
               normalize=True):
+        if context == 'gpu':
+            mx_context = mx.gpu()
+        elif context == 'cpu':
+            mx_context = mx.cpu()
+        else:
+            logging.error("Context argument is '" + context + "'. Only 'cpu' and 'gpu are valid arguments'.")
 
         if 'weight_decay' in optimizer_params:
             optimizer_params['wd'] = optimizer_params['weight_decay']
@@ -133,13 +139,13 @@ class CNNCreator_VGG16:
         train_iter, test_iter, data_mean, data_std = self.load_data(batch_size)
         if self.module == None:
             if normalize:
-                self.construct(context, data_mean, data_std)
+                self.construct(mx_context, data_mean, data_std)
             else:
-                self.construct(context)
+                self.construct(mx_context)
 
         begin_epoch = 0
         if load_checkpoint:
-            begin_epoch = self.load(context)
+            begin_epoch = self.load(mx_context)
         else:
             if os.path.isdir(self._model_dir_):
                 shutil.rmtree(self._model_dir_)
@@ -453,4 +459,4 @@ class CNNCreator_VGG16:
         self.module = mx.mod.Module(symbol=mx.symbol.Group([predictions]),
                                          data_names=self._input_names_,
                                          label_names=self._output_names_,
-                                         context=context)
\ No newline at end of file
+                                         context=context)
-- 
GitLab