Commit 1812de35 authored by Sebastian Nickels's avatar Sebastian Nickels
Browse files

Implemented multiple inputs

parent bbb95cbf
...@@ -32,8 +32,6 @@ import java.util.Map; ...@@ -32,8 +32,6 @@ import java.util.Map;
public class CNNArch2Gluon extends CNNArch2MxNet { public class CNNArch2Gluon extends CNNArch2MxNet {
public CNNArch2Gluon() { public CNNArch2Gluon() {
super();
architectureSupportChecker = new CNNArch2GluonArchitectureSupportChecker(); architectureSupportChecker = new CNNArch2GluonArchitectureSupportChecker();
layerSupportChecker = new CNNArch2GluonLayerSupportChecker(); layerSupportChecker = new CNNArch2GluonLayerSupportChecker();
} }
...@@ -62,6 +60,9 @@ public class CNNArch2Gluon extends CNNArch2MxNet { ...@@ -62,6 +60,9 @@ public class CNNArch2Gluon extends CNNArch2MxNet {
temp = archTc.process("CNNCreator", Target.PYTHON); temp = archTc.process("CNNCreator", Target.PYTHON);
fileContentMap.put(temp.getKey(), temp.getValue()); fileContentMap.put(temp.getKey(), temp.getValue());
temp = archTc.process("CNNSupervisedTrainer", Target.PYTHON);
fileContentMap.put(temp.getKey(), temp.getValue());
temp = archTc.process("execute", Target.CPP); temp = archTc.process("execute", Target.CPP);
fileContentMap.put(temp.getKey().replace(".h", ""), temp.getValue()); fileContentMap.put(temp.getKey().replace(".h", ""), temp.getValue());
......
...@@ -11,9 +11,9 @@ public class CNNArch2GluonArchitectureSupportChecker extends ArchitectureSupport ...@@ -11,9 +11,9 @@ public class CNNArch2GluonArchitectureSupportChecker extends ArchitectureSupport
return true; return true;
}*/ }*/
/*protected boolean checkMultipleInputs(ArchitectureSymbol architecture) { protected boolean checkMultipleInputs(ArchitectureSymbol architecture) {
return true; return true;
}*/ }
/*protected boolean checkMultipleOutputs(ArchitectureSymbol architecture) { /*protected boolean checkMultipleOutputs(ArchitectureSymbol architecture) {
return true; return true;
......
...@@ -93,9 +93,6 @@ public class CNNTrain2Gluon extends CNNTrain2MxNet { ...@@ -93,9 +93,6 @@ public class CNNTrain2Gluon extends CNNTrain2MxNet {
if (configData.isSupervisedLearning()) { if (configData.isSupervisedLearning()) {
String cnnTrainTemplateContent = templateConfiguration.processTemplate(ftlContext, "CNNTrainer.ftl"); String cnnTrainTemplateContent = templateConfiguration.processTemplate(ftlContext, "CNNTrainer.ftl");
fileContentMap.put("CNNTrainer_" + getInstanceName() + ".py", cnnTrainTemplateContent); fileContentMap.put("CNNTrainer_" + getInstanceName() + ".py", cnnTrainTemplateContent);
String cnnSupervisedTrainerContent = templateConfiguration.processTemplate(ftlContext, "CNNSupervisedTrainer.ftl");
fileContentMap.put("supervised_trainer.py", cnnSupervisedTrainerContent);
} else if (configData.isReinforcementLearning()) { } else if (configData.isReinforcementLearning()) {
final String trainerName = "CNNTrainer_" + getInstanceName(); final String trainerName = "CNNTrainer_" + getInstanceName();
ftlContext.put("trainerName", trainerName); ftlContext.put("trainerName", trainerName);
......
...@@ -43,12 +43,11 @@ class ${tc.fileNameWithoutEnding}: ...@@ -43,12 +43,11 @@ class ${tc.fileNameWithoutEnding}:
self.net.load_parameters(self._model_dir_ + param_file) self.net.load_parameters(self._model_dir_ + param_file)
return lastEpoch return lastEpoch
def construct(self, context, data_mean=None, data_std=None): def construct(self, context, data_mean=None, data_std=None):
self.net = Net(data_mean=data_mean, data_std=data_std) self.net = Net(data_mean=data_mean, data_std=data_std)
self.net.collect_params().initialize(self.weight_initializer, ctx=context) self.net.collect_params().initialize(self.weight_initializer, ctx=context)
self.net.hybridize() self.net.hybridize()
self.net(mx.nd.zeros((1,)+self._input_shapes_[0], ctx=context)) self.net(<#list tc.architecture.inputs as input>mx.nd.zeros((1,) + self._input_shapes_[${input?index}], ctx=context)<#sep>, </#list>)
if not os.path.exists(self._model_dir_): if not os.path.exists(self._model_dir_):
os.makedirs(self._model_dir_) os.makedirs(self._model_dir_)
......
...@@ -3,10 +3,11 @@ import h5py ...@@ -3,10 +3,11 @@ import h5py
import mxnet as mx import mxnet as mx
import logging import logging
import sys import sys
from mxnet import nd
class ${tc.fullArchitectureName}DataLoader: class ${tc.fileNameWithoutEnding}:
_input_names_ = [${tc.join(tc.architectureInputs, ",", "'", "'")}] _input_names_ = [${tc.join(tc.architectureInputs, ",", "'", "'")}]
_output_names_ = [${tc.join(tc.architectureOutputs, ",", "'", "_label'")}] _output_names_ = [${tc.join(tc.architectureOutputs, ",", "'", "'")}]
def __init__(self): def __init__(self):
self._data_dir = "${tc.dataPath}/" self._data_dir = "${tc.dataPath}/"
...@@ -14,21 +15,38 @@ class ${tc.fullArchitectureName}DataLoader: ...@@ -14,21 +15,38 @@ class ${tc.fullArchitectureName}DataLoader:
def load_data(self, batch_size): def load_data(self, batch_size):
train_h5, test_h5 = self.load_h5_files() train_h5, test_h5 = self.load_h5_files()
data_mean = train_h5[self._input_names_[0]][:].mean(axis=0) train_data = {}
data_std = train_h5[self._input_names_[0]][:].std(axis=0) + 1e-5 data_mean = {}
data_std = {}
for input_name in self._input_names_:
train_data[input_name] = train_h5[input_name]
data_mean[input_name] = nd.array(train_h5[input_name][:].mean(axis=0))
data_std[input_name] = nd.array(train_h5[input_name][:].std(axis=0) + 1e-5)
train_label = {}
for output_name in self._output_names_:
train_label[output_name] = train_h5[output_name]
train_iter = mx.io.NDArrayIter(data=train_data,
label=train_label,
batch_size=batch_size)
train_iter = mx.io.NDArrayIter(train_h5[self._input_names_[0]],
train_h5[self._output_names_[0]],
batch_size=batch_size,
data_name=self._input_names_[0],
label_name=self._output_names_[0])
test_iter = None test_iter = None
if test_h5 != None: if test_h5 != None:
test_iter = mx.io.NDArrayIter(test_h5[self._input_names_[0]], test_data = {}
test_h5[self._output_names_[0]], for input_name in self._input_names_:
batch_size=batch_size, test_data[input_name] = test_h5[input_name]
data_name=self._input_names_[0],
label_name=self._output_names_[0]) test_label = {}
for output_name in self._output_names_:
test_label[output_name] = test_h5[output_name]
test_iter = mx.io.NDArrayIter(data=test_data,
label=test_label,
batch_size=batch_size)
return train_iter, test_iter, data_mean, data_std return train_iter, test_iter, data_mean, data_std
def load_h5_files(self): def load_h5_files(self):
...@@ -36,21 +54,39 @@ class ${tc.fullArchitectureName}DataLoader: ...@@ -36,21 +54,39 @@ class ${tc.fullArchitectureName}DataLoader:
test_h5 = None test_h5 = None
train_path = self._data_dir + "train.h5" train_path = self._data_dir + "train.h5"
test_path = self._data_dir + "test.h5" test_path = self._data_dir + "test.h5"
if os.path.isfile(train_path): if os.path.isfile(train_path):
train_h5 = h5py.File(train_path, 'r') train_h5 = h5py.File(train_path, 'r')
if not (self._input_names_[0] in train_h5 and self._output_names_[0] in train_h5):
logging.error("The HDF5 file '" + os.path.abspath(train_path) + "' has to contain the datasets: " for input_name in self._input_names_:
+ "'" + self._input_names_[0] + "', '" + self._output_names_[0] + "'") if not input_name in train_h5:
sys.exit(1) logging.error("The HDF5 file '" + os.path.abspath(train_path) + "' has to contain the dataset "
test_iter = None + "'" + input_name + "'")
sys.exit(1)
for output_name in self._output_names_:
if not output_name in train_h5:
logging.error("The HDF5 file '" + os.path.abspath(train_path) + "' has to contain the dataset "
+ "'" + output_name + "'")
sys.exit(1)
if os.path.isfile(test_path): if os.path.isfile(test_path):
test_h5 = h5py.File(test_path, 'r') test_h5 = h5py.File(test_path, 'r')
if not (self._input_names_[0] in test_h5 and self._output_names_[0] in test_h5):
logging.error("The HDF5 file '" + os.path.abspath(test_path) + "' has to contain the datasets: " for input_name in self._input_names_:
+ "'" + self._input_names_[0] + "', '" + self._output_names_[0] + "'") if not input_name in test_h5:
sys.exit(1) logging.error("The HDF5 file '" + os.path.abspath(test_path) + "' has to contain the dataset "
+ "'" + input_name + "'")
sys.exit(1)
for output_name in self._output_names_:
if not output_name in test_h5:
logging.error("The HDF5 file '" + os.path.abspath(test_path) + "' has to contain the dataset "
+ "'" + output_name + "'")
sys.exit(1)
else: else:
logging.warning("Couldn't load test set. File '" + os.path.abspath(test_path) + "' does not exist.") logging.warning("Couldn't load test set. File '" + os.path.abspath(test_path) + "' does not exist.")
return train_h5, test_h5 return train_h5, test_h5
else: else:
logging.error("Data loading failure. File '" + os.path.abspath(train_path) + "' does not exist.") logging.error("Data loading failure. File '" + os.path.abspath(train_path) + "' does not exist.")
......
...@@ -74,5 +74,5 @@ class Net(gluon.HybridBlock): ...@@ -74,5 +74,5 @@ class Net(gluon.HybridBlock):
with self.name_scope(): with self.name_scope():
${tc.include(tc.architecture.streams[0], "ARCHITECTURE_DEFINITION")} ${tc.include(tc.architecture.streams[0], "ARCHITECTURE_DEFINITION")}
def hybrid_forward(self, F, x): def hybrid_forward(self, F, ${tc.join(tc.architectureInputs, ", ")}):
${tc.include(tc.architecture.streams[0], "FORWARD_FUNCTION")} ${tc.include(tc.architecture.streams[0], "FORWARD_FUNCTION")}
\ No newline at end of file
...@@ -6,7 +6,7 @@ import os ...@@ -6,7 +6,7 @@ import os
import shutil import shutil
from mxnet import gluon, autograd, nd from mxnet import gluon, autograd, nd
class CNNSupervisedTrainer(object): class ${tc.fileNameWithoutEnding}:
def __init__(self, data_loader, net_constructor, net=None): def __init__(self, data_loader, net_constructor, net=None):
self._data_loader = data_loader self._data_loader = data_loader
self._net_creator = net_constructor self._net_creator = net_constructor
...@@ -48,7 +48,7 @@ class CNNSupervisedTrainer(object): ...@@ -48,7 +48,7 @@ class CNNSupervisedTrainer(object):
if self._net is None: if self._net is None:
if normalize: if normalize:
self._net_creator.construct( self._net_creator.construct(
context=mx_context, data_mean=nd.array(data_mean), data_std=nd.array(data_std)) context=mx_context, data_mean=data_mean, data_std=data_std)
else: else:
self._net_creator.construct(context=mx_context) self._net_creator.construct(context=mx_context)
...@@ -75,7 +75,7 @@ class CNNSupervisedTrainer(object): ...@@ -75,7 +75,7 @@ class CNNSupervisedTrainer(object):
loss_function = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss() loss_function = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss()
elif self._net.last_layer == 'linear': elif self._net.last_layer == 'linear':
loss_function = mx.gluon.loss.L2Loss() loss_function = mx.gluon.loss.L2Loss()
else: # TODO: Change default? else:
loss_function = mx.gluon.loss.L2Loss() loss_function = mx.gluon.loss.L2Loss()
logging.warning("Invalid last_layer, defaulting to L2 loss") logging.warning("Invalid last_layer, defaulting to L2 loss")
...@@ -85,10 +85,13 @@ class CNNSupervisedTrainer(object): ...@@ -85,10 +85,13 @@ class CNNSupervisedTrainer(object):
for epoch in range(begin_epoch, begin_epoch + num_epoch): for epoch in range(begin_epoch, begin_epoch + num_epoch):
train_iter.reset() train_iter.reset()
for batch_i, batch in enumerate(train_iter): for batch_i, batch in enumerate(train_iter):
data = batch.data[0].as_in_context(mx_context) <#list tc.architectureInputs as input_name>
${input_name} = batch.data[${input_name?index}].as_in_context(mx_context)
</#list>
label = batch.label[0].as_in_context(mx_context) label = batch.label[0].as_in_context(mx_context)
with autograd.record(): with autograd.record():
output = self._net(data) output = self._net(${tc.join(tc.architectureInputs, ",")})
loss = loss_function(output, label) loss = loss_function(output, label)
loss.backward() loss.backward()
...@@ -112,9 +115,12 @@ class CNNSupervisedTrainer(object): ...@@ -112,9 +115,12 @@ class CNNSupervisedTrainer(object):
train_iter.reset() train_iter.reset()
metric = mx.metric.create(eval_metric) metric = mx.metric.create(eval_metric)
for batch_i, batch in enumerate(train_iter): for batch_i, batch in enumerate(train_iter):
data = batch.data[0].as_in_context(mx_context) <#list tc.architectureInputs as input_name>
${input_name} = batch.data[${input_name?index}].as_in_context(mx_context)
</#list>
label = batch.label[0].as_in_context(mx_context) label = batch.label[0].as_in_context(mx_context)
output = self._net(data)
output = self._net(${tc.join(tc.architectureInputs, ",")})
predictions = mx.nd.argmax(output, axis=1) predictions = mx.nd.argmax(output, axis=1)
metric.update(preds=predictions, labels=label) metric.update(preds=predictions, labels=label)
train_metric_score = metric.get()[1] train_metric_score = metric.get()[1]
...@@ -122,9 +128,12 @@ class CNNSupervisedTrainer(object): ...@@ -122,9 +128,12 @@ class CNNSupervisedTrainer(object):
test_iter.reset() test_iter.reset()
metric = mx.metric.create(eval_metric) metric = mx.metric.create(eval_metric)
for batch_i, batch in enumerate(test_iter): for batch_i, batch in enumerate(test_iter):
data = batch.data[0].as_in_context(mx_context) <#list tc.architectureInputs as input_name>
${input_name} = batch.data[${input_name?index}].as_in_context(mx_context)
</#list>
label = batch.label[0].as_in_context(mx_context) label = batch.label[0].as_in_context(mx_context)
output = self._net(data)
output = self._net(${tc.join(tc.architectureInputs, ",")})
predictions = mx.nd.argmax(output, axis=1) predictions = mx.nd.argmax(output, axis=1)
metric.update(preds=predictions, labels=label) metric.update(preds=predictions, labels=label)
test_metric_score = metric.get()[1] test_metric_score = metric.get()[1]
......
import logging import logging
import mxnet as mx import mxnet as mx
import supervised_trainer
<#list configurations as config> <#list configurations as config>
import CNNCreator_${config.instanceName} import CNNCreator_${config.instanceName}
import CNNDataLoader_${config.instanceName} import CNNDataLoader_${config.instanceName}
import CNNSupervisedTrainer_${config.instanceName}
</#list> </#list>
if __name__ == "__main__": if __name__ == "__main__":
...@@ -14,9 +14,11 @@ if __name__ == "__main__": ...@@ -14,9 +14,11 @@ if __name__ == "__main__":
<#list configurations as config> <#list configurations as config>
${config.instanceName}_creator = CNNCreator_${config.instanceName}.CNNCreator_${config.instanceName}() ${config.instanceName}_creator = CNNCreator_${config.instanceName}.CNNCreator_${config.instanceName}()
${config.instanceName}_loader = CNNDataLoader_${config.instanceName}.${config.instanceName}DataLoader() ${config.instanceName}_loader = CNNDataLoader_${config.instanceName}.CNNDataLoader_${config.instanceName}()
${config.instanceName}_trainer = supervised_trainer.CNNSupervisedTrainer(${config.instanceName}_loader, ${config.instanceName}_trainer = CNNSupervisedTrainer_${config.instanceName}.CNNSupervisedTrainer_${config.instanceName}(
${config.instanceName}_creator) ${config.instanceName}_loader,
${config.instanceName}_creator
)
${config.instanceName}_trainer.train( ${config.instanceName}_trainer.train(
<#if (config.batchSize)??> <#if (config.batchSize)??>
......
<#assign mode = definition_mode.toString()> <#assign mode = definition_mode.toString()>
<#if mode == "ARCHITECTURE_DEFINITION"> <#if mode == "ARCHITECTURE_DEFINITION">
if not data_mean is None: if data_mean:
assert(not data_std is None) assert(data_std)
self.input_normalization = ZScoreNormalization(data_mean=data_mean, data_std=data_std) self.input_normalization_${element.name} = ZScoreNormalization(data_mean=data_mean['${element.name}'],
data_std=data_std['${element.name}'])
else: else:
self.input_normalization = NoNormalization() self.input_normalization_${element.name} = NoNormalization()
</#if> </#if>
<#if mode == "FORWARD_FUNCTION"> <#if mode == "FORWARD_FUNCTION">
${element.name} = self.input_normalization(x) ${element.name} = self.input_normalization_${element.name}(${element.name})
</#if> </#if>
\ No newline at end of file
...@@ -59,12 +59,13 @@ public class GenerationTest extends AbstractSymtabTest { ...@@ -59,12 +59,13 @@ public class GenerationTest extends AbstractSymtabTest {
Paths.get("./target/generated-sources-cnnarch"), Paths.get("./target/generated-sources-cnnarch"),
Paths.get("./src/test/resources/target_code"), Paths.get("./src/test/resources/target_code"),
Arrays.asList( Arrays.asList(
"CNNCreator_CifarClassifierNetwork.py", "CNNCreator_CifarClassifierNetwork.py",
"CNNNet_CifarClassifierNetwork.py", "CNNNet_CifarClassifierNetwork.py",
"CNNDataLoader_CifarClassifierNetwork.py", "CNNDataLoader_CifarClassifierNetwork.py",
"CNNPredictor_CifarClassifierNetwork.h", "CNNSupervisedTrainer_CifarClassifierNetwork.py",
"execute_CifarClassifierNetwork", "CNNPredictor_CifarClassifierNetwork.h",
"CNNBufferFile.h")); "execute_CifarClassifierNetwork",
"CNNBufferFile.h"));
} }
@Test @Test
...@@ -81,6 +82,7 @@ public class GenerationTest extends AbstractSymtabTest { ...@@ -81,6 +82,7 @@ public class GenerationTest extends AbstractSymtabTest {
"CNNCreator_Alexnet.py", "CNNCreator_Alexnet.py",
"CNNNet_Alexnet.py", "CNNNet_Alexnet.py",
"CNNDataLoader_Alexnet.py", "CNNDataLoader_Alexnet.py",
"CNNSupervisedTrainer_Alexnet.py",
"CNNPredictor_Alexnet.h", "CNNPredictor_Alexnet.h",
"execute_Alexnet")); "execute_Alexnet"));
} }
...@@ -99,6 +101,7 @@ public class GenerationTest extends AbstractSymtabTest { ...@@ -99,6 +101,7 @@ public class GenerationTest extends AbstractSymtabTest {
"CNNCreator_VGG16.py", "CNNCreator_VGG16.py",
"CNNNet_VGG16.py", "CNNNet_VGG16.py",
"CNNDataLoader_VGG16.py", "CNNDataLoader_VGG16.py",
"CNNSupervisedTrainer_VGG16.py",
"CNNPredictor_VGG16.h", "CNNPredictor_VGG16.h",
"execute_VGG16")); "execute_VGG16"));
} }
...@@ -108,7 +111,7 @@ public class GenerationTest extends AbstractSymtabTest { ...@@ -108,7 +111,7 @@ public class GenerationTest extends AbstractSymtabTest {
Log.getFindings().clear(); Log.getFindings().clear();
String[] args = {"-m", "src/test/resources/architectures", "-r", "ThreeInputCNN_M14"}; String[] args = {"-m", "src/test/resources/architectures", "-r", "ThreeInputCNN_M14"};
CNNArch2GluonCli.main(args); CNNArch2GluonCli.main(args);
assertTrue(Log.getFindings().size() == 2); assertTrue(Log.getFindings().isEmpty());
} }
@Test @Test
...@@ -146,9 +149,7 @@ public class GenerationTest extends AbstractSymtabTest { ...@@ -146,9 +149,7 @@ public class GenerationTest extends AbstractSymtabTest {
checkFilesAreEqual( checkFilesAreEqual(
Paths.get("./target/generated-sources-cnnarch"), Paths.get("./target/generated-sources-cnnarch"),
Paths.get("./src/test/resources/target_code"), Paths.get("./src/test/resources/target_code"),
Arrays.asList( Arrays.asList("CNNTrainer_fullConfig.py"));
"CNNTrainer_fullConfig.py",
"supervised_trainer.py"));
} }
@Test @Test
...@@ -163,9 +164,7 @@ public class GenerationTest extends AbstractSymtabTest { ...@@ -163,9 +164,7 @@ public class GenerationTest extends AbstractSymtabTest {
checkFilesAreEqual( checkFilesAreEqual(
Paths.get("./target/generated-sources-cnnarch"), Paths.get("./target/generated-sources-cnnarch"),
Paths.get("./src/test/resources/target_code"), Paths.get("./src/test/resources/target_code"),
Arrays.asList( Arrays.asList("CNNTrainer_simpleConfig.py"));
"CNNTrainer_simpleConfig.py",
"supervised_trainer.py"));
} }
@Test @Test
...@@ -179,9 +178,7 @@ public class GenerationTest extends AbstractSymtabTest { ...@@ -179,9 +178,7 @@ public class GenerationTest extends AbstractSymtabTest {
checkFilesAreEqual( checkFilesAreEqual(
Paths.get("./target/generated-sources-cnnarch"), Paths.get("./target/generated-sources-cnnarch"),
Paths.get("./src/test/resources/target_code"), Paths.get("./src/test/resources/target_code"),
Arrays.asList( Arrays.asList("CNNTrainer_emptyConfig.py"));
"CNNTrainer_emptyConfig.py",
"supervised_trainer.py"));
} }
@Test @Test
...@@ -222,14 +219,12 @@ public class GenerationTest extends AbstractSymtabTest { ...@@ -222,14 +219,12 @@ public class GenerationTest extends AbstractSymtabTest {
checkFilesAreEqual( checkFilesAreEqual(
Paths.get("./target/generated-sources-cnnarch"), Paths.get("./target/generated-sources-cnnarch"),
Paths.get("./src/test/resources/target_code"), Paths.get("./src/test/resources/target_code"),
Arrays.asList( Arrays.asList("CMakeLists.txt"));
"CMakeLists.txt"));
checkFilesAreEqual( checkFilesAreEqual(
Paths.get("./target/generated-sources-cnnarch/cmake"), Paths.get("./target/generated-sources-cnnarch/cmake"),
Paths.get("./src/test/resources/target_code/cmake"), Paths.get("./src/test/resources/target_code/cmake"),
Arrays.asList( Arrays.asList("FindArmadillo.cmake"));
"FindArmadillo.cmake"));
} }
} }
...@@ -43,12 +43,11 @@ class CNNCreator_Alexnet: ...@@ -43,12 +43,11 @@ class CNNCreator_Alexnet:
self.net.load_parameters(self._model_dir_ + param_file) self.net.load_parameters(self._model_dir_ + param_file)
return lastEpoch return lastEpoch
def construct(self, context, data_mean=None, data_std=None): def construct(self, context, data_mean=None, data_std=None):
self.net = Net(data_mean=data_mean, data_std=data_std) self.net = Net(data_mean=data_mean, data_std=data_std)
self.net.collect_params().initialize(self.weight_initializer, ctx=context) self.net.collect_params().initialize(self.weight_initializer, ctx=context)
self.net.hybridize() self.net.hybridize()
self.net(mx.nd.zeros((1,)+self._input_shapes_[0], ctx=context)) self.net(mx.nd.zeros((1,) + self._input_shapes_[0], ctx=context))
if not os.path.exists(self._model_dir_): if not os.path.exists(self._model_dir_):
os.makedirs(self._model_dir_) os.makedirs(self._model_dir_)
......
...@@ -43,12 +43,11 @@ class CNNCreator_CifarClassifierNetwork: ...@@ -43,12 +43,11 @@ class CNNCreator_CifarClassifierNetwork:
self.net.load_parameters(self._model_dir_ + param_file) self.net.load_parameters(self._model_dir_ + param_file)
return lastEpoch return lastEpoch
def construct(self, context, data_mean=None, data_std=None): def construct(self, context, data_mean=None, data_std=None):
self.net = Net(data_mean=data_mean, data_std=data_std) self.net = Net(data_mean=data_mean, data_std=data_std)
self.net.collect_params().initialize(self.weight_initializer, ctx=context) self.net.collect_params().initialize(self.weight_initializer, ctx=context)
self.net.hybridize() self.net.hybridize()
self.net(mx.nd.zeros((1,)+self._input_shapes_[0], ctx=context)) self.net(mx.nd.zeros((1,) + self._input_shapes_[0], ctx=context))
if not os.path.exists(self._model_dir_): if not os.path.exists(self._model_dir_):
os.makedirs(self._model_dir_) os.makedirs(self._model_dir_)
......
...@@ -43,12 +43,11 @@ class CNNCreator_VGG16: ...@@ -43,12 +43,11 @@ class CNNCreator_VGG16:
self.net.load_parameters(self._model_dir_ + param_file) self.net.load_parameters(self._model_dir_ + param_file)
return lastEpoch return lastEpoch
def construct(self, context, data_mean=None, data_std=None): def construct(self, context, data_mean=None, data_std=None):
self.net = Net(data_mean=data_mean, data_std=data_std) self.net = Net(data_mean=data_mean, data_std=data_std)
self.net.collect_params().initialize(self.weight_initializer, ctx=context) self.net.collect_params().initialize(self.weight_initializer, ctx=context)