Commit 61b436b3 authored by Nicola Gatto's avatar Nicola Gatto
Browse files

Fix parameter loading

parent 55580062
Pipeline #116751 passed with stages
in 5 minutes and 42 seconds
......@@ -8,7 +8,7 @@
<groupId>de.monticore.lang.monticar</groupId>
<artifactId>cnnarch-gluon-generator</artifactId>
<version>0.1.4</version>
<version>0.1.5</version>
<!-- == PROJECT DEPENDENCIES ============================================= -->
......
......@@ -56,7 +56,7 @@ class ${tc.fileNameWithoutEnding}:
return 0
else:
logging.info("Loading checkpoint: " + param_file)
self.net.load_parameters(param_file)
self.net.load_parameters(self._model_dir_ + param_file)
return lastEpoch
......@@ -223,10 +223,11 @@ class ${tc.fileNameWithoutEnding}:
logging.info("Epoch[%d] Train: %f, Test: %f" % (epoch, train_metric_score, test_metric_score))
if (epoch - begin_epoch) % checkpoint_period == 0:
self.net.export(self._model_dir_ + self._model_prefix_, epoch)
self.net.save_parameters(self._model_dir_ + self._model_prefix_ + '-' + str(epoch).zfill(4) + '.params')
self.net.export(self._model_dir_ + self._model_prefix_, num_epoch + begin_epoch)
self.net.export(self._model_dir_ + self._model_prefix_ + '_newest', 0)
self.net.save_parameters(self._model_dir_ + self._model_prefix_ + '-'
+ str(num_epoch + begin_epoch).zfill(4) + '.params')
self.net.export(self._model_dir_ + self._model_prefix_ + '_newest', epoch=0)
def construct(self, context, data_mean=None, data_std=None):
......@@ -234,3 +235,8 @@ class ${tc.fileNameWithoutEnding}:
self.net.collect_params().initialize(self.weight_initializer, ctx=context)
self.net.hybridize()
self.net(mx.nd.zeros((1,)+self._input_shapes_[0], ctx=context))
if not os.path.exists(self._model_dir_):
os.makedirs(self._model_dir_)
self.net.export(self._model_dir_ + self._model_prefix_, epoch=0)
......@@ -56,7 +56,7 @@ class CNNCreator_Alexnet:
return 0
else:
logging.info("Loading checkpoint: " + param_file)
self.net.load_parameters(param_file)
self.net.load_parameters(self._model_dir_ + param_file)
return lastEpoch
......@@ -223,10 +223,11 @@ class CNNCreator_Alexnet:
logging.info("Epoch[%d] Train: %f, Test: %f" % (epoch, train_metric_score, test_metric_score))
if (epoch - begin_epoch) % checkpoint_period == 0:
self.net.export(self._model_dir_ + self._model_prefix_, epoch)
self.net.save_parameters(self._model_dir_ + self._model_prefix_ + '-' + str(epoch).zfill(4) + '.params')
self.net.export(self._model_dir_ + self._model_prefix_, num_epoch + begin_epoch)
self.net.export(self._model_dir_ + self._model_prefix_ + '_newest', 0)
self.net.save_parameters(self._model_dir_ + self._model_prefix_ + '-'
+ str(num_epoch + begin_epoch).zfill(4) + '.params')
self.net.export(self._model_dir_ + self._model_prefix_ + '_newest', epoch=0)
def construct(self, context, data_mean=None, data_std=None):
......@@ -234,3 +235,8 @@ class CNNCreator_Alexnet:
self.net.collect_params().initialize(self.weight_initializer, ctx=context)
self.net.hybridize()
self.net(mx.nd.zeros((1,)+self._input_shapes_[0], ctx=context))
if not os.path.exists(self._model_dir_):
os.makedirs(self._model_dir_)
self.net.export(self._model_dir_ + self._model_prefix_, epoch=0)
......@@ -56,7 +56,7 @@ class CNNCreator_CifarClassifierNetwork:
return 0
else:
logging.info("Loading checkpoint: " + param_file)
self.net.load_parameters(param_file)
self.net.load_parameters(self._model_dir_ + param_file)
return lastEpoch
......@@ -223,10 +223,11 @@ class CNNCreator_CifarClassifierNetwork:
logging.info("Epoch[%d] Train: %f, Test: %f" % (epoch, train_metric_score, test_metric_score))
if (epoch - begin_epoch) % checkpoint_period == 0:
self.net.export(self._model_dir_ + self._model_prefix_, epoch)
self.net.save_parameters(self._model_dir_ + self._model_prefix_ + '-' + str(epoch).zfill(4) + '.params')
self.net.export(self._model_dir_ + self._model_prefix_, num_epoch + begin_epoch)
self.net.export(self._model_dir_ + self._model_prefix_ + '_newest', 0)
self.net.save_parameters(self._model_dir_ + self._model_prefix_ + '-'
+ str(num_epoch + begin_epoch).zfill(4) + '.params')
self.net.export(self._model_dir_ + self._model_prefix_ + '_newest', epoch=0)
def construct(self, context, data_mean=None, data_std=None):
......@@ -234,3 +235,8 @@ class CNNCreator_CifarClassifierNetwork:
self.net.collect_params().initialize(self.weight_initializer, ctx=context)
self.net.hybridize()
self.net(mx.nd.zeros((1,)+self._input_shapes_[0], ctx=context))
if not os.path.exists(self._model_dir_):
os.makedirs(self._model_dir_)
self.net.export(self._model_dir_ + self._model_prefix_, epoch=0)
......@@ -56,7 +56,7 @@ class CNNCreator_VGG16:
return 0
else:
logging.info("Loading checkpoint: " + param_file)
self.net.load_parameters(param_file)
self.net.load_parameters(self._model_dir_ + param_file)
return lastEpoch
......@@ -223,10 +223,11 @@ class CNNCreator_VGG16:
logging.info("Epoch[%d] Train: %f, Test: %f" % (epoch, train_metric_score, test_metric_score))
if (epoch - begin_epoch) % checkpoint_period == 0:
self.net.export(self._model_dir_ + self._model_prefix_, epoch)
self.net.save_parameters(self._model_dir_ + self._model_prefix_ + '-' + str(epoch).zfill(4) + '.params')
self.net.export(self._model_dir_ + self._model_prefix_, num_epoch + begin_epoch)
self.net.export(self._model_dir_ + self._model_prefix_ + '_newest', 0)
self.net.save_parameters(self._model_dir_ + self._model_prefix_ + '-'
+ str(num_epoch + begin_epoch).zfill(4) + '.params')
self.net.export(self._model_dir_ + self._model_prefix_ + '_newest', epoch=0)
def construct(self, context, data_mean=None, data_std=None):
......@@ -234,3 +235,8 @@ class CNNCreator_VGG16:
self.net.collect_params().initialize(self.weight_initializer, ctx=context)
self.net.hybridize()
self.net(mx.nd.zeros((1,)+self._input_shapes_[0], ctx=context))
if not os.path.exists(self._model_dir_):
os.makedirs(self._model_dir_)
self.net.export(self._model_dir_ + self._model_prefix_, epoch=0)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment