diff --git a/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/ReinforcementConfigurationData.java b/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/ReinforcementConfigurationData.java index 344a4b073a279f0bbd594ed066a8c691efa59902..770af5e48d225f386e15f81830a4497b85561cc7 100644 --- a/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/ReinforcementConfigurationData.java +++ b/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/ReinforcementConfigurationData.java @@ -193,30 +193,28 @@ public class ReinforcementConfigurationData extends ConfigurationData { assert isReinforcementLearning(): "Strategy parameter only for reinforcement learning but called in a " + " non reinforcement learning context"; Map<String, Object> strategyParams = getMultiParamEntry(AST_ENTRY_STRATEGY, "method"); - if (strategyParams.get("method").equals(STRATEGY_ORNSTEIN_UHLENBECK)) { - assert getConfiguration().getTrainedArchitecture().isPresent(): "Architecture not present," + - " but reinforcement training"; - TrainedArchitecture trainedArchitecture = getConfiguration().getTrainedArchitecture().get(); - final String actionPortName = getOutputNameOfTrainedArchitecture(); - Range actionRange = trainedArchitecture.getRanges().get(actionPortName); - - if (actionRange.isLowerLimitInfinity() && actionRange.isUpperLimitInfinity()) { - strategyParams.put("action_low", null); - strategyParams.put("action_high", null); - } else if(!actionRange.isLowerLimitInfinity() && actionRange.isUpperLimitInfinity()) { - assert actionRange.getLowerLimit().isPresent(); - strategyParams.put("action_low", actionRange.getLowerLimit().get()); - strategyParams.put("action_high", null); - } else if (actionRange.isLowerLimitInfinity() && !actionRange.isUpperLimitInfinity()) { - assert actionRange.getUpperLimit().isPresent(); - strategyParams.put("action_low", null); - strategyParams.put("action_high", actionRange.getUpperLimit().get()); - } else { - assert actionRange.getLowerLimit().isPresent(); - assert actionRange.getUpperLimit().isPresent(); - strategyParams.put("action_low", actionRange.getLowerLimit().get()); - strategyParams.put("action_high", actionRange.getUpperLimit().get()); - } + assert getConfiguration().getTrainedArchitecture().isPresent(): "Architecture not present," + + " but reinforcement training"; + TrainedArchitecture trainedArchitecture = getConfiguration().getTrainedArchitecture().get(); + final String actionPortName = getOutputNameOfTrainedArchitecture(); + Range actionRange = trainedArchitecture.getRanges().get(actionPortName); + + if (actionRange.isLowerLimitInfinity() && actionRange.isUpperLimitInfinity()) { + strategyParams.put("action_low", null); + strategyParams.put("action_high", null); + } else if(!actionRange.isLowerLimitInfinity() && actionRange.isUpperLimitInfinity()) { + assert actionRange.getLowerLimit().isPresent(); + strategyParams.put("action_low", actionRange.getLowerLimit().get()); + strategyParams.put("action_high", null); + } else if (actionRange.isLowerLimitInfinity() && !actionRange.isUpperLimitInfinity()) { + assert actionRange.getUpperLimit().isPresent(); + strategyParams.put("action_low", null); + strategyParams.put("action_high", actionRange.getUpperLimit().get()); + } else { + assert actionRange.getLowerLimit().isPresent(); + assert actionRange.getUpperLimit().isPresent(); + strategyParams.put("action_low", actionRange.getLowerLimit().get()); + strategyParams.put("action_high", actionRange.getUpperLimit().get()); } return strategyParams; } diff --git a/src/main/resources/templates/gluon/reinforcement/params/StrategyParams.ftl b/src/main/resources/templates/gluon/reinforcement/params/StrategyParams.ftl index 7792bee473f03c0f98c4955712a5aa359e169e16..e5ef1ee588d2a4255761121965429022a76c96d0 100644 --- a/src/main/resources/templates/gluon/reinforcement/params/StrategyParams.ftl +++ b/src/main/resources/templates/gluon/reinforcement/params/StrategyParams.ftl @@ -19,7 +19,7 @@ <#if (config.strategy.epsilon_decay_per_step)??> 'epsilon_decay_per_step': ${config.strategy.epsilon_decay_per_step?string('True', 'False')}, </#if> -<#if (config.strategy.method)?? && (config.strategy.method=="ornstein_uhlenbeck")> +<#if (config.strategy.method)?? && (config.strategy.method=="ornstein_uhlenbeck" || config.strategy.method=="gaussian")> <#if (config.strategy.action_low)?? > 'action_low': ${config.strategy.action_low}, <#else> @@ -30,6 +30,8 @@ <#else> 'action_high' : np.infty, </#if> +</#if> +<#if (config.strategy.method)?? && (config.strategy.method=="ornstein_uhlenbeck")> <#if (config.strategy.mu)??> 'mu': [<#list config.strategy.mu as m>${m}<#if m?has_next>, </#if></#list>], </#if>