Commit feee580b authored by Julian Treiber's avatar Julian Treiber

updated tests after rebase

parent 4cefb787
Pipeline #267642 failed with stage
in 59 seconds
......@@ -65,10 +65,7 @@ if __name__ == "__main__":
preprocessing=${config.preprocessor?string("True","False")},
</#if>
<#if (config.evalMetric)??>
eval_metric='${config.evalMetric.name}',
eval_metric_params={
<#if (config.evalMetric.exclude)??>
'exclude': [<#list config.evalMetric.exclude as value>${value}<#sep>, </#list>],
eval_metric='${config.evalMetric}',
</#if>
<#if (config.configuration.optimizer)??>
optimizer='${config.optimizerName}',
......@@ -150,5 +147,3 @@ if __name__ == "__main__":
<#if (config.printImages)??>
print_images=${config.printImages?string("True","False")},
</#if>)
import mxnet as mx
import logging
import os
import shutil
from CNNNet_Discriminator import Net_0
......@@ -11,6 +12,7 @@ class CNNCreator_Discriminator:
def __init__(self):
self.weight_initializer = mx.init.Normal()
self.networks = {}
self._weights_dir_ = None
def load(self, context):
earliestLastEpoch = None
......@@ -47,6 +49,29 @@ class CNNCreator_Discriminator:
return earliestLastEpoch
def load_pretrained_weights(self, context):
if os.path.isdir(self._model_dir_):
shutil.rmtree(self._model_dir_)
if self._weights_dir_ is not None:
for i, network in self.networks.items():
# param_file = self._model_prefix_ + "_" + str(i) + "_newest-0000.params"
param_file = None
if os.path.isdir(self._weights_dir_):
lastEpoch = 0
for file in os.listdir(self._weights_dir_):
if ".params" in file and self._model_prefix_ + "_" + str(i) in file:
epochStr = file.replace(".params","").replace(self._model_prefix_ + "_" + str(i) + "-","")
epoch = int(epochStr)
if epoch > lastEpoch:
lastEpoch = epoch
param_file = file
logging.info("Loading pretrained weights: " + self._weights_dir_ + param_file)
network.load_parameters(self._weights_dir_ + param_file, allow_missing=True, ignore_extra=True)
else:
logging.info("No pretrained weights available at: " + self._weights_dir_ + param_file)
def construct(self, context, data_mean=None, data_std=None):
self.networks[0] = Net_0(data_mean=data_mean, data_std=data_std)
self.networks[0].collect_params().initialize(self.weight_initializer, ctx=context)
......
import mxnet as mx
import logging
import os
import shutil
from CNNNet_InfoDiscriminator import Net_0
......@@ -11,6 +12,7 @@ class CNNCreator_InfoDiscriminator:
def __init__(self):
self.weight_initializer = mx.init.Normal()
self.networks = {}
self._weights_dir_ = None
def load(self, context):
earliestLastEpoch = None
......@@ -47,6 +49,29 @@ class CNNCreator_InfoDiscriminator:
return earliestLastEpoch
def load_pretrained_weights(self, context):
if os.path.isdir(self._model_dir_):
shutil.rmtree(self._model_dir_)
if self._weights_dir_ is not None:
for i, network in self.networks.items():
# param_file = self._model_prefix_ + "_" + str(i) + "_newest-0000.params"
param_file = None
if os.path.isdir(self._weights_dir_):
lastEpoch = 0
for file in os.listdir(self._weights_dir_):
if ".params" in file and self._model_prefix_ + "_" + str(i) in file:
epochStr = file.replace(".params","").replace(self._model_prefix_ + "_" + str(i) + "-","")
epoch = int(epochStr)
if epoch > lastEpoch:
lastEpoch = epoch
param_file = file
logging.info("Loading pretrained weights: " + self._weights_dir_ + param_file)
network.load_parameters(self._weights_dir_ + param_file, allow_missing=True, ignore_extra=True)
else:
logging.info("No pretrained weights available at: " + self._weights_dir_ + param_file)
def construct(self, context, data_mean=None, data_std=None):
self.networks[0] = Net_0(data_mean=data_mean, data_std=data_std)
self.networks[0].collect_params().initialize(self.weight_initializer, ctx=context)
......
import mxnet as mx
import logging
import os
import shutil
from CNNNet_InfoQNetwork import Net_0
......@@ -11,6 +12,7 @@ class CNNCreator_InfoQNetwork:
def __init__(self):
self.weight_initializer = mx.init.Normal()
self.networks = {}
self._weights_dir_ = None
def load(self, context):
earliestLastEpoch = None
......@@ -47,6 +49,29 @@ class CNNCreator_InfoQNetwork:
return earliestLastEpoch
def load_pretrained_weights(self, context):
if os.path.isdir(self._model_dir_):
shutil.rmtree(self._model_dir_)
if self._weights_dir_ is not None:
for i, network in self.networks.items():
# param_file = self._model_prefix_ + "_" + str(i) + "_newest-0000.params"
param_file = None
if os.path.isdir(self._weights_dir_):
lastEpoch = 0
for file in os.listdir(self._weights_dir_):
if ".params" in file and self._model_prefix_ + "_" + str(i) in file:
epochStr = file.replace(".params","").replace(self._model_prefix_ + "_" + str(i) + "-","")
epoch = int(epochStr)
if epoch > lastEpoch:
lastEpoch = epoch
param_file = file
logging.info("Loading pretrained weights: " + self._weights_dir_ + param_file)
network.load_parameters(self._weights_dir_ + param_file, allow_missing=True, ignore_extra=True)
else:
logging.info("No pretrained weights available at: " + self._weights_dir_ + param_file)
def construct(self, context, data_mean=None, data_std=None):
self.networks[0] = Net_0(data_mean=data_mean, data_std=data_std)
self.networks[0].collect_params().initialize(self.weight_initializer, ctx=context)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment