Commit ec4b1df8 authored by Evgeny Kusmenko's avatar Evgeny Kusmenko
Browse files

Merge branch 'weight_initializer' into 'master'

Weight initializer

See merge request !36
parents 44881810 b5706cb5
Pipeline #407235 passed with stages
in 2 minutes and 10 seconds
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
<groupId>de.monticore.lang.monticar</groupId> <groupId>de.monticore.lang.monticar</groupId>
<artifactId>cnnarch-gluon-generator</artifactId> <artifactId>cnnarch-gluon-generator</artifactId>
<version>0.4.5</version> <version>0.4.6-SNAPSHOT</version>
<!-- == PROJECT DEPENDENCIES ============================================= --> <!-- == PROJECT DEPENDENCIES ============================================= -->
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
<!-- .. SE-Libraries .................................................. --> <!-- .. SE-Libraries .................................................. -->
<CNNArch2X.version>0.4.5</CNNArch2X.version> <CNNArch2X.version>0.4.6-SNAPSHOT</CNNArch2X.version>
<EMADL2PythonWrapper.version>0.0.3-SNAPSHOT</EMADL2PythonWrapper.version> <EMADL2PythonWrapper.version>0.0.3-SNAPSHOT</EMADL2PythonWrapper.version>
<!-- .. Libraries .................................................. --> <!-- .. Libraries .................................................. -->
......
...@@ -160,6 +160,16 @@ public class GluonConfigurationData extends ConfigurationData { ...@@ -160,6 +160,16 @@ public class GluonConfigurationData extends ConfigurationData {
return !configurationContainsKey(LOSS) return !configurationContainsKey(LOSS)
? null : retrieveConfigurationEntryValueByKey(LOSS).toString(); ? null : retrieveConfigurationEntryValueByKey(LOSS).toString();
} }
public Map<String, Object> getInitializer() {
Map<String, Object> initializer = getMultiParamEntry("initializer", "method");
Map<String, Object> actor_initializer = getMultiParamEntry("actor_initializer", "method");
return (initializer != null) ? initializer : actor_initializer;
}
public Map<String, Object> getCriticInitializer() {
return getMultiParamEntry("critic_initializer", "method");
}
public Map<String, Object> getReplayMemory() { public Map<String, Object> getReplayMemory() {
return getMultiParamEntry(REPLAY_MEMORY, "method"); return getMultiParamEntry(REPLAY_MEMORY, "method");
......
...@@ -182,6 +182,9 @@ class ${tc.fileNameWithoutEnding}: ...@@ -182,6 +182,9 @@ class ${tc.fileNameWithoutEnding}:
for i, network in self.networks.items(): for i, network in self.networks.items():
network.export(self._model_dir_ + self._model_prefix_ + "_" + str(i), epoch=0) network.export(self._model_dir_ + self._model_prefix_ + "_" + str(i), epoch=0)
def setWeightInitializer(self, initializer):
self.weight_initializer = initializer
def getInputs(self): def getInputs(self):
inputs = {} inputs = {}
<#list tc.architecture.streams as stream> <#list tc.architecture.streams as stream>
......
...@@ -83,13 +83,38 @@ if __name__ == "__main__": ...@@ -83,13 +83,38 @@ if __name__ == "__main__":
<#else> <#else>
context = mx.cpu() context = mx.cpu()
</#if> </#if>
<#if (config.initializer)??>
<#if config.initializer.method=="normal">
initializer_params = {
'sigma': ${config.initializer.sigma}
}
initializer = mx.init.Normal(**initializer_params)
</#if>
<#else>
initializer = mx.init.Normal()
</#if>
<#if config.rlAlgorithm=="ddpg" || config.rlAlgorithm=="td3">
<#if (config.criticInitializer)??>
<#if config.criticInitializer.method=="normal">
critic_initializer_params = {
'sigma': ${config.criticInitializer.sigma}
}
critic_initializer = mx.init.Normal(**critic_initializer_params)
</#if>
<#else>
critic_initializer = mx.init.Normal()
</#if>
</#if>
<#if config.rlAlgorithm == "dqn"> <#if config.rlAlgorithm == "dqn">
qnet_creator = CNNCreator_${config.instanceName}.CNNCreator_${config.instanceName}() qnet_creator = CNNCreator_${config.instanceName}.CNNCreator_${config.instanceName}()
qnet_creator.setWeightInitializer(initializer)
qnet_creator.construct(context) qnet_creator.construct(context)
<#else> <#elseif config.rlAlgorithm=="ddpg" || config.rlAlgorithm=="td3">
actor_creator = CNNCreator_${config.instanceName}.CNNCreator_${config.instanceName}() actor_creator = CNNCreator_${config.instanceName}.CNNCreator_${config.instanceName}()
actor_creator.setWeightInitializer(initializer)
actor_creator.construct(context) actor_creator.construct(context)
critic_creator = CNNCreator_${criticInstanceName}() critic_creator = CNNCreator_${criticInstanceName}()
critic_creator.setWeightInitializer(critic_initializer)
critic_creator.construct(context) critic_creator.construct(context)
</#if> </#if>
......
...@@ -170,6 +170,9 @@ class CNNCreator_Alexnet: ...@@ -170,6 +170,9 @@ class CNNCreator_Alexnet:
for i, network in self.networks.items(): for i, network in self.networks.items():
network.export(self._model_dir_ + self._model_prefix_ + "_" + str(i), epoch=0) network.export(self._model_dir_ + self._model_prefix_ + "_" + str(i), epoch=0)
def setWeightInitializer(self, initializer):
self.weight_initializer = initializer
def getInputs(self): def getInputs(self):
inputs = {} inputs = {}
input_dimensions = (3,224,224,) input_dimensions = (3,224,224,)
......
...@@ -170,6 +170,9 @@ class CNNCreator_CifarClassifierNetwork: ...@@ -170,6 +170,9 @@ class CNNCreator_CifarClassifierNetwork:
for i, network in self.networks.items(): for i, network in self.networks.items():
network.export(self._model_dir_ + self._model_prefix_ + "_" + str(i), epoch=0) network.export(self._model_dir_ + self._model_prefix_ + "_" + str(i), epoch=0)
def setWeightInitializer(self, initializer):
self.weight_initializer = initializer
def getInputs(self): def getInputs(self):
inputs = {} inputs = {}
input_dimensions = (3,32,32,) input_dimensions = (3,32,32,)
......
...@@ -171,6 +171,9 @@ class CNNCreator_EpisodicMemoryNetwork: ...@@ -171,6 +171,9 @@ class CNNCreator_EpisodicMemoryNetwork:
for i, network in self.networks.items(): for i, network in self.networks.items():
network.export(self._model_dir_ + self._model_prefix_ + "_" + str(i), epoch=0) network.export(self._model_dir_ + self._model_prefix_ + "_" + str(i), epoch=0)
def setWeightInitializer(self, initializer):
self.weight_initializer = initializer
def getInputs(self): def getInputs(self):
inputs = {} inputs = {}
input_dimensions = (128,) input_dimensions = (128,)
......
...@@ -170,6 +170,9 @@ class CNNCreator_LoadNetworkTest: ...@@ -170,6 +170,9 @@ class CNNCreator_LoadNetworkTest:
for i, network in self.networks.items(): for i, network in self.networks.items():
network.export(self._model_dir_ + self._model_prefix_ + "_" + str(i), epoch=0) network.export(self._model_dir_ + self._model_prefix_ + "_" + str(i), epoch=0)
def setWeightInitializer(self, initializer):
self.weight_initializer = initializer
def getInputs(self): def getInputs(self):
inputs = {} inputs = {}
input_dimensions = (128,) input_dimensions = (128,)
......
...@@ -170,6 +170,9 @@ class CNNCreator_VGG16: ...@@ -170,6 +170,9 @@ class CNNCreator_VGG16:
for i, network in self.networks.items(): for i, network in self.networks.items():
network.export(self._model_dir_ + self._model_prefix_ + "_" + str(i), epoch=0) network.export(self._model_dir_ + self._model_prefix_ + "_" + str(i), epoch=0)
def setWeightInitializer(self, initializer):
self.weight_initializer = initializer
def getInputs(self): def getInputs(self):
inputs = {} inputs = {}
input_dimensions = (3,224,224,) input_dimensions = (3,224,224,)
......
...@@ -52,7 +52,9 @@ if __name__ == "__main__": ...@@ -52,7 +52,9 @@ if __name__ == "__main__":
env = reinforcement_learning.environment.RosEnvironment(**env_params) env = reinforcement_learning.environment.RosEnvironment(**env_params)
context = mx.cpu() context = mx.cpu()
initializer = mx.init.Normal()
qnet_creator = CNNCreator_reinforcementConfig1.CNNCreator_reinforcementConfig1() qnet_creator = CNNCreator_reinforcementConfig1.CNNCreator_reinforcementConfig1()
qnet_creator.setWeightInitializer(initializer)
qnet_creator.construct(context) qnet_creator.construct(context)
agent_params = { agent_params = {
......
...@@ -46,7 +46,9 @@ if __name__ == "__main__": ...@@ -46,7 +46,9 @@ if __name__ == "__main__":
env = reinforcement_learning.environment.GymEnvironment('CartPole-v1') env = reinforcement_learning.environment.GymEnvironment('CartPole-v1')
context = mx.cpu() context = mx.cpu()
initializer = mx.init.Normal()
qnet_creator = CNNCreator_reinforcementConfig2.CNNCreator_reinforcementConfig2() qnet_creator = CNNCreator_reinforcementConfig2.CNNCreator_reinforcementConfig2()
qnet_creator.setWeightInitializer(initializer)
qnet_creator.construct(context) qnet_creator.construct(context)
agent_params = { agent_params = {
......
...@@ -53,7 +53,9 @@ if __name__ == "__main__": ...@@ -53,7 +53,9 @@ if __name__ == "__main__":
env = reinforcement_learning.environment.RosEnvironment(**env_params) env = reinforcement_learning.environment.RosEnvironment(**env_params)
context = mx.cpu() context = mx.cpu()
initializer = mx.init.Normal()
qnet_creator = CNNCreator_reinforcementConfig3.CNNCreator_reinforcementConfig3() qnet_creator = CNNCreator_reinforcementConfig3.CNNCreator_reinforcementConfig3()
qnet_creator.setWeightInitializer(initializer)
qnet_creator.construct(context) qnet_creator.construct(context)
agent_params = { agent_params = {
......
...@@ -47,9 +47,13 @@ if __name__ == "__main__": ...@@ -47,9 +47,13 @@ if __name__ == "__main__":
env = reinforcement_learning.environment.GymEnvironment('CartPole-v0') env = reinforcement_learning.environment.GymEnvironment('CartPole-v0')
context = mx.cpu() context = mx.cpu()
initializer = mx.init.Normal()
critic_initializer = mx.init.Normal()
actor_creator = CNNCreator_actorNetwork.CNNCreator_actorNetwork() actor_creator = CNNCreator_actorNetwork.CNNCreator_actorNetwork()
actor_creator.setWeightInitializer(initializer)
actor_creator.construct(context) actor_creator.construct(context)
critic_creator = CNNCreator_CriticNetwork() critic_creator = CNNCreator_CriticNetwork()
critic_creator.setWeightInitializer(critic_initializer)
critic_creator.construct(context) critic_creator.construct(context)
agent_params = { agent_params = {
......
...@@ -170,6 +170,9 @@ class CNNCreator_CriticNetwork: ...@@ -170,6 +170,9 @@ class CNNCreator_CriticNetwork:
for i, network in self.networks.items(): for i, network in self.networks.items():
network.export(self._model_dir_ + self._model_prefix_ + "_" + str(i), epoch=0) network.export(self._model_dir_ + self._model_prefix_ + "_" + str(i), epoch=0)
def setWeightInitializer(self, initializer):
self.weight_initializer = initializer
def getInputs(self): def getInputs(self):
inputs = {} inputs = {}
input_dimensions = (8,) input_dimensions = (8,)
......
...@@ -170,6 +170,9 @@ class CNNCreator_Discriminator: ...@@ -170,6 +170,9 @@ class CNNCreator_Discriminator:
for i, network in self.networks.items(): for i, network in self.networks.items():
network.export(self._model_dir_ + self._model_prefix_ + "_" + str(i), epoch=0) network.export(self._model_dir_ + self._model_prefix_ + "_" + str(i), epoch=0)
def setWeightInitializer(self, initializer):
self.weight_initializer = initializer
def getInputs(self): def getInputs(self):
inputs = {} inputs = {}
input_dimensions = (1,64,64,) input_dimensions = (1,64,64,)
......
...@@ -170,6 +170,9 @@ class CNNCreator_InfoDiscriminator: ...@@ -170,6 +170,9 @@ class CNNCreator_InfoDiscriminator:
for i, network in self.networks.items(): for i, network in self.networks.items():
network.export(self._model_dir_ + self._model_prefix_ + "_" + str(i), epoch=0) network.export(self._model_dir_ + self._model_prefix_ + "_" + str(i), epoch=0)
def setWeightInitializer(self, initializer):
self.weight_initializer = initializer
def getInputs(self): def getInputs(self):
inputs = {} inputs = {}
input_dimensions = (1,64,64,) input_dimensions = (1,64,64,)
......
...@@ -170,6 +170,9 @@ class CNNCreator_InfoQNetwork: ...@@ -170,6 +170,9 @@ class CNNCreator_InfoQNetwork:
for i, network in self.networks.items(): for i, network in self.networks.items():
network.export(self._model_dir_ + self._model_prefix_ + "_" + str(i), epoch=0) network.export(self._model_dir_ + self._model_prefix_ + "_" + str(i), epoch=0)
def setWeightInitializer(self, initializer):
self.weight_initializer = initializer
def getInputs(self): def getInputs(self):
inputs = {} inputs = {}
input_dimensions = (1024,) input_dimensions = (1024,)
......
...@@ -54,9 +54,13 @@ if __name__ == "__main__": ...@@ -54,9 +54,13 @@ if __name__ == "__main__":
env = reinforcement_learning.environment.RosEnvironment(**env_params) env = reinforcement_learning.environment.RosEnvironment(**env_params)
context = mx.cpu() context = mx.cpu()
initializer = mx.init.Normal()
critic_initializer = mx.init.Normal()
actor_creator = CNNCreator_rosActorNetwork.CNNCreator_rosActorNetwork() actor_creator = CNNCreator_rosActorNetwork.CNNCreator_rosActorNetwork()
actor_creator.setWeightInitializer(initializer)
actor_creator.construct(context) actor_creator.construct(context)
critic_creator = CNNCreator_RosCriticNetwork() critic_creator = CNNCreator_RosCriticNetwork()
critic_creator.setWeightInitializer(critic_initializer)
critic_creator.construct(context) critic_creator.construct(context)
agent_params = { agent_params = {
......
...@@ -170,6 +170,9 @@ class CNNCreator_RosCriticNetwork: ...@@ -170,6 +170,9 @@ class CNNCreator_RosCriticNetwork:
for i, network in self.networks.items(): for i, network in self.networks.items():
network.export(self._model_dir_ + self._model_prefix_ + "_" + str(i), epoch=0) network.export(self._model_dir_ + self._model_prefix_ + "_" + str(i), epoch=0)
def setWeightInitializer(self, initializer):
self.weight_initializer = initializer
def getInputs(self): def getInputs(self):
inputs = {} inputs = {}
input_dimensions = (8,) input_dimensions = (8,)
......
...@@ -47,9 +47,13 @@ if __name__ == "__main__": ...@@ -47,9 +47,13 @@ if __name__ == "__main__":
env = reinforcement_learning.environment.GymEnvironment('CartPole-v1') env = reinforcement_learning.environment.GymEnvironment('CartPole-v1')
context = mx.cpu() context = mx.cpu()
initializer = mx.init.Normal()
critic_initializer = mx.init.Normal()
actor_creator = CNNCreator_tD3Config.CNNCreator_tD3Config() actor_creator = CNNCreator_tD3Config.CNNCreator_tD3Config()
actor_creator.setWeightInitializer(initializer)
actor_creator.construct(context) actor_creator.construct(context)
critic_creator = CNNCreator_CriticNetwork() critic_creator = CNNCreator_CriticNetwork()
critic_creator.setWeightInitializer(critic_initializer)
critic_creator.construct(context) critic_creator.construct(context)
agent_params = { agent_params = {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment