Commit 629d11dd authored by Sascha Dewes's avatar Sascha Dewes
Browse files

updated target code for weight initializer feature

parent bd51c601
Pipeline #408735 passed with stage
in 7 seconds
......@@ -22,7 +22,7 @@
<cnnarch-generator.version>0.4.5</cnnarch-generator.version>
<cnnarch-mxnet-generator.version>0.4.5</cnnarch-mxnet-generator.version>
<cnnarch-caffe2-generator.version>0.4.5</cnnarch-caffe2-generator.version>
<cnnarch-gluon-generator.version>0.4.5</cnnarch-gluon-generator.version>
<cnnarch-gluon-generator.version>0.4.6-SNAPSHOT</cnnarch-gluon-generator.version>
<cnnarch-tensorflow-generator.version>0.4.5</cnnarch-tensorflow-generator.version>
<!-- .. Libraries .................................................. -->
......
......@@ -170,6 +170,9 @@ class CNNCreator_mnist_mnistClassifier_net:
for i, network in self.networks.items():
network.export(self._model_dir_ + self._model_prefix_ + "_" + str(i), epoch=0)
def setWeightInitializer(self, initializer):
self.weight_initializer = initializer
def getInputs(self):
inputs = {}
input_dimensions = (1,28,28,)
......
......@@ -170,6 +170,9 @@ class CNNCreator_defaultGAN_defaultGANDiscriminator:
for i, network in self.networks.items():
network.export(self._model_dir_ + self._model_prefix_ + "_" + str(i), epoch=0)
def setWeightInitializer(self, initializer):
self.weight_initializer = initializer
def getInputs(self):
inputs = {}
input_dimensions = (1,64,64,)
......
......@@ -170,6 +170,9 @@ class CNNCreator_infoGAN_infoGANDiscriminator:
for i, network in self.networks.items():
network.export(self._model_dir_ + self._model_prefix_ + "_" + str(i), epoch=0)
def setWeightInitializer(self, initializer):
self.weight_initializer = initializer
def getInputs(self):
inputs = {}
input_dimensions = (1,28,28,)
......
......@@ -170,6 +170,9 @@ class CNNCreator_cartpole_master_dqn:
for i, network in self.networks.items():
network.export(self._model_dir_ + self._model_prefix_ + "_" + str(i), epoch=0)
def setWeightInitializer(self, initializer):
self.weight_initializer = initializer
def getInputs(self):
inputs = {}
input_dimensions = (4,)
......
......@@ -170,6 +170,9 @@ class CNNCreator_mountaincar_master_actor:
for i, network in self.networks.items():
network.export(self._model_dir_ + self._model_prefix_ + "_" + str(i), epoch=0)
def setWeightInitializer(self, initializer):
self.weight_initializer = initializer
def getInputs(self):
inputs = {}
input_dimensions = (2,)
......
......@@ -170,6 +170,9 @@ class CNNCreator_torcs_agent_torcsAgent_dqn:
for i, network in self.networks.items():
network.export(self._model_dir_ + self._model_prefix_ + "_" + str(i), epoch=0)
def setWeightInitializer(self, initializer):
self.weight_initializer = initializer
def getInputs(self):
inputs = {}
input_dimensions = (5,)
......
......@@ -170,6 +170,9 @@ class CNNCreator_torcs_agent_torcsAgent_actor:
for i, network in self.networks.items():
network.export(self._model_dir_ + self._model_prefix_ + "_" + str(i), epoch=0)
def setWeightInitializer(self, initializer):
self.weight_initializer = initializer
def getInputs(self):
inputs = {}
input_dimensions = (29,)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment