Commit effe3b70 authored by Sascha Dewes's avatar Sascha Dewes
Browse files

added weight initializer options

parent 6e3b0bef
...@@ -266,6 +266,60 @@ public class GluonConfigurationData extends ConfigurationData { ...@@ -266,6 +266,60 @@ public class GluonConfigurationData extends ConfigurationData {
return getRlRewardFunctionParameter().get().getOutputParameterName().orElse(null); return getRlRewardFunctionParameter().get().getOutputParameterName().orElse(null);
} }
public String getInitializerName() {
if (getConfiguration().getInitializer() == null) {
return null;
}
return getConfiguration().getInitializer().getName();
}
public Map<String, String> getInitializerParams() {
Map<String, String> mapToStrings = new HashMap<>();
Map<String, InitializerParamSymbol> initializerParams = getConfiguration().getInitializer().getInitializerParamMap();
for (Map.Entry<String, InitializerParamSymbol> entry : initializerParams.entrySet()) {
String paramName = entry.getKey();
String valueAsString = entry.getValue().toString();
Class realClass = entry.getValue().getValue().getValue().getClass();
if (realClass == Boolean.class) {
valueAsString = (Boolean) entry.getValue().getValue().getValue() ? "True" : "False";
}
mapToStrings.put(paramName, valueAsString);
}
if (mapToStrings.isEmpty()) {
return null;
} else {
return mapToStrings;
}
}
public String getCriticInitializerName() {
if (!getConfiguration().getCriticInitializer().isPresent()) {
return null;
}
return getConfiguration().getCriticInitializer().get().getName();
}
public Map<String, String> getCriticInitializerParams() {
assert getConfiguration().getCriticInitializer().isPresent():
"Critic initializer params called although, not present";
Map<String, String> mapToStrings = new HashMap<>();
Map<String, InitializerParamSymbol> initializerParams = getConfiguration().getCriticInitializer().getInitializerParamMap();
for (Map.Entry<String, InitializerParamSymbol> entry : initializerParams.entrySet()) {
String paramName = entry.getKey();
String valueAsString = entry.getValue().toString();
Class realClass = entry.getValue().getValue().getValue().getClass();
if (realClass == Boolean.class) {
valueAsString = (Boolean) entry.getValue().getValue().getValue() ? "True" : "False";
}
mapToStrings.put(paramName, valueAsString);
}
if (mapToStrings.isEmpty()) {
return null;
} else {
return mapToStrings;
}
}
public String getCriticOptimizerName() { public String getCriticOptimizerName() {
if (!getConfiguration().getCriticOptimizer().isPresent()) { if (!getConfiguration().getCriticOptimizer().isPresent()) {
return null; return null;
......
...@@ -182,6 +182,9 @@ class ${tc.fileNameWithoutEnding}: ...@@ -182,6 +182,9 @@ class ${tc.fileNameWithoutEnding}:
for i, network in self.networks.items(): for i, network in self.networks.items():
network.export(self._model_dir_ + self._model_prefix_ + "_" + str(i), epoch=0) network.export(self._model_dir_ + self._model_prefix_ + "_" + str(i), epoch=0)
def setWeightInitializer(self, initializer):
self.weight_initializer = initializer
def getInputs(self): def getInputs(self):
inputs = {} inputs = {}
<#list tc.architecture.streams as stream> <#list tc.architecture.streams as stream>
......
...@@ -83,13 +83,36 @@ if __name__ == "__main__": ...@@ -83,13 +83,36 @@ if __name__ == "__main__":
<#else> <#else>
context = mx.cpu() context = mx.cpu()
</#if> </#if>
<#if (config.configuration.initializer)??>
initializer_params = {
<#list config.initializerParams?keys as param>
'${param}': ${config.initializerParams[param]}<#sep>,
</#list>
}
initializer = mx.init.${config.initializerName?capitalize}(**initializer_params)
<#else>
initializer = mx.init.Normal()
</#if>
<#if (config.configuration.criticInitializer)??>
critic_initializer_params = {
<#list config.criticInitializerParams?keys as param>
'${param}': ${config.criticInitializerParams[param]}<#sep>,
</#list>
}
critic_initializer = mx.init.${config.criticInitializerName?capitalize}(**critic_initializer_params)
<#else>
critic_initializer = mx.init.Normal()
</#if>
<#if config.rlAlgorithm == "dqn"> <#if config.rlAlgorithm == "dqn">
qnet_creator = CNNCreator_${config.instanceName}.CNNCreator_${config.instanceName}() qnet_creator = CNNCreator_${config.instanceName}.CNNCreator_${config.instanceName}()
qnet_creator.setWeightInitializer(initializer)
qnet_creator.construct(context) qnet_creator.construct(context)
<#else> <#else>
actor_creator = CNNCreator_${config.instanceName}.CNNCreator_${config.instanceName}() actor_creator = CNNCreator_${config.instanceName}.CNNCreator_${config.instanceName}()
actor_creator.setWeightInitializer(initializer)
actor_creator.construct(context) actor_creator.construct(context)
critic_creator = CNNCreator_${criticInstanceName}() critic_creator = CNNCreator_${criticInstanceName}()
critic_creator.setWeightInitializer(critic_initializer)
critic_creator.construct(context) critic_creator.construct(context)
</#if> </#if>
......
...@@ -170,6 +170,9 @@ class CNNCreator_Alexnet: ...@@ -170,6 +170,9 @@ class CNNCreator_Alexnet:
for i, network in self.networks.items(): for i, network in self.networks.items():
network.export(self._model_dir_ + self._model_prefix_ + "_" + str(i), epoch=0) network.export(self._model_dir_ + self._model_prefix_ + "_" + str(i), epoch=0)
def setWeightInitializer(self, initializer):
self.weight_initializer = initializer
def getInputs(self): def getInputs(self):
inputs = {} inputs = {}
input_dimensions = (3,224,224,) input_dimensions = (3,224,224,)
......
...@@ -170,6 +170,9 @@ class CNNCreator_CifarClassifierNetwork: ...@@ -170,6 +170,9 @@ class CNNCreator_CifarClassifierNetwork:
for i, network in self.networks.items(): for i, network in self.networks.items():
network.export(self._model_dir_ + self._model_prefix_ + "_" + str(i), epoch=0) network.export(self._model_dir_ + self._model_prefix_ + "_" + str(i), epoch=0)
def setWeightInitializer(self, initializer):
self.weight_initializer = initializer
def getInputs(self): def getInputs(self):
inputs = {} inputs = {}
input_dimensions = (3,32,32,) input_dimensions = (3,32,32,)
......
...@@ -171,6 +171,9 @@ class CNNCreator_EpisodicMemoryNetwork: ...@@ -171,6 +171,9 @@ class CNNCreator_EpisodicMemoryNetwork:
for i, network in self.networks.items(): for i, network in self.networks.items():
network.export(self._model_dir_ + self._model_prefix_ + "_" + str(i), epoch=0) network.export(self._model_dir_ + self._model_prefix_ + "_" + str(i), epoch=0)
def setWeightInitializer(self, initializer):
self.weight_initializer = initializer
def getInputs(self): def getInputs(self):
inputs = {} inputs = {}
input_dimensions = (128,) input_dimensions = (128,)
......
...@@ -170,6 +170,9 @@ class CNNCreator_LoadNetworkTest: ...@@ -170,6 +170,9 @@ class CNNCreator_LoadNetworkTest:
for i, network in self.networks.items(): for i, network in self.networks.items():
network.export(self._model_dir_ + self._model_prefix_ + "_" + str(i), epoch=0) network.export(self._model_dir_ + self._model_prefix_ + "_" + str(i), epoch=0)
def setWeightInitializer(self, initializer):
self.weight_initializer = initializer
def getInputs(self): def getInputs(self):
inputs = {} inputs = {}
input_dimensions = (128,) input_dimensions = (128,)
......
...@@ -170,6 +170,9 @@ class CNNCreator_VGG16: ...@@ -170,6 +170,9 @@ class CNNCreator_VGG16:
for i, network in self.networks.items(): for i, network in self.networks.items():
network.export(self._model_dir_ + self._model_prefix_ + "_" + str(i), epoch=0) network.export(self._model_dir_ + self._model_prefix_ + "_" + str(i), epoch=0)
def setWeightInitializer(self, initializer):
self.weight_initializer = initializer
def getInputs(self): def getInputs(self):
inputs = {} inputs = {}
input_dimensions = (3,224,224,) input_dimensions = (3,224,224,)
......
...@@ -52,7 +52,10 @@ if __name__ == "__main__": ...@@ -52,7 +52,10 @@ if __name__ == "__main__":
env = reinforcement_learning.environment.RosEnvironment(**env_params) env = reinforcement_learning.environment.RosEnvironment(**env_params)
context = mx.cpu() context = mx.cpu()
initializer = mx.init.Normal()
critic_initializer = mx.init.Normal()
qnet_creator = CNNCreator_reinforcementConfig1.CNNCreator_reinforcementConfig1() qnet_creator = CNNCreator_reinforcementConfig1.CNNCreator_reinforcementConfig1()
qnet_creator.setWeightInitializer(initializer)
qnet_creator.construct(context) qnet_creator.construct(context)
agent_params = { agent_params = {
......
...@@ -46,7 +46,10 @@ if __name__ == "__main__": ...@@ -46,7 +46,10 @@ if __name__ == "__main__":
env = reinforcement_learning.environment.GymEnvironment('CartPole-v1') env = reinforcement_learning.environment.GymEnvironment('CartPole-v1')
context = mx.cpu() context = mx.cpu()
initializer = mx.init.Normal()
critic_initializer = mx.init.Normal()
qnet_creator = CNNCreator_reinforcementConfig2.CNNCreator_reinforcementConfig2() qnet_creator = CNNCreator_reinforcementConfig2.CNNCreator_reinforcementConfig2()
qnet_creator.setWeightInitializer(initializer)
qnet_creator.construct(context) qnet_creator.construct(context)
agent_params = { agent_params = {
......
...@@ -53,7 +53,10 @@ if __name__ == "__main__": ...@@ -53,7 +53,10 @@ if __name__ == "__main__":
env = reinforcement_learning.environment.RosEnvironment(**env_params) env = reinforcement_learning.environment.RosEnvironment(**env_params)
context = mx.cpu() context = mx.cpu()
initializer = mx.init.Normal()
critic_initializer = mx.init.Normal()
qnet_creator = CNNCreator_reinforcementConfig3.CNNCreator_reinforcementConfig3() qnet_creator = CNNCreator_reinforcementConfig3.CNNCreator_reinforcementConfig3()
qnet_creator.setWeightInitializer(initializer)
qnet_creator.construct(context) qnet_creator.construct(context)
agent_params = { agent_params = {
......
...@@ -47,9 +47,13 @@ if __name__ == "__main__": ...@@ -47,9 +47,13 @@ if __name__ == "__main__":
env = reinforcement_learning.environment.GymEnvironment('CartPole-v0') env = reinforcement_learning.environment.GymEnvironment('CartPole-v0')
context = mx.cpu() context = mx.cpu()
initializer = mx.init.Normal()
critic_initializer = mx.init.Normal()
actor_creator = CNNCreator_actorNetwork.CNNCreator_actorNetwork() actor_creator = CNNCreator_actorNetwork.CNNCreator_actorNetwork()
actor_creator.setWeightInitializer(initializer)
actor_creator.construct(context) actor_creator.construct(context)
critic_creator = CNNCreator_CriticNetwork() critic_creator = CNNCreator_CriticNetwork()
critic_creator.setWeightInitializer(critic_initializer)
critic_creator.construct(context) critic_creator.construct(context)
agent_params = { agent_params = {
......
...@@ -170,6 +170,9 @@ class CNNCreator_CriticNetwork: ...@@ -170,6 +170,9 @@ class CNNCreator_CriticNetwork:
for i, network in self.networks.items(): for i, network in self.networks.items():
network.export(self._model_dir_ + self._model_prefix_ + "_" + str(i), epoch=0) network.export(self._model_dir_ + self._model_prefix_ + "_" + str(i), epoch=0)
def setWeightInitializer(self, initializer):
self.weight_initializer = initializer
def getInputs(self): def getInputs(self):
inputs = {} inputs = {}
input_dimensions = (8,) input_dimensions = (8,)
......
...@@ -170,6 +170,9 @@ class CNNCreator_Discriminator: ...@@ -170,6 +170,9 @@ class CNNCreator_Discriminator:
for i, network in self.networks.items(): for i, network in self.networks.items():
network.export(self._model_dir_ + self._model_prefix_ + "_" + str(i), epoch=0) network.export(self._model_dir_ + self._model_prefix_ + "_" + str(i), epoch=0)
def setWeightInitializer(self, initializer):
self.weight_initializer = initializer
def getInputs(self): def getInputs(self):
inputs = {} inputs = {}
input_dimensions = (1,64,64,) input_dimensions = (1,64,64,)
......
...@@ -170,6 +170,9 @@ class CNNCreator_InfoDiscriminator: ...@@ -170,6 +170,9 @@ class CNNCreator_InfoDiscriminator:
for i, network in self.networks.items(): for i, network in self.networks.items():
network.export(self._model_dir_ + self._model_prefix_ + "_" + str(i), epoch=0) network.export(self._model_dir_ + self._model_prefix_ + "_" + str(i), epoch=0)
def setWeightInitializer(self, initializer):
self.weight_initializer = initializer
def getInputs(self): def getInputs(self):
inputs = {} inputs = {}
input_dimensions = (1,64,64,) input_dimensions = (1,64,64,)
......
...@@ -170,6 +170,9 @@ class CNNCreator_InfoQNetwork: ...@@ -170,6 +170,9 @@ class CNNCreator_InfoQNetwork:
for i, network in self.networks.items(): for i, network in self.networks.items():
network.export(self._model_dir_ + self._model_prefix_ + "_" + str(i), epoch=0) network.export(self._model_dir_ + self._model_prefix_ + "_" + str(i), epoch=0)
def setWeightInitializer(self, initializer):
self.weight_initializer = initializer
def getInputs(self): def getInputs(self):
inputs = {} inputs = {}
input_dimensions = (1024,) input_dimensions = (1024,)
......
...@@ -54,9 +54,13 @@ if __name__ == "__main__": ...@@ -54,9 +54,13 @@ if __name__ == "__main__":
env = reinforcement_learning.environment.RosEnvironment(**env_params) env = reinforcement_learning.environment.RosEnvironment(**env_params)
context = mx.cpu() context = mx.cpu()
initializer = mx.init.Normal()
critic_initializer = mx.init.Normal()
actor_creator = CNNCreator_rosActorNetwork.CNNCreator_rosActorNetwork() actor_creator = CNNCreator_rosActorNetwork.CNNCreator_rosActorNetwork()
actor_creator.setWeightInitializer(initializer)
actor_creator.construct(context) actor_creator.construct(context)
critic_creator = CNNCreator_RosCriticNetwork() critic_creator = CNNCreator_RosCriticNetwork()
critic_creator.setWeightInitializer(critic_initializer)
critic_creator.construct(context) critic_creator.construct(context)
agent_params = { agent_params = {
......
...@@ -170,6 +170,9 @@ class CNNCreator_RosCriticNetwork: ...@@ -170,6 +170,9 @@ class CNNCreator_RosCriticNetwork:
for i, network in self.networks.items(): for i, network in self.networks.items():
network.export(self._model_dir_ + self._model_prefix_ + "_" + str(i), epoch=0) network.export(self._model_dir_ + self._model_prefix_ + "_" + str(i), epoch=0)
def setWeightInitializer(self, initializer):
self.weight_initializer = initializer
def getInputs(self): def getInputs(self):
inputs = {} inputs = {}
input_dimensions = (8,) input_dimensions = (8,)
......
...@@ -47,9 +47,13 @@ if __name__ == "__main__": ...@@ -47,9 +47,13 @@ if __name__ == "__main__":
env = reinforcement_learning.environment.GymEnvironment('CartPole-v1') env = reinforcement_learning.environment.GymEnvironment('CartPole-v1')
context = mx.cpu() context = mx.cpu()
initializer = mx.init.Normal()
critic_initializer = mx.init.Normal()
actor_creator = CNNCreator_tD3Config.CNNCreator_tD3Config() actor_creator = CNNCreator_tD3Config.CNNCreator_tD3Config()
actor_creator.setWeightInitializer(initializer)
actor_creator.construct(context) actor_creator.construct(context)
critic_creator = CNNCreator_CriticNetwork() critic_creator = CNNCreator_CriticNetwork()
critic_creator.setWeightInitializer(critic_initializer)
critic_creator.construct(context) critic_creator.construct(context)
agent_params = { agent_params = {
......
...@@ -170,6 +170,9 @@ class CNNCreator_CriticNetwork: ...@@ -170,6 +170,9 @@ class CNNCreator_CriticNetwork:
for i, network in self.networks.items(): for i, network in self.networks.items():
network.export(self._model_dir_ + self._model_prefix_ + "_" + str(i), epoch=0) network.export(self._model_dir_ + self._model_prefix_ + "_" + str(i), epoch=0)
def setWeightInitializer(self, initializer):
self.weight_initializer = initializer
def getInputs(self): def getInputs(self):
inputs = {} inputs = {}
input_dimensions = (8,) input_dimensions = (8,)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment