Commit effe3b70 authored by Sascha Dewes's avatar Sascha Dewes
Browse files

added weight initializer options

parent 6e3b0bef
......@@ -265,6 +265,60 @@ public class GluonConfigurationData extends ConfigurationData {
}
return getRlRewardFunctionParameter().get().getOutputParameterName().orElse(null);
}
public String getInitializerName() {
if (getConfiguration().getInitializer() == null) {
return null;
}
return getConfiguration().getInitializer().getName();
}
public Map<String, String> getInitializerParams() {
Map<String, String> mapToStrings = new HashMap<>();
Map<String, InitializerParamSymbol> initializerParams = getConfiguration().getInitializer().getInitializerParamMap();
for (Map.Entry<String, InitializerParamSymbol> entry : initializerParams.entrySet()) {
String paramName = entry.getKey();
String valueAsString = entry.getValue().toString();
Class realClass = entry.getValue().getValue().getValue().getClass();
if (realClass == Boolean.class) {
valueAsString = (Boolean) entry.getValue().getValue().getValue() ? "True" : "False";
}
mapToStrings.put(paramName, valueAsString);
}
if (mapToStrings.isEmpty()) {
return null;
} else {
return mapToStrings;
}
}
public String getCriticInitializerName() {
if (!getConfiguration().getCriticInitializer().isPresent()) {
return null;
}
return getConfiguration().getCriticInitializer().get().getName();
}
public Map<String, String> getCriticInitializerParams() {
assert getConfiguration().getCriticInitializer().isPresent():
"Critic initializer params called although, not present";
Map<String, String> mapToStrings = new HashMap<>();
Map<String, InitializerParamSymbol> initializerParams = getConfiguration().getCriticInitializer().getInitializerParamMap();
for (Map.Entry<String, InitializerParamSymbol> entry : initializerParams.entrySet()) {
String paramName = entry.getKey();
String valueAsString = entry.getValue().toString();
Class realClass = entry.getValue().getValue().getValue().getClass();
if (realClass == Boolean.class) {
valueAsString = (Boolean) entry.getValue().getValue().getValue() ? "True" : "False";
}
mapToStrings.put(paramName, valueAsString);
}
if (mapToStrings.isEmpty()) {
return null;
} else {
return mapToStrings;
}
}
public String getCriticOptimizerName() {
if (!getConfiguration().getCriticOptimizer().isPresent()) {
......
......@@ -182,6 +182,9 @@ class ${tc.fileNameWithoutEnding}:
for i, network in self.networks.items():
network.export(self._model_dir_ + self._model_prefix_ + "_" + str(i), epoch=0)
def setWeightInitializer(self, initializer):
self.weight_initializer = initializer
def getInputs(self):
inputs = {}
<#list tc.architecture.streams as stream>
......
......@@ -83,13 +83,36 @@ if __name__ == "__main__":
<#else>
context = mx.cpu()
</#if>
<#if (config.configuration.initializer)??>
initializer_params = {
<#list config.initializerParams?keys as param>
'${param}': ${config.initializerParams[param]}<#sep>,
</#list>
}
initializer = mx.init.${config.initializerName?capitalize}(**initializer_params)
<#else>
initializer = mx.init.Normal()
</#if>
<#if (config.configuration.criticInitializer)??>
critic_initializer_params = {
<#list config.criticInitializerParams?keys as param>
'${param}': ${config.criticInitializerParams[param]}<#sep>,
</#list>
}
critic_initializer = mx.init.${config.criticInitializerName?capitalize}(**critic_initializer_params)
<#else>
critic_initializer = mx.init.Normal()
</#if>
<#if config.rlAlgorithm == "dqn">
qnet_creator = CNNCreator_${config.instanceName}.CNNCreator_${config.instanceName}()
qnet_creator.setWeightInitializer(initializer)
qnet_creator.construct(context)
<#else>
actor_creator = CNNCreator_${config.instanceName}.CNNCreator_${config.instanceName}()
actor_creator.setWeightInitializer(initializer)
actor_creator.construct(context)
critic_creator = CNNCreator_${criticInstanceName}()
critic_creator.setWeightInitializer(critic_initializer)
critic_creator.construct(context)
</#if>
......
......@@ -170,6 +170,9 @@ class CNNCreator_Alexnet:
for i, network in self.networks.items():
network.export(self._model_dir_ + self._model_prefix_ + "_" + str(i), epoch=0)
def setWeightInitializer(self, initializer):
self.weight_initializer = initializer
def getInputs(self):
inputs = {}
input_dimensions = (3,224,224,)
......
......@@ -170,6 +170,9 @@ class CNNCreator_CifarClassifierNetwork:
for i, network in self.networks.items():
network.export(self._model_dir_ + self._model_prefix_ + "_" + str(i), epoch=0)
def setWeightInitializer(self, initializer):
self.weight_initializer = initializer
def getInputs(self):
inputs = {}
input_dimensions = (3,32,32,)
......
......@@ -171,6 +171,9 @@ class CNNCreator_EpisodicMemoryNetwork:
for i, network in self.networks.items():
network.export(self._model_dir_ + self._model_prefix_ + "_" + str(i), epoch=0)
def setWeightInitializer(self, initializer):
self.weight_initializer = initializer
def getInputs(self):
inputs = {}
input_dimensions = (128,)
......
......@@ -170,6 +170,9 @@ class CNNCreator_LoadNetworkTest:
for i, network in self.networks.items():
network.export(self._model_dir_ + self._model_prefix_ + "_" + str(i), epoch=0)
def setWeightInitializer(self, initializer):
self.weight_initializer = initializer
def getInputs(self):
inputs = {}
input_dimensions = (128,)
......
......@@ -170,6 +170,9 @@ class CNNCreator_VGG16:
for i, network in self.networks.items():
network.export(self._model_dir_ + self._model_prefix_ + "_" + str(i), epoch=0)
def setWeightInitializer(self, initializer):
self.weight_initializer = initializer
def getInputs(self):
inputs = {}
input_dimensions = (3,224,224,)
......
......@@ -52,7 +52,10 @@ if __name__ == "__main__":
env = reinforcement_learning.environment.RosEnvironment(**env_params)
context = mx.cpu()
initializer = mx.init.Normal()
critic_initializer = mx.init.Normal()
qnet_creator = CNNCreator_reinforcementConfig1.CNNCreator_reinforcementConfig1()
qnet_creator.setWeightInitializer(initializer)
qnet_creator.construct(context)
agent_params = {
......
......@@ -46,7 +46,10 @@ if __name__ == "__main__":
env = reinforcement_learning.environment.GymEnvironment('CartPole-v1')
context = mx.cpu()
initializer = mx.init.Normal()
critic_initializer = mx.init.Normal()
qnet_creator = CNNCreator_reinforcementConfig2.CNNCreator_reinforcementConfig2()
qnet_creator.setWeightInitializer(initializer)
qnet_creator.construct(context)
agent_params = {
......
......@@ -53,7 +53,10 @@ if __name__ == "__main__":
env = reinforcement_learning.environment.RosEnvironment(**env_params)
context = mx.cpu()
initializer = mx.init.Normal()
critic_initializer = mx.init.Normal()
qnet_creator = CNNCreator_reinforcementConfig3.CNNCreator_reinforcementConfig3()
qnet_creator.setWeightInitializer(initializer)
qnet_creator.construct(context)
agent_params = {
......
......@@ -47,9 +47,13 @@ if __name__ == "__main__":
env = reinforcement_learning.environment.GymEnvironment('CartPole-v0')
context = mx.cpu()
initializer = mx.init.Normal()
critic_initializer = mx.init.Normal()
actor_creator = CNNCreator_actorNetwork.CNNCreator_actorNetwork()
actor_creator.setWeightInitializer(initializer)
actor_creator.construct(context)
critic_creator = CNNCreator_CriticNetwork()
critic_creator.setWeightInitializer(critic_initializer)
critic_creator.construct(context)
agent_params = {
......
......@@ -170,6 +170,9 @@ class CNNCreator_CriticNetwork:
for i, network in self.networks.items():
network.export(self._model_dir_ + self._model_prefix_ + "_" + str(i), epoch=0)
def setWeightInitializer(self, initializer):
self.weight_initializer = initializer
def getInputs(self):
inputs = {}
input_dimensions = (8,)
......
......@@ -170,6 +170,9 @@ class CNNCreator_Discriminator:
for i, network in self.networks.items():
network.export(self._model_dir_ + self._model_prefix_ + "_" + str(i), epoch=0)
def setWeightInitializer(self, initializer):
self.weight_initializer = initializer
def getInputs(self):
inputs = {}
input_dimensions = (1,64,64,)
......
......@@ -170,6 +170,9 @@ class CNNCreator_InfoDiscriminator:
for i, network in self.networks.items():
network.export(self._model_dir_ + self._model_prefix_ + "_" + str(i), epoch=0)
def setWeightInitializer(self, initializer):
self.weight_initializer = initializer
def getInputs(self):
inputs = {}
input_dimensions = (1,64,64,)
......
......@@ -170,6 +170,9 @@ class CNNCreator_InfoQNetwork:
for i, network in self.networks.items():
network.export(self._model_dir_ + self._model_prefix_ + "_" + str(i), epoch=0)
def setWeightInitializer(self, initializer):
self.weight_initializer = initializer
def getInputs(self):
inputs = {}
input_dimensions = (1024,)
......
......@@ -54,9 +54,13 @@ if __name__ == "__main__":
env = reinforcement_learning.environment.RosEnvironment(**env_params)
context = mx.cpu()
initializer = mx.init.Normal()
critic_initializer = mx.init.Normal()
actor_creator = CNNCreator_rosActorNetwork.CNNCreator_rosActorNetwork()
actor_creator.setWeightInitializer(initializer)
actor_creator.construct(context)
critic_creator = CNNCreator_RosCriticNetwork()
critic_creator.setWeightInitializer(critic_initializer)
critic_creator.construct(context)
agent_params = {
......
......@@ -170,6 +170,9 @@ class CNNCreator_RosCriticNetwork:
for i, network in self.networks.items():
network.export(self._model_dir_ + self._model_prefix_ + "_" + str(i), epoch=0)
def setWeightInitializer(self, initializer):
self.weight_initializer = initializer
def getInputs(self):
inputs = {}
input_dimensions = (8,)
......
......@@ -47,9 +47,13 @@ if __name__ == "__main__":
env = reinforcement_learning.environment.GymEnvironment('CartPole-v1')
context = mx.cpu()
initializer = mx.init.Normal()
critic_initializer = mx.init.Normal()
actor_creator = CNNCreator_tD3Config.CNNCreator_tD3Config()
actor_creator.setWeightInitializer(initializer)
actor_creator.construct(context)
critic_creator = CNNCreator_CriticNetwork()
critic_creator.setWeightInitializer(critic_initializer)
critic_creator.construct(context)
agent_params = {
......
......@@ -170,6 +170,9 @@ class CNNCreator_CriticNetwork:
for i, network in self.networks.items():
network.export(self._model_dir_ + self._model_prefix_ + "_" + str(i), epoch=0)
def setWeightInitializer(self, initializer):
self.weight_initializer = initializer
def getInputs(self):
inputs = {}
input_dimensions = (8,)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment