Commit af4a9cb2 authored by Evgeny Kusmenko's avatar Evgeny Kusmenko
Browse files

Merge branch 'fix-parameter-loading' into 'master'

Fix parameter loading

See merge request !12
parents 55580062 61b436b3
Pipeline #116877 passed with stages
in 6 minutes and 25 seconds
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
<groupId>de.monticore.lang.monticar</groupId> <groupId>de.monticore.lang.monticar</groupId>
<artifactId>cnnarch-gluon-generator</artifactId> <artifactId>cnnarch-gluon-generator</artifactId>
<version>0.1.4</version> <version>0.1.5</version>
<!-- == PROJECT DEPENDENCIES ============================================= --> <!-- == PROJECT DEPENDENCIES ============================================= -->
......
...@@ -56,7 +56,7 @@ class ${tc.fileNameWithoutEnding}: ...@@ -56,7 +56,7 @@ class ${tc.fileNameWithoutEnding}:
return 0 return 0
else: else:
logging.info("Loading checkpoint: " + param_file) logging.info("Loading checkpoint: " + param_file)
self.net.load_parameters(param_file) self.net.load_parameters(self._model_dir_ + param_file)
return lastEpoch return lastEpoch
...@@ -223,10 +223,11 @@ class ${tc.fileNameWithoutEnding}: ...@@ -223,10 +223,11 @@ class ${tc.fileNameWithoutEnding}:
logging.info("Epoch[%d] Train: %f, Test: %f" % (epoch, train_metric_score, test_metric_score)) logging.info("Epoch[%d] Train: %f, Test: %f" % (epoch, train_metric_score, test_metric_score))
if (epoch - begin_epoch) % checkpoint_period == 0: if (epoch - begin_epoch) % checkpoint_period == 0:
self.net.export(self._model_dir_ + self._model_prefix_, epoch) self.net.save_parameters(self._model_dir_ + self._model_prefix_ + '-' + str(epoch).zfill(4) + '.params')
self.net.export(self._model_dir_ + self._model_prefix_, num_epoch + begin_epoch) self.net.save_parameters(self._model_dir_ + self._model_prefix_ + '-'
self.net.export(self._model_dir_ + self._model_prefix_ + '_newest', 0) + str(num_epoch + begin_epoch).zfill(4) + '.params')
self.net.export(self._model_dir_ + self._model_prefix_ + '_newest', epoch=0)
def construct(self, context, data_mean=None, data_std=None): def construct(self, context, data_mean=None, data_std=None):
...@@ -234,3 +235,8 @@ class ${tc.fileNameWithoutEnding}: ...@@ -234,3 +235,8 @@ class ${tc.fileNameWithoutEnding}:
self.net.collect_params().initialize(self.weight_initializer, ctx=context) self.net.collect_params().initialize(self.weight_initializer, ctx=context)
self.net.hybridize() self.net.hybridize()
self.net(mx.nd.zeros((1,)+self._input_shapes_[0], ctx=context)) self.net(mx.nd.zeros((1,)+self._input_shapes_[0], ctx=context))
if not os.path.exists(self._model_dir_):
os.makedirs(self._model_dir_)
self.net.export(self._model_dir_ + self._model_prefix_, epoch=0)
...@@ -56,7 +56,7 @@ class CNNCreator_Alexnet: ...@@ -56,7 +56,7 @@ class CNNCreator_Alexnet:
return 0 return 0
else: else:
logging.info("Loading checkpoint: " + param_file) logging.info("Loading checkpoint: " + param_file)
self.net.load_parameters(param_file) self.net.load_parameters(self._model_dir_ + param_file)
return lastEpoch return lastEpoch
...@@ -223,10 +223,11 @@ class CNNCreator_Alexnet: ...@@ -223,10 +223,11 @@ class CNNCreator_Alexnet:
logging.info("Epoch[%d] Train: %f, Test: %f" % (epoch, train_metric_score, test_metric_score)) logging.info("Epoch[%d] Train: %f, Test: %f" % (epoch, train_metric_score, test_metric_score))
if (epoch - begin_epoch) % checkpoint_period == 0: if (epoch - begin_epoch) % checkpoint_period == 0:
self.net.export(self._model_dir_ + self._model_prefix_, epoch) self.net.save_parameters(self._model_dir_ + self._model_prefix_ + '-' + str(epoch).zfill(4) + '.params')
self.net.export(self._model_dir_ + self._model_prefix_, num_epoch + begin_epoch) self.net.save_parameters(self._model_dir_ + self._model_prefix_ + '-'
self.net.export(self._model_dir_ + self._model_prefix_ + '_newest', 0) + str(num_epoch + begin_epoch).zfill(4) + '.params')
self.net.export(self._model_dir_ + self._model_prefix_ + '_newest', epoch=0)
def construct(self, context, data_mean=None, data_std=None): def construct(self, context, data_mean=None, data_std=None):
...@@ -234,3 +235,8 @@ class CNNCreator_Alexnet: ...@@ -234,3 +235,8 @@ class CNNCreator_Alexnet:
self.net.collect_params().initialize(self.weight_initializer, ctx=context) self.net.collect_params().initialize(self.weight_initializer, ctx=context)
self.net.hybridize() self.net.hybridize()
self.net(mx.nd.zeros((1,)+self._input_shapes_[0], ctx=context)) self.net(mx.nd.zeros((1,)+self._input_shapes_[0], ctx=context))
if not os.path.exists(self._model_dir_):
os.makedirs(self._model_dir_)
self.net.export(self._model_dir_ + self._model_prefix_, epoch=0)
...@@ -56,7 +56,7 @@ class CNNCreator_CifarClassifierNetwork: ...@@ -56,7 +56,7 @@ class CNNCreator_CifarClassifierNetwork:
return 0 return 0
else: else:
logging.info("Loading checkpoint: " + param_file) logging.info("Loading checkpoint: " + param_file)
self.net.load_parameters(param_file) self.net.load_parameters(self._model_dir_ + param_file)
return lastEpoch return lastEpoch
...@@ -223,10 +223,11 @@ class CNNCreator_CifarClassifierNetwork: ...@@ -223,10 +223,11 @@ class CNNCreator_CifarClassifierNetwork:
logging.info("Epoch[%d] Train: %f, Test: %f" % (epoch, train_metric_score, test_metric_score)) logging.info("Epoch[%d] Train: %f, Test: %f" % (epoch, train_metric_score, test_metric_score))
if (epoch - begin_epoch) % checkpoint_period == 0: if (epoch - begin_epoch) % checkpoint_period == 0:
self.net.export(self._model_dir_ + self._model_prefix_, epoch) self.net.save_parameters(self._model_dir_ + self._model_prefix_ + '-' + str(epoch).zfill(4) + '.params')
self.net.export(self._model_dir_ + self._model_prefix_, num_epoch + begin_epoch) self.net.save_parameters(self._model_dir_ + self._model_prefix_ + '-'
self.net.export(self._model_dir_ + self._model_prefix_ + '_newest', 0) + str(num_epoch + begin_epoch).zfill(4) + '.params')
self.net.export(self._model_dir_ + self._model_prefix_ + '_newest', epoch=0)
def construct(self, context, data_mean=None, data_std=None): def construct(self, context, data_mean=None, data_std=None):
...@@ -234,3 +235,8 @@ class CNNCreator_CifarClassifierNetwork: ...@@ -234,3 +235,8 @@ class CNNCreator_CifarClassifierNetwork:
self.net.collect_params().initialize(self.weight_initializer, ctx=context) self.net.collect_params().initialize(self.weight_initializer, ctx=context)
self.net.hybridize() self.net.hybridize()
self.net(mx.nd.zeros((1,)+self._input_shapes_[0], ctx=context)) self.net(mx.nd.zeros((1,)+self._input_shapes_[0], ctx=context))
if not os.path.exists(self._model_dir_):
os.makedirs(self._model_dir_)
self.net.export(self._model_dir_ + self._model_prefix_, epoch=0)
...@@ -56,7 +56,7 @@ class CNNCreator_VGG16: ...@@ -56,7 +56,7 @@ class CNNCreator_VGG16:
return 0 return 0
else: else:
logging.info("Loading checkpoint: " + param_file) logging.info("Loading checkpoint: " + param_file)
self.net.load_parameters(param_file) self.net.load_parameters(self._model_dir_ + param_file)
return lastEpoch return lastEpoch
...@@ -223,10 +223,11 @@ class CNNCreator_VGG16: ...@@ -223,10 +223,11 @@ class CNNCreator_VGG16:
logging.info("Epoch[%d] Train: %f, Test: %f" % (epoch, train_metric_score, test_metric_score)) logging.info("Epoch[%d] Train: %f, Test: %f" % (epoch, train_metric_score, test_metric_score))
if (epoch - begin_epoch) % checkpoint_period == 0: if (epoch - begin_epoch) % checkpoint_period == 0:
self.net.export(self._model_dir_ + self._model_prefix_, epoch) self.net.save_parameters(self._model_dir_ + self._model_prefix_ + '-' + str(epoch).zfill(4) + '.params')
self.net.export(self._model_dir_ + self._model_prefix_, num_epoch + begin_epoch) self.net.save_parameters(self._model_dir_ + self._model_prefix_ + '-'
self.net.export(self._model_dir_ + self._model_prefix_ + '_newest', 0) + str(num_epoch + begin_epoch).zfill(4) + '.params')
self.net.export(self._model_dir_ + self._model_prefix_ + '_newest', epoch=0)
def construct(self, context, data_mean=None, data_std=None): def construct(self, context, data_mean=None, data_std=None):
...@@ -234,3 +235,8 @@ class CNNCreator_VGG16: ...@@ -234,3 +235,8 @@ class CNNCreator_VGG16:
self.net.collect_params().initialize(self.weight_initializer, ctx=context) self.net.collect_params().initialize(self.weight_initializer, ctx=context)
self.net.hybridize() self.net.hybridize()
self.net(mx.nd.zeros((1,)+self._input_shapes_[0], ctx=context)) self.net(mx.nd.zeros((1,)+self._input_shapes_[0], ctx=context))
if not os.path.exists(self._model_dir_):
os.makedirs(self._model_dir_)
self.net.export(self._model_dir_ + self._model_prefix_, epoch=0)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment