Commit b5706cb5 authored by Sascha Dewes's avatar Sascha Dewes
Browse files

bug fixes

parent 2a05ee34
Pipeline #407217 passed with stage
in 1 minute and 21 seconds
......@@ -161,6 +161,16 @@ public class GluonConfigurationData extends ConfigurationData {
? null : retrieveConfigurationEntryValueByKey(LOSS).toString();
}
public Map<String, Object> getInitializer() {
Map<String, Object> initializer = getMultiParamEntry("initializer", "method");
Map<String, Object> actor_initializer = getMultiParamEntry("actor_initializer", "method");
return (initializer != null) ? initializer : actor_initializer;
}
public Map<String, Object> getCriticInitializer() {
return getMultiParamEntry("critic_initializer", "method");
}
public Map<String, Object> getReplayMemory() {
return getMultiParamEntry(REPLAY_MEMORY, "method");
}
......@@ -266,60 +276,6 @@ public class GluonConfigurationData extends ConfigurationData {
return getRlRewardFunctionParameter().get().getOutputParameterName().orElse(null);
}
public String getInitializerName() {
if (getConfiguration().getInitializer() == null) {
return null;
}
return getConfiguration().getInitializer().getName();
}
public Map<String, String> getInitializerParams() {
Map<String, String> mapToStrings = new HashMap<>();
Map<String, InitializerParamSymbol> initializerParams = getConfiguration().getInitializer().getInitializerParamMap();
for (Map.Entry<String, InitializerParamSymbol> entry : initializerParams.entrySet()) {
String paramName = entry.getKey();
String valueAsString = entry.getValue().toString();
Class realClass = entry.getValue().getValue().getValue().getClass();
if (realClass == Boolean.class) {
valueAsString = (Boolean) entry.getValue().getValue().getValue() ? "True" : "False";
}
mapToStrings.put(paramName, valueAsString);
}
if (mapToStrings.isEmpty()) {
return null;
} else {
return mapToStrings;
}
}
public String getCriticInitializerName() {
if (!getConfiguration().getCriticInitializer().isPresent()) {
return null;
}
return getConfiguration().getCriticInitializer().get().getName();
}
public Map<String, String> getCriticInitializerParams() {
assert getConfiguration().getCriticInitializer().isPresent():
"Critic initializer params called although, not present";
Map<String, String> mapToStrings = new HashMap<>();
Map<String, InitializerParamSymbol> initializerParams = getConfiguration().getCriticInitializer().get().getInitializerParamMap();
for (Map.Entry<String, InitializerParamSymbol> entry : initializerParams.entrySet()) {
String paramName = entry.getKey();
String valueAsString = entry.getValue().toString();
Class realClass = entry.getValue().getValue().getValue().getClass();
if (realClass == Boolean.class) {
valueAsString = (Boolean) entry.getValue().getValue().getValue() ? "True" : "False";
}
mapToStrings.put(paramName, valueAsString);
}
if (mapToStrings.isEmpty()) {
return null;
} else {
return mapToStrings;
}
}
public String getCriticOptimizerName() {
if (!getConfiguration().getCriticOptimizer().isPresent()) {
return null;
......
......@@ -83,24 +83,24 @@ if __name__ == "__main__":
<#else>
context = mx.cpu()
</#if>
<#if (config.configuration.initializer)??>
<#if (config.initializer)??>
<#if config.initializer.method=="normal">
initializer_params = {
<#list config.initializerParams?keys as param>
'${param}': ${config.initializerParams[param]}<#sep>,
</#list>
'sigma': ${config.initializer.sigma}
}
initializer = mx.init.${config.initializerName?capitalize}(**initializer_params)
initializer = mx.init.Normal(**initializer_params)
</#if>
<#else>
initializer = mx.init.Normal()
</#if>
<#if config.rlAlgorithm=="ddpg" || config.rlAlgorithm=="td3">
<#if (config.configuration.criticInitializer)??>
<#if (config.criticInitializer)??>
<#if config.criticInitializer.method=="normal">
critic_initializer_params = {
<#list config.criticInitializerParams?keys as param>
'${param}': ${config.criticInitializerParams[param]}<#sep>,
</#list>
'sigma': ${config.criticInitializer.sigma}
}
critic_initializer = mx.init.${config.criticInitializerName?capitalize}(**critic_initializer_params)
critic_initializer = mx.init.Normal(**critic_initializer_params)
</#if>
<#else>
critic_initializer = mx.init.Normal()
</#if>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment