From c2de294e66f44ae7b4c428d8c2ab11e2200ba2ea Mon Sep 17 00:00:00 2001 From: Nicola Gatto Date: Tue, 28 May 2019 19:20:10 +0200 Subject: [PATCH 01/16] Update version --- pom.xml | 2 +- .../_cocos/CheckLearningParameterCombination.java | 3 ++- .../cnntrain/_symboltable/CNNTrainSymbolTableCreator.java | 8 ++++++++ src/test/resources/valid_tests/DdpgConfig.cnnt | 1 + 4 files changed, 12 insertions(+), 2 deletions(-) diff --git a/pom.xml b/pom.xml index 037c85e..b91f77c 100644 --- a/pom.xml +++ b/pom.xml @@ -30,7 +30,7 @@ de.monticore.lang.monticar cnn-train - 0.3.1-SNAPSHOT + 0.3.2-SNAPSHOT diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckLearningParameterCombination.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckLearningParameterCombination.java index 90ea12a..afaa3a1 100644 --- a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckLearningParameterCombination.java +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckLearningParameterCombination.java @@ -104,7 +104,8 @@ public class CheckLearningParameterCombination implements CNNTrainASTEntryCoCo { ASTRosEnvironmentMetaTopicEntry.class, ASTRosEnvironmentResetTopicEntry.class, ASTRosEnvironmentTerminalStateTopicEntry.class, - ASTRosEnvironmentGreetingTopicEntry.class + ASTRosEnvironmentGreetingTopicEntry.class, + ASTSoftTargetUpdateRateEntry.class ); private Set allEntries; diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/CNNTrainSymbolTableCreator.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/CNNTrainSymbolTableCreator.java index 2c45115..fd0eb7d 100644 --- a/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/CNNTrainSymbolTableCreator.java +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/CNNTrainSymbolTableCreator.java @@ -469,6 +469,14 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP { addToScopeAndLinkWithNode(symbol, node); } + @Override + public void visit(ASTSoftTargetUpdateRateEntry node) { + EntrySymbol entry = new EntrySymbol(node.getName()); + entry.setValue(getValueSymbolForDouble(node.getValue())); + addToScopeAndLinkWithNode(entry, node); + configuration.getEntryMap().put(node.getName(), entry); + } + private void processMultiParamConfigVisit(ASTMultiParamConfigEntry node, Object value) { EntrySymbol entry = new EntrySymbol(node.getName()); MultiParamValueSymbol valueSymbol = new MultiParamValueSymbol(); diff --git a/src/test/resources/valid_tests/DdpgConfig.cnnt b/src/test/resources/valid_tests/DdpgConfig.cnnt index 45f3b28..3553f2d 100644 --- a/src/test/resources/valid_tests/DdpgConfig.cnnt +++ b/src/test/resources/valid_tests/DdpgConfig.cnnt @@ -3,4 +3,5 @@ configuration DdpgConfig { rl_algorithm : ddpg-algorithm critic : path.to.component environment : gym { name:"CartPole-v1" } + soft_target_update_rate: 0.001 } \ No newline at end of file -- GitLab From 2a8c59ac1a2bcd5a3ff3af6a38b202fe2b7cd0d5 Mon Sep 17 00:00:00 2001 From: Nicola Gatto Date: Tue, 28 May 2019 19:22:52 +0200 Subject: [PATCH 02/16] Addd soft target update parameter --- .../de/monticore/lang/monticar/CNNTrain.mc4 | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/src/main/grammars/de/monticore/lang/monticar/CNNTrain.mc4 b/src/main/grammars/de/monticore/lang/monticar/CNNTrain.mc4 index 7409708..3ea8c37 100644 --- a/src/main/grammars/de/monticore/lang/monticar/CNNTrain.mc4 +++ b/src/main/grammars/de/monticore/lang/monticar/CNNTrain.mc4 @@ -104,13 +104,9 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number NumMaxStepsEntry implements ConfigEntry = name:"num_max_steps" ":" value:IntegerValue; TargetScoreEntry implements ConfigEntry = name:"target_score" ":" value:NumberValue; TrainingIntervalEntry implements ConfigEntry = name:"training_interval" ":" value:IntegerValue; - UseFixTargetNetworkEntry implements ConfigEntry = name:"use_fix_target_network" ":" value:BooleanValue; - TargetNetworkUpdateIntervalEntry implements ConfigEntry = name:"target_network_update_interval" ":" value:IntegerValue; SnapshotIntervalEntry implements ConfigEntry = name:"snapshot_interval" ":" value:IntegerValue; AgentNameEntry implements ConfigEntry = name:"agent_name" ":" value:StringValue; - UseDoubleDQNEntry implements ConfigEntry = name:"use_double_dqn" ":" value:BooleanValue; RewardFunctionEntry implements ConfigEntry = name:"reward_function" ":" value:ComponentNameValue; - CriticNetworkEntry implements ConfigEntry = name:"critic" ":" value:ComponentNameValue; ComponentNameValue implements ConfigValue = Name ("."Name)*; @@ -166,4 +162,14 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number RosEnvironmentGreetingTopicEntry implements RosEnvironmentEntry = name:"greeting_topic" ":" value:StringValue; RosEnvironmentMetaTopicEntry implements RosEnvironmentEntry = name:"meta_topic" ":" value:StringValue; RosEnvironmentTerminalStateTopicEntry implements RosEnvironmentEntry = name:"terminal_state_topic" ":" value:StringValue; + + // DQN exclusive parameters + UseFixTargetNetworkEntry implements ConfigEntry = name:"use_fix_target_network" ":" value:BooleanValue; + TargetNetworkUpdateIntervalEntry implements ConfigEntry = name:"target_network_update_interval" ":" value:IntegerValue; + UseDoubleDQNEntry implements ConfigEntry = name:"use_double_dqn" ":" value:BooleanValue; + + + // DDPG exclusive parameters + CriticNetworkEntry implements ConfigEntry = name:"critic" ":" value:ComponentNameValue; + SoftTargetUpdateRateEntry implements ConfigEntry = name:"soft_target_update_rate" ":" value:NumberValue; } \ No newline at end of file -- GitLab From 8451a41f1bb91833105b3830e70abb0a405cae21 Mon Sep 17 00:00:00 2001 From: Nicola Gatto Date: Tue, 28 May 2019 22:02:57 +0200 Subject: [PATCH 03/16] Refactor allowed parameter combination --- .../_cocos/ASTConfigurationUtils.java | 18 ++- .../CheckLearningParameterCombination.java | 102 +++---------- .../_cocos/ParameterAlgorithmMapping.java | 141 ++++++++++++++++++ 3 files changed, 173 insertions(+), 88 deletions(-) create mode 100644 src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ParameterAlgorithmMapping.java diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ASTConfigurationUtils.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ASTConfigurationUtils.java index a0d4ed9..1fb1f6a 100644 --- a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ASTConfigurationUtils.java +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ASTConfigurationUtils.java @@ -20,9 +20,7 @@ */ package de.monticore.lang.monticar.cnntrain._cocos; -import de.monticore.lang.monticar.cnntrain._ast.ASTConfiguration; -import de.monticore.lang.monticar.cnntrain._ast.ASTEnvironmentEntry; -import de.monticore.lang.monticar.cnntrain._ast.ASTLearningMethodEntry; +import de.monticore.lang.monticar.cnntrain._ast.*; class ASTConfigurationUtils { static boolean isReinforcementLearning(final ASTConfiguration configuration) { @@ -34,4 +32,18 @@ class ASTConfigurationUtils { static boolean hasEnvironment(final ASTConfiguration configuration) { return configuration.getEntriesList().stream().anyMatch(e -> e instanceof ASTEnvironmentEntry); } + + static boolean isDdpgAlgorithm(final ASTConfiguration configuration) { + return isReinforcementLearning(configuration) + && configuration.getEntriesList().stream().anyMatch( + e -> (e instanceof ASTRLAlgorithmEntry) && ((ASTRLAlgorithmEntry)e).getValue().isPresentDdpg()); + } + + static boolean isDqnAlgorithm(final ASTConfiguration configuration) { + return isReinforcementLearning(configuration) && !isDdpgAlgorithm(configuration); + } + + static boolean hasEntry(final ASTConfiguration configuration, final Class entryClazz) { + return configuration.getEntriesList().stream().anyMatch(entryClazz::isInstance); + } } diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckLearningParameterCombination.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckLearningParameterCombination.java index afaa3a1..63c11b4 100644 --- a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckLearningParameterCombination.java +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckLearningParameterCombination.java @@ -34,79 +34,7 @@ import java.util.Set; * */ public class CheckLearningParameterCombination implements CNNTrainASTEntryCoCo { - private final static List ALLOWED_SUPERVISED_LEARNING = Lists.newArrayList( - ASTTrainContextEntry.class, - ASTBatchSizeEntry.class, - ASTOptimizerEntry.class, - ASTLearningRateEntry.class, - ASTLoadCheckpointEntry.class, - ASTEvalMetricEntry.class, - ASTLossEntry.class, - ASTNormalizeEntry.class, - ASTMinimumLearningRateEntry.class, - ASTLRDecayEntry.class, - ASTWeightDecayEntry.class, - ASTLRPolicyEntry.class, - ASTStepSizeEntry.class, - ASTRescaleGradEntry.class, - ASTClipGradEntry.class, - ASTGamma1Entry.class, - ASTGamma2Entry.class, - ASTEpsilonEntry.class, - ASTCenteredEntry.class, - ASTClipWeightsEntry.class, - ASTBeta1Entry.class, - ASTBeta2Entry.class, - ASTNumEpochEntry.class - ); - private final static List ALLOWED_REINFORCEMENT_LEARNING = Lists.newArrayList( - ASTTrainContextEntry.class, - ASTRLAlgorithmEntry.class, - ASTCriticNetworkEntry.class, - ASTOptimizerEntry.class, - ASTRewardFunctionEntry.class, - ASTMinimumLearningRateEntry.class, - ASTLRDecayEntry.class, - ASTWeightDecayEntry.class, - ASTLRPolicyEntry.class, - ASTGamma1Entry.class, - ASTGamma2Entry.class, - ASTEpsilonEntry.class, - ASTClipGradEntry.class, - ASTRescaleGradEntry.class, - ASTStepSizeEntry.class, - ASTCenteredEntry.class, - ASTClipWeightsEntry.class, - ASTLearningRateEntry.class, - ASTDiscountFactorEntry.class, - ASTNumMaxStepsEntry.class, - ASTTargetScoreEntry.class, - ASTTrainingIntervalEntry.class, - ASTUseFixTargetNetworkEntry.class, - ASTTargetNetworkUpdateIntervalEntry.class, - ASTSnapshotIntervalEntry.class, - ASTAgentNameEntry.class, - ASTGymEnvironmentNameEntry.class, - ASTEnvironmentEntry.class, - ASTUseDoubleDQNEntry.class, - ASTLossEntry.class, - ASTReplayMemoryEntry.class, - ASTMemorySizeEntry.class, - ASTSampleSizeEntry.class, - ASTActionSelectionEntry.class, - ASTGreedyEpsilonEntry.class, - ASTMinEpsilonEntry.class, - ASTEpsilonDecayEntry.class, - ASTEpsilonDecayMethodEntry.class, - ASTNumEpisodesEntry.class, - ASTRosEnvironmentActionTopicEntry.class, - ASTRosEnvironmentStateTopicEntry.class, - ASTRosEnvironmentMetaTopicEntry.class, - ASTRosEnvironmentResetTopicEntry.class, - ASTRosEnvironmentTerminalStateTopicEntry.class, - ASTRosEnvironmentGreetingTopicEntry.class, - ASTSoftTargetUpdateRateEntry.class - ); + private final ParameterAlgorithmMapping parameterAlgorithmMapping; private Set allEntries; @@ -114,12 +42,13 @@ public class CheckLearningParameterCombination implements CNNTrainASTEntryCoCo { private LearningMethod learningMethod; public CheckLearningParameterCombination() { - this.allEntries = new HashSet<>(); - this.learningMethodKnown = false; + allEntries = new HashSet<>(); + learningMethodKnown = false; + parameterAlgorithmMapping = new ParameterAlgorithmMapping(); } private Boolean isLearningMethodKnown() { - return this.learningMethodKnown; + return learningMethodKnown; } @Override @@ -133,18 +62,20 @@ public class CheckLearningParameterCombination implements CNNTrainASTEntryCoCo { private void evaluateEntry(ASTEntry node) { allEntries.add(node); - final Boolean supervisedLearningParameter = ALLOWED_SUPERVISED_LEARNING.contains(node.getClass()); - final Boolean reinforcementLearningParameter = ALLOWED_REINFORCEMENT_LEARNING.contains(node.getClass()); + final boolean supervisedLearningParameter + = parameterAlgorithmMapping.isSupervisedLearningParameter(node.getClass()); + final boolean reinforcementLearningParameter + = parameterAlgorithmMapping.isReinforcementLearningParameter(node.getClass()); - assert (supervisedLearningParameter || reinforcementLearningParameter) : + assert (supervisedLearningParameter || reinforcementLearningParameter) : "Parameter " + node.getName() + " is not checkable, because it is unknown to Condition"; - if (supervisedLearningParameter && reinforcementLearningParameter) { - return; - } else if (supervisedLearningParameter && !reinforcementLearningParameter) { + + if (supervisedLearningParameter && !reinforcementLearningParameter) { setLearningMethodOrLogErrorIfActualLearningMethodIsNotSupervised(node); - } else if (!supervisedLearningParameter && reinforcementLearningParameter) { + } else if(!supervisedLearningParameter) { setLearningMethodOrLogErrorIfActualLearningMethodIsNotReinforcement(node); } + } private void setLearningMethodOrLogErrorIfActualLearningMethodIsNotReinforcement(ASTEntry node) { @@ -204,11 +135,12 @@ public class CheckLearningParameterCombination implements CNNTrainASTEntryCoCo { private List getAllowedParametersByLearningMethod(final LearningMethod learningMethod) { if (learningMethod.equals(LearningMethod.REINFORCEMENT)) { - return ALLOWED_REINFORCEMENT_LEARNING; + return parameterAlgorithmMapping.getAllReinforcementParameters(); } - return ALLOWED_SUPERVISED_LEARNING; + return parameterAlgorithmMapping.getAllSupervisedParameters(); } + private void setLearningMethod(final LearningMethod learningMethod) { if (learningMethod.equals(LearningMethod.REINFORCEMENT)) { setLearningMethodToReinforcement(); diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ParameterAlgorithmMapping.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ParameterAlgorithmMapping.java new file mode 100644 index 0000000..cee3d1d --- /dev/null +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ParameterAlgorithmMapping.java @@ -0,0 +1,141 @@ +/** + * + * ****************************************************************************** + * MontiCAR Modeling Family, www.se-rwth.de + * Copyright (c) 2017, Software Engineering Group at RWTH Aachen, + * All rights reserved. + * + * This project is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 3.0 of the License, or (at your option) any later version. + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this project. If not, see . + * ******************************************************************************* + */ +package de.monticore.lang.monticar.cnntrain._cocos; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Lists; +import de.monticore.lang.monticar.cnntrain._ast.*; + +import java.util.List; + +public class ParameterAlgorithmMapping { + private static final List GENERAL_PARAMETERS = Lists.newArrayList( + ASTTrainContextEntry.class, + ASTOptimizerEntry.class, + ASTLearningRateEntry.class, + ASTMinimumLearningRateEntry.class, + ASTLRDecayEntry.class, + ASTWeightDecayEntry.class, + ASTLRPolicyEntry.class, + ASTStepSizeEntry.class, + ASTRescaleGradEntry.class, + ASTClipGradEntry.class, + ASTGamma1Entry.class, + ASTGamma2Entry.class, + ASTEpsilonEntry.class, + ASTCenteredEntry.class, + ASTClipWeightsEntry.class, + ASTBeta1Entry.class, + ASTBeta2Entry.class + ); + + private static final List EXCLUSIVE_SUPERVISED_PARAMETERS = Lists.newArrayList( + ASTBatchSizeEntry.class, + ASTLoadCheckpointEntry.class, + ASTEvalMetricEntry.class, + ASTNormalizeEntry.class, + ASTNumEpochEntry.class, + ASTLossEntry.class + ); + + private static final List GENERAL_REINFORCEMENT_PARAMETERS = Lists.newArrayList( + ASTRLAlgorithmEntry.class, + ASTRewardFunctionEntry.class, + ASTDiscountFactorEntry.class, + ASTNumMaxStepsEntry.class, + ASTTargetScoreEntry.class, + ASTTrainingIntervalEntry.class, + ASTAgentNameEntry.class, + ASTGymEnvironmentNameEntry.class, + ASTEnvironmentEntry.class, + ASTReplayMemoryEntry.class, + ASTMemorySizeEntry.class, + ASTSampleSizeEntry.class, + ASTActionSelectionEntry.class, + ASTGreedyEpsilonEntry.class, + ASTMinEpsilonEntry.class, + ASTEpsilonDecayEntry.class, + ASTEpsilonDecayMethodEntry.class, + ASTNumEpisodesEntry.class, + ASTRosEnvironmentActionTopicEntry.class, + ASTRosEnvironmentStateTopicEntry.class, + ASTRosEnvironmentMetaTopicEntry.class, + ASTRosEnvironmentResetTopicEntry.class, + ASTRosEnvironmentTerminalStateTopicEntry.class, + ASTRosEnvironmentGreetingTopicEntry.class + ); + + private static final List EXCLUSIVE_DQN_PARAMETERS = Lists.newArrayList( + ASTUseFixTargetNetworkEntry.class, + ASTTargetNetworkUpdateIntervalEntry.class, + ASTUseDoubleDQNEntry.class, + ASTLossEntry.class + ); + + private static final List EXCLUSIVE_DDPG_PARAMETERS = Lists.newArrayList( + ASTCriticNetworkEntry.class, + ASTSoftTargetUpdateRateEntry.class + ); + + ParameterAlgorithmMapping() { + + } + + boolean isReinforcementLearningParameter(Class entryClazz) { + return GENERAL_PARAMETERS.contains(entryClazz) + || GENERAL_REINFORCEMENT_PARAMETERS.contains(entryClazz) + || EXCLUSIVE_DQN_PARAMETERS.contains(entryClazz) + || EXCLUSIVE_DDPG_PARAMETERS.contains(entryClazz); + } + + boolean isSupervisedLearningParameter(Class entryClazz) { + return GENERAL_PARAMETERS.contains(entryClazz) + || EXCLUSIVE_SUPERVISED_PARAMETERS.contains(entryClazz); + } + + boolean isDqnParameter(Class entryClazz) { + return GENERAL_PARAMETERS.contains(entryClazz) + || GENERAL_REINFORCEMENT_PARAMETERS.contains(entryClazz) + || EXCLUSIVE_DQN_PARAMETERS.contains(entryClazz); + } + + boolean isDdpgParameter(Class entryClazz) { + return GENERAL_PARAMETERS.contains(entryClazz) + || GENERAL_REINFORCEMENT_PARAMETERS.contains(entryClazz) + || EXCLUSIVE_DDPG_PARAMETERS.contains(entryClazz); + } + + List getAllReinforcementParameters() { + return ImmutableList. builder() + .addAll(GENERAL_PARAMETERS) + .addAll(GENERAL_REINFORCEMENT_PARAMETERS) + .addAll(EXCLUSIVE_DQN_PARAMETERS) + .addAll(EXCLUSIVE_DDPG_PARAMETERS) + .build(); + } + + List getAllSupervisedParameters() { + return ImmutableList. builder() + .addAll(GENERAL_PARAMETERS) + .addAll(EXCLUSIVE_SUPERVISED_PARAMETERS) + .build(); + } +} \ No newline at end of file -- GitLab From 75b5d5d4138cbf29b32a32bf88e1c99fde3eb6c8 Mon Sep 17 00:00:00 2001 From: Nicola Gatto Date: Tue, 28 May 2019 23:20:22 +0200 Subject: [PATCH 04/16] Implement CoCo check for rl algorithm parameter --- .../cnntrain/_cocos/CNNTrainCocos.java | 3 +- .../_cocos/CheckRlAlgorithmParameter.java | 90 +++++++++++++++++++ .../monticar/cnntrain/cocos/AllCoCoTest.java | 9 ++ .../CheckRLAlgorithmParameter1.cnnt | 7 ++ .../CheckRLAlgorithmParameter2.cnnt | 9 ++ .../CheckRLAlgorithmParameter3.cnnt | 19 ++++ 6 files changed, 136 insertions(+), 1 deletion(-) create mode 100644 src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckRlAlgorithmParameter.java create mode 100644 src/test/resources/invalid_cocos_tests/CheckRLAlgorithmParameter1.cnnt create mode 100644 src/test/resources/invalid_cocos_tests/CheckRLAlgorithmParameter2.cnnt create mode 100644 src/test/resources/invalid_cocos_tests/CheckRLAlgorithmParameter3.cnnt diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CNNTrainCocos.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CNNTrainCocos.java index 3d04980..35e3d7b 100644 --- a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CNNTrainCocos.java +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CNNTrainCocos.java @@ -34,7 +34,8 @@ public class CNNTrainCocos { .addCoCo(new CheckReinforcementRequiresEnvironment()) .addCoCo(new CheckLearningParameterCombination()) .addCoCo(new CheckRosEnvironmentRequiresRewardFunction()) - .addCoCo(new CheckDdpgRequiresCriticNetwork()); + .addCoCo(new CheckDdpgRequiresCriticNetwork()) + .addCoCo(new CheckRlAlgorithmParameter()); } public static void checkAll(CNNTrainCompilationUnitSymbol compilationUnit){ diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckRlAlgorithmParameter.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckRlAlgorithmParameter.java new file mode 100644 index 0000000..85d0bd1 --- /dev/null +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckRlAlgorithmParameter.java @@ -0,0 +1,90 @@ +/** + * + * ****************************************************************************** + * MontiCAR Modeling Family, www.se-rwth.de + * Copyright (c) 2017, Software Engineering Group at RWTH Aachen, + * All rights reserved. + * + * This project is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 3.0 of the License, or (at your option) any later version. + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this project. If not, see . + * ******************************************************************************* + */ +package de.monticore.lang.monticar.cnntrain._cocos; + +import de.monticore.lang.monticar.cnntrain._ast.ASTEntry; +import de.monticore.lang.monticar.cnntrain._ast.ASTRLAlgorithmEntry; +import de.monticore.lang.monticar.cnntrain._symboltable.RLAlgorithm; +import de.monticore.lang.monticar.cnntrain.helper.ErrorCodes; +import de.se_rwth.commons.logging.Log; + +public class CheckRlAlgorithmParameter implements CNNTrainASTEntryCoCo { + private final ParameterAlgorithmMapping parameterAlgorithmMapping; + + boolean algorithmKnown; + RLAlgorithm algorithm; + + public CheckRlAlgorithmParameter() { + parameterAlgorithmMapping = new ParameterAlgorithmMapping(); + algorithmKnown = false; + algorithm = null; + } + + + @Override + public void check(ASTEntry node) { + final boolean isDdpgParameter = parameterAlgorithmMapping.isDdpgParameter(node.getClass()); + final boolean isDqnParameter = parameterAlgorithmMapping.isDqnParameter(node.getClass()); + + if (node instanceof ASTRLAlgorithmEntry) { + ASTRLAlgorithmEntry algorithmEntry = (ASTRLAlgorithmEntry)node; + if (algorithmEntry.getValue().isPresentDdpg()) { + setAlgorithmToDdpg(node); + } else { + setAlgorithmToDqn(node); + } + } else { + if (isDdpgParameter && !isDqnParameter) { + setAlgorithmToDdpg(node); + } else if (!isDdpgParameter && isDqnParameter) { + setAlgorithmToDqn(node); + } + } + } + + private void logErrorIfAlgorithmIsDqn(final ASTEntry node) { + if (algorithmKnown && algorithm.equals(RLAlgorithm.DQN)) { + Log.error("0" + ErrorCodes.UNSUPPORTED_PARAMETER + + " DDPG Parameter " + node.getName() + " used but algorithm is " + algorithm + ".", + node.get_SourcePositionStart()); + } + } + + private void setAlgorithmToDdpg(final ASTEntry node) { + logErrorIfAlgorithmIsDqn(node); + algorithmKnown = true; + algorithm = RLAlgorithm.DDPG; + } + + private void setAlgorithmToDqn(final ASTEntry node) { + logErrorIfAlgorithmIsDdpg(node); + algorithmKnown = true; + algorithm = RLAlgorithm.DQN; + } + + private void logErrorIfAlgorithmIsDdpg(final ASTEntry node) { + if (algorithmKnown && algorithm.equals(RLAlgorithm.DDPG)) { + Log.error("0" + ErrorCodes.UNSUPPORTED_PARAMETER + + " DQN Parameter " + node.getName() + " used but algorithm is " + algorithm + ".", + node.get_SourcePositionStart()); + } + } +} \ No newline at end of file diff --git a/src/test/java/de/monticore/lang/monticar/cnntrain/cocos/AllCoCoTest.java b/src/test/java/de/monticore/lang/monticar/cnntrain/cocos/AllCoCoTest.java index 57cf32f..251058e 100644 --- a/src/test/java/de/monticore/lang/monticar/cnntrain/cocos/AllCoCoTest.java +++ b/src/test/java/de/monticore/lang/monticar/cnntrain/cocos/AllCoCoTest.java @@ -76,5 +76,14 @@ public class AllCoCoTest extends AbstractCoCoTest{ checkInvalid(new CNNTrainCoCoChecker().addCoCo(new CheckRosEnvironmentRequiresRewardFunction()), "invalid_cocos_tests", "CheckRosEnvironmentRequiresRewardFunction", new ExpectedErrorInfo(1, ErrorCodes.REQUIRED_PARAMETER_MISSING)); + checkInvalid(new CNNTrainCoCoChecker().addCoCo(new CheckRlAlgorithmParameter()), + "invalid_cocos_tests", "CheckRLAlgorithmParameter1", + new ExpectedErrorInfo(1, ErrorCodes.UNSUPPORTED_PARAMETER)); + checkInvalid(new CNNTrainCoCoChecker().addCoCo(new CheckRlAlgorithmParameter()), + "invalid_cocos_tests", "CheckRLAlgorithmParameter2", + new ExpectedErrorInfo(1, ErrorCodes.UNSUPPORTED_PARAMETER)); + checkInvalid(new CNNTrainCoCoChecker().addCoCo(new CheckRlAlgorithmParameter()), + "invalid_cocos_tests", "CheckRLAlgorithmParameter3", + new ExpectedErrorInfo(1, ErrorCodes.UNSUPPORTED_PARAMETER)); } } diff --git a/src/test/resources/invalid_cocos_tests/CheckRLAlgorithmParameter1.cnnt b/src/test/resources/invalid_cocos_tests/CheckRLAlgorithmParameter1.cnnt new file mode 100644 index 0000000..b38eb8e --- /dev/null +++ b/src/test/resources/invalid_cocos_tests/CheckRLAlgorithmParameter1.cnnt @@ -0,0 +1,7 @@ +configuration CheckRLAlgorithmParameter1 { + learning_method : reinforcement + critic : path.to.component + environment : gym { name:"CartPole-v1" } + soft_target_update_rate: 0.001 + use_double_dqn: true +} \ No newline at end of file diff --git a/src/test/resources/invalid_cocos_tests/CheckRLAlgorithmParameter2.cnnt b/src/test/resources/invalid_cocos_tests/CheckRLAlgorithmParameter2.cnnt new file mode 100644 index 0000000..89a14c4 --- /dev/null +++ b/src/test/resources/invalid_cocos_tests/CheckRLAlgorithmParameter2.cnnt @@ -0,0 +1,9 @@ +configuration CheckRLAlgorithmParameter2 { + learning_method : reinforcement + rl_algorithm : ddpg-algorithm + critic : path.to.component + environment : gym { name:"CartPole-v1" } + soft_target_update_rate: 0.001 + target_network_update_interval: 400 + use_fix_target_network: true +} \ No newline at end of file diff --git a/src/test/resources/invalid_cocos_tests/CheckRLAlgorithmParameter3.cnnt b/src/test/resources/invalid_cocos_tests/CheckRLAlgorithmParameter3.cnnt new file mode 100644 index 0000000..711c895 --- /dev/null +++ b/src/test/resources/invalid_cocos_tests/CheckRLAlgorithmParameter3.cnnt @@ -0,0 +1,19 @@ +configuration CheckRLAlgorithmParameter3 { + learning_method : reinforcement + + rl_algorithm: dqn-algorithm + + agent_name : "reinforcement-agent" + + environment : gym { name:"CartPole-v1" } + + context : cpu + + num_episodes : 300 + num_max_steps : 9999 + discount_factor : 0.998 + target_score : 1000 + training_interval : 10 + + soft_target_update_rate: 0.001 +} \ No newline at end of file -- GitLab From 0bb9994e1c4af2323214d3155a38b5a213db7a95 Mon Sep 17 00:00:00 2001 From: Nicola Gatto Date: Tue, 28 May 2019 23:32:08 +0200 Subject: [PATCH 05/16] Add start training at parameter --- src/main/grammars/de/monticore/lang/monticar/CNNTrain.mc4 | 1 + .../cnntrain/_cocos/ParameterAlgorithmMapping.java | 3 ++- .../cnntrain/_symboltable/CNNTrainSymbolTableCreator.java | 8 ++++++++ src/test/resources/valid_tests/ReinforcementConfig.cnnt | 2 ++ 4 files changed, 13 insertions(+), 1 deletion(-) diff --git a/src/main/grammars/de/monticore/lang/monticar/CNNTrain.mc4 b/src/main/grammars/de/monticore/lang/monticar/CNNTrain.mc4 index 3ea8c37..a722507 100644 --- a/src/main/grammars/de/monticore/lang/monticar/CNNTrain.mc4 +++ b/src/main/grammars/de/monticore/lang/monticar/CNNTrain.mc4 @@ -107,6 +107,7 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number SnapshotIntervalEntry implements ConfigEntry = name:"snapshot_interval" ":" value:IntegerValue; AgentNameEntry implements ConfigEntry = name:"agent_name" ":" value:StringValue; RewardFunctionEntry implements ConfigEntry = name:"reward_function" ":" value:ComponentNameValue; + StartTrainingAtEntry implements ConfigEntry = name:"start_training_at" ":" value:IntegerValue; ComponentNameValue implements ConfigValue = Name ("."Name)*; diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ParameterAlgorithmMapping.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ParameterAlgorithmMapping.java index cee3d1d..eff843f 100644 --- a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ParameterAlgorithmMapping.java +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ParameterAlgorithmMapping.java @@ -80,7 +80,8 @@ public class ParameterAlgorithmMapping { ASTRosEnvironmentMetaTopicEntry.class, ASTRosEnvironmentResetTopicEntry.class, ASTRosEnvironmentTerminalStateTopicEntry.class, - ASTRosEnvironmentGreetingTopicEntry.class + ASTRosEnvironmentGreetingTopicEntry.class, + ASTStartTrainingAtEntry.class ); private static final List EXCLUSIVE_DQN_PARAMETERS = Lists.newArrayList( diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/CNNTrainSymbolTableCreator.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/CNNTrainSymbolTableCreator.java index fd0eb7d..50af059 100644 --- a/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/CNNTrainSymbolTableCreator.java +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/CNNTrainSymbolTableCreator.java @@ -477,6 +477,14 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP { configuration.getEntryMap().put(node.getName(), entry); } + @Override + public void visit(ASTStartTrainingAtEntry node) { + EntrySymbol entry = new EntrySymbol(node.getName()); + entry.setValue(getValueSymbolForInteger(node.getValue())); + addToScopeAndLinkWithNode(entry, node); + configuration.getEntryMap().put(node.getName(), entry); + } + private void processMultiParamConfigVisit(ASTMultiParamConfigEntry node, Object value) { EntrySymbol entry = new EntrySymbol(node.getName()); MultiParamValueSymbol valueSymbol = new MultiParamValueSymbol(); diff --git a/src/test/resources/valid_tests/ReinforcementConfig.cnnt b/src/test/resources/valid_tests/ReinforcementConfig.cnnt index 3f03ff5..9b6e366 100644 --- a/src/test/resources/valid_tests/ReinforcementConfig.cnnt +++ b/src/test/resources/valid_tests/ReinforcementConfig.cnnt @@ -13,6 +13,8 @@ configuration ReinforcementConfig { target_score : 1000 training_interval : 10 + start_training_at : 20 + loss : huber_loss use_fix_target_network : true -- GitLab From 64857312934084f2fed5a3dfb6a91bc7b8e87237 Mon Sep 17 00:00:00 2001 From: Nicola Gatto Date: Wed, 29 May 2019 00:01:22 +0200 Subject: [PATCH 06/16] Add evaluation_samples parameter --- src/main/grammars/de/monticore/lang/monticar/CNNTrain.mc4 | 1 + .../cnntrain/_cocos/ParameterAlgorithmMapping.java | 3 ++- .../cnntrain/_symboltable/CNNTrainSymbolTableCreator.java | 8 ++++++++ src/test/resources/valid_tests/ReinforcementConfig.cnnt | 2 ++ 4 files changed, 13 insertions(+), 1 deletion(-) diff --git a/src/main/grammars/de/monticore/lang/monticar/CNNTrain.mc4 b/src/main/grammars/de/monticore/lang/monticar/CNNTrain.mc4 index a722507..ebcec5b 100644 --- a/src/main/grammars/de/monticore/lang/monticar/CNNTrain.mc4 +++ b/src/main/grammars/de/monticore/lang/monticar/CNNTrain.mc4 @@ -108,6 +108,7 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number AgentNameEntry implements ConfigEntry = name:"agent_name" ":" value:StringValue; RewardFunctionEntry implements ConfigEntry = name:"reward_function" ":" value:ComponentNameValue; StartTrainingAtEntry implements ConfigEntry = name:"start_training_at" ":" value:IntegerValue; + EvaluationSamplesEntry implements ConfigEntry = name:"evaluation_samples" ":" value:IntegerValue; ComponentNameValue implements ConfigValue = Name ("."Name)*; diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ParameterAlgorithmMapping.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ParameterAlgorithmMapping.java index eff843f..ade8318 100644 --- a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ParameterAlgorithmMapping.java +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ParameterAlgorithmMapping.java @@ -81,7 +81,8 @@ public class ParameterAlgorithmMapping { ASTRosEnvironmentResetTopicEntry.class, ASTRosEnvironmentTerminalStateTopicEntry.class, ASTRosEnvironmentGreetingTopicEntry.class, - ASTStartTrainingAtEntry.class + ASTStartTrainingAtEntry.class, + ASTEvaluationSamplesEntry.class ); private static final List EXCLUSIVE_DQN_PARAMETERS = Lists.newArrayList( diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/CNNTrainSymbolTableCreator.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/CNNTrainSymbolTableCreator.java index 50af059..b3172ef 100644 --- a/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/CNNTrainSymbolTableCreator.java +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/CNNTrainSymbolTableCreator.java @@ -485,6 +485,14 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP { configuration.getEntryMap().put(node.getName(), entry); } + @Override + public void visit(ASTEvaluationSamplesEntry node) { + EntrySymbol entry = new EntrySymbol(node.getName()); + entry.setValue(getValueSymbolForInteger(node.getValue())); + addToScopeAndLinkWithNode(entry, node); + configuration.getEntryMap().put(node.getName(), entry); + } + private void processMultiParamConfigVisit(ASTMultiParamConfigEntry node, Object value) { EntrySymbol entry = new EntrySymbol(node.getName()); MultiParamValueSymbol valueSymbol = new MultiParamValueSymbol(); diff --git a/src/test/resources/valid_tests/ReinforcementConfig.cnnt b/src/test/resources/valid_tests/ReinforcementConfig.cnnt index 9b6e366..1abe572 100644 --- a/src/test/resources/valid_tests/ReinforcementConfig.cnnt +++ b/src/test/resources/valid_tests/ReinforcementConfig.cnnt @@ -15,6 +15,8 @@ configuration ReinforcementConfig { start_training_at : 20 + evaluation_samples: 100 + loss : huber_loss use_fix_target_network : true -- GitLab From b190679cd1e74b12b2845e7ca08540d25fb8a183 Mon Sep 17 00:00:00 2001 From: Nicola Gatto Date: Wed, 29 May 2019 00:29:42 +0200 Subject: [PATCH 07/16] Add epsilon_decay_start parameter --- src/main/grammars/de/monticore/lang/monticar/CNNTrain.mc4 | 1 + .../monticar/cnntrain/_cocos/ParameterAlgorithmMapping.java | 3 ++- src/test/resources/valid_tests/ReinforcementConfig.cnnt | 1 + 3 files changed, 4 insertions(+), 1 deletion(-) diff --git a/src/main/grammars/de/monticore/lang/monticar/CNNTrain.mc4 b/src/main/grammars/de/monticore/lang/monticar/CNNTrain.mc4 index ebcec5b..4584f19 100644 --- a/src/main/grammars/de/monticore/lang/monticar/CNNTrain.mc4 +++ b/src/main/grammars/de/monticore/lang/monticar/CNNTrain.mc4 @@ -144,6 +144,7 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number GreedyEpsilonEntry implements ActionSelectionEpsGreedyEntry = name:"epsilon" ":" value:NumberValue; MinEpsilonEntry implements ActionSelectionEpsGreedyEntry = name:"min_epsilon" ":" value:NumberValue; + EpsilonDecayStartEntry implements ActionSelectionEpsGreedyEntry = name:"epsilon_decay_start" ":" value:IntegerValue; EpsilonDecayMethodEntry implements ActionSelectionEpsGreedyEntry = name:"epsilon_decay_method" ":" value:EpsilonDecayMethodValue; EpsilonDecayMethodValue implements ConfigValue = (linear:"linear" | no:"no"); EpsilonDecayEntry implements ActionSelectionEpsGreedyEntry = name:"epsilon_decay" ":" value:NumberValue; diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ParameterAlgorithmMapping.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ParameterAlgorithmMapping.java index ade8318..ab07ca0 100644 --- a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ParameterAlgorithmMapping.java +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ParameterAlgorithmMapping.java @@ -82,7 +82,8 @@ public class ParameterAlgorithmMapping { ASTRosEnvironmentTerminalStateTopicEntry.class, ASTRosEnvironmentGreetingTopicEntry.class, ASTStartTrainingAtEntry.class, - ASTEvaluationSamplesEntry.class + ASTEvaluationSamplesEntry.class, + ASTEpsilonDecayStartEntry.class ); private static final List EXCLUSIVE_DQN_PARAMETERS = Lists.newArrayList( diff --git a/src/test/resources/valid_tests/ReinforcementConfig.cnnt b/src/test/resources/valid_tests/ReinforcementConfig.cnnt index 1abe572..02ab5fa 100644 --- a/src/test/resources/valid_tests/ReinforcementConfig.cnnt +++ b/src/test/resources/valid_tests/ReinforcementConfig.cnnt @@ -34,6 +34,7 @@ configuration ReinforcementConfig { min_epsilon : 0.01 epsilon_decay_method: linear epsilon_decay : 0.0001 + epsilon_decay_start: 50 } optimizer : rmsprop{ -- GitLab From f1a3f4b67bbd72dc79076e5a1c570a040e1439b4 Mon Sep 17 00:00:00 2001 From: Nicola Gatto Date: Wed, 29 May 2019 13:10:51 +0200 Subject: [PATCH 08/16] Add snapshot interval to parameter list --- .../monticar/cnntrain/_cocos/ParameterAlgorithmMapping.java | 3 ++- src/test/resources/valid_tests/ReinforcementConfig.cnnt | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ParameterAlgorithmMapping.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ParameterAlgorithmMapping.java index ab07ca0..2e35ac7 100644 --- a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ParameterAlgorithmMapping.java +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ParameterAlgorithmMapping.java @@ -83,7 +83,8 @@ public class ParameterAlgorithmMapping { ASTRosEnvironmentGreetingTopicEntry.class, ASTStartTrainingAtEntry.class, ASTEvaluationSamplesEntry.class, - ASTEpsilonDecayStartEntry.class + ASTEpsilonDecayStartEntry.class, + ASTSnapshotIntervalEntry.class ); private static final List EXCLUSIVE_DQN_PARAMETERS = Lists.newArrayList( diff --git a/src/test/resources/valid_tests/ReinforcementConfig.cnnt b/src/test/resources/valid_tests/ReinforcementConfig.cnnt index 02ab5fa..1dc025f 100644 --- a/src/test/resources/valid_tests/ReinforcementConfig.cnnt +++ b/src/test/resources/valid_tests/ReinforcementConfig.cnnt @@ -12,6 +12,7 @@ configuration ReinforcementConfig { discount_factor : 0.998 target_score : 1000 training_interval : 10 + snapshot_interval: 100 start_training_at : 20 -- GitLab From 1a085ac2dff02b292c232874a688b3a894ee3f42 Mon Sep 17 00:00:00 2001 From: Nicola Gatto Date: Wed, 29 May 2019 16:24:47 +0200 Subject: [PATCH 09/16] Add optimizer parameter for critic network --- .../de/monticore/lang/monticar/CNNTrain.mc4 | 15 +++---- .../cnntrain/_cocos/CheckEntryRepetition.java | 39 ++++++++++++------- .../_cocos/ParameterAlgorithmMapping.java | 3 +- .../CNNTrainSymbolTableCreator.java | 20 +++++++++- .../_symboltable/ConfigurationSymbol.java | 10 +++++ .../resources/valid_tests/DdpgConfig.cnnt | 12 ++++++ 6 files changed, 77 insertions(+), 22 deletions(-) diff --git a/src/main/grammars/de/monticore/lang/monticar/CNNTrain.mc4 b/src/main/grammars/de/monticore/lang/monticar/CNNTrain.mc4 index 4584f19..2d76e41 100644 --- a/src/main/grammars/de/monticore/lang/monticar/CNNTrain.mc4 +++ b/src/main/grammars/de/monticore/lang/monticar/CNNTrain.mc4 @@ -20,7 +20,7 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number BatchSizeEntry implements ConfigEntry = name:"batch_size" ":" value:IntegerValue; LoadCheckpointEntry implements ConfigEntry = name:"load_checkpoint" ":" value:BooleanValue; NormalizeEntry implements ConfigEntry = name:"normalize" ":" value:BooleanValue; - OptimizerEntry implements ConfigEntry = name:"optimizer" ":" value:OptimizerValue; + OptimizerEntry implements ConfigEntry = (name:"optimizer" | name:"actor_optimizer") ":" value:OptimizerValue; TrainContextEntry implements ConfigEntry = name:"context" ":" value:TrainContextValue; EvalMetricEntry implements ConfigEntry = name:"eval_metric" ":" value:EvalMetricValue; LossEntry implements ConfigEntry = name:"loss" ":" value:LossValue; @@ -53,23 +53,23 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number StringValue implements ConfigValue = StringLiteral; BooleanValue implements ConfigValue = (TRUE:"true" | FALSE:"false"); + interface OptimizerParamEntry extends Entry; interface OptimizerValue extends ConfigValue; - - interface SGDEntry extends Entry; + interface SGDEntry extends OptimizerParamEntry; SGDOptimizer implements OptimizerValue = name:"sgd" ("{" params:SGDEntry* "}")?; - interface AdamEntry extends Entry; + interface AdamEntry extends OptimizerParamEntry; AdamOptimizer implements OptimizerValue = name:"adam" ("{" params:AdamEntry* "}")?; - interface RmsPropEntry extends Entry; + interface RmsPropEntry extends OptimizerParamEntry; RmsPropOptimizer implements OptimizerValue = name:"rmsprop" ("{" params:RmsPropEntry* "}")?; - interface AdaGradEntry extends Entry; + interface AdaGradEntry extends OptimizerParamEntry; AdaGradOptimizer implements OptimizerValue = name:"adagrad" ("{" params:AdaGradEntry* "}")?; NesterovOptimizer implements OptimizerValue = name:"nag" ("{" params:SGDEntry* "}")?; - interface AdaDeltaEntry extends Entry; + interface AdaDeltaEntry extends OptimizerParamEntry; AdaDeltaOptimizer implements OptimizerValue = name:"adadelta" ("{" params:AdaDeltaEntry* "}")?; interface GeneralOptimizerEntry extends SGDEntry,AdamEntry,RmsPropEntry,AdaGradEntry,AdaDeltaEntry; @@ -175,4 +175,5 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number // DDPG exclusive parameters CriticNetworkEntry implements ConfigEntry = name:"critic" ":" value:ComponentNameValue; SoftTargetUpdateRateEntry implements ConfigEntry = name:"soft_target_update_rate" ":" value:NumberValue; + CriticOptimizerEntry implements ConfigEntry = name:"critic_optimizer" ":" value:OptimizerValue; } \ No newline at end of file diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckEntryRepetition.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckEntryRepetition.java index 07337b8..5a3ce20 100644 --- a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckEntryRepetition.java +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckEntryRepetition.java @@ -20,8 +20,9 @@ */ package de.monticore.lang.monticar.cnntrain._cocos; -import de.monticore.lang.monticar.cnntrain._ast.ASTEntry; -import de.monticore.lang.monticar.cnntrain._ast.ASTGreedyEpsilonEntry; +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Sets; +import de.monticore.lang.monticar.cnntrain._ast.*; import de.monticore.lang.monticar.cnntrain.helper.ErrorCodes; import de.se_rwth.commons.logging.Log; @@ -29,23 +30,35 @@ import java.util.HashSet; import java.util.Set; public class CheckEntryRepetition implements CNNTrainASTEntryCoCo { + private final static Set> REPEATABLE_ENTRIES = ImmutableSet + .>builder() + .add(ASTOptimizerParamEntry.class) + .build(); + + private Set entryNameSet = new HashSet<>(); @Override public void check(ASTEntry node) { - String parameterPrefix = ""; - if (node instanceof ASTGreedyEpsilonEntry) { - parameterPrefix = "greedy_"; - } - if (entryNameSet.contains(parameterPrefix + node.getName())){ - Log.error("0" + ErrorCodes.ENTRY_REPETITION_CODE +" The parameter '" + node.getName() + "' has multiple values. " + - "Multiple assignments of the same parameter are not allowed", - node.get_SourcePositionStart()); - } - else { - entryNameSet.add(parameterPrefix + node.getName()); + if (!isRepeatable(node)) { + String parameterPrefix = ""; + + if (node instanceof ASTGreedyEpsilonEntry) { + parameterPrefix = "greedy_"; + } + if (entryNameSet.contains(parameterPrefix + node.getName())) { + Log.error("0" + ErrorCodes.ENTRY_REPETITION_CODE + " The parameter '" + node.getName() + "' has multiple values. " + + "Multiple assignments of the same parameter are not allowed", + node.get_SourcePositionStart()); + } else { + entryNameSet.add(parameterPrefix + node.getName()); + } } } + private boolean isRepeatable(final ASTEntry node) { + return REPEATABLE_ENTRIES.stream().anyMatch(i -> i.isInstance(node)); + } + } diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ParameterAlgorithmMapping.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ParameterAlgorithmMapping.java index 2e35ac7..62aa360 100644 --- a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ParameterAlgorithmMapping.java +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ParameterAlgorithmMapping.java @@ -96,7 +96,8 @@ public class ParameterAlgorithmMapping { private static final List EXCLUSIVE_DDPG_PARAMETERS = Lists.newArrayList( ASTCriticNetworkEntry.class, - ASTSoftTargetUpdateRateEntry.class + ASTSoftTargetUpdateRateEntry.class, + ASTCriticOptimizerEntry.class ); ParameterAlgorithmMapping() { diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/CNNTrainSymbolTableCreator.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/CNNTrainSymbolTableCreator.java index b3172ef..ccf994f 100644 --- a/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/CNNTrainSymbolTableCreator.java +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/CNNTrainSymbolTableCreator.java @@ -96,11 +96,29 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP { OptimizerParamSymbol param = new OptimizerParamSymbol(); OptimizerParamValueSymbol valueSymbol = (OptimizerParamValueSymbol) nodeParam.getValue().getSymbolOpt().get(); param.setValue(valueSymbol); - configuration.getOptimizer().getOptimizerParamMap().put(nodeParam.getName(), param);; + configuration.getOptimizer().getOptimizerParamMap().put(nodeParam.getName(), param); } } + @Override + public void visit(ASTCriticOptimizerEntry node) { + OptimizerSymbol optimizerSymbol = new OptimizerSymbol(node.getValue().getName()); + configuration.setCriticOptimizer(optimizerSymbol); + addToScopeAndLinkWithNode(optimizerSymbol, node); + } + + @Override + public void endVisit(ASTCriticOptimizerEntry node) { + assert configuration.getCriticOptimizer().isPresent(): "Critic optimizer not present"; + for (ASTEntry paramNode : node.getValue().getParamsList()) { + OptimizerParamSymbol param = new OptimizerParamSymbol(); + OptimizerParamValueSymbol valueSymbol = (OptimizerParamValueSymbol)paramNode.getValue().getSymbolOpt().get(); + param.setValue(valueSymbol); + configuration.getCriticOptimizer().get().getOptimizerParamMap().put(paramNode.getName(), param); + } + } + @Override public void endVisit(ASTNumEpochEntry node) { EntrySymbol entry = new EntrySymbol(node.getName()); diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/ConfigurationSymbol.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/ConfigurationSymbol.java index f137c47..a026fbd 100644 --- a/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/ConfigurationSymbol.java +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/ConfigurationSymbol.java @@ -24,12 +24,14 @@ import com.google.common.collect.Lists; import de.monticore.lang.monticar.cnntrain.annotations.TrainedArchitecture; import de.monticore.symboltable.CommonScopeSpanningSymbol; +import javax.swing.text.html.Option; import java.util.*; public class ConfigurationSymbol extends CommonScopeSpanningSymbol { private Map entryMap = new HashMap<>(); private OptimizerSymbol optimizer; + private OptimizerSymbol criticOptimizer; private RewardFunctionSymbol rlRewardFunctionSymbol; private TrainedArchitecture trainedArchitecture; @@ -49,6 +51,14 @@ public class ConfigurationSymbol extends CommonScopeSpanningSymbol { this.optimizer = optimizer; } + public void setCriticOptimizer(OptimizerSymbol criticOptimizer) { + this.criticOptimizer = criticOptimizer; + } + + public Optional getCriticOptimizer() { + return Optional.ofNullable(criticOptimizer); + } + protected void setRlRewardFunction(RewardFunctionSymbol rlRewardFunctionSymbol) { this.rlRewardFunctionSymbol = rlRewardFunctionSymbol; } diff --git a/src/test/resources/valid_tests/DdpgConfig.cnnt b/src/test/resources/valid_tests/DdpgConfig.cnnt index 3553f2d..6df4e7b 100644 --- a/src/test/resources/valid_tests/DdpgConfig.cnnt +++ b/src/test/resources/valid_tests/DdpgConfig.cnnt @@ -4,4 +4,16 @@ configuration DdpgConfig { critic : path.to.component environment : gym { name:"CartPole-v1" } soft_target_update_rate: 0.001 + actor_optimizer : adam{ + learning_rate : 0.0001 + learning_rate_minimum : 0.00005 + learning_rate_decay : 0.9 + learning_rate_policy : step + } + critic_optimizer : rmsprop{ + learning_rate : 0.001 + learning_rate_minimum : 0.0001 + learning_rate_decay : 0.5 + learning_rate_policy : step + } } \ No newline at end of file -- GitLab From 36eff3c334c14b2fb73e36c83dd53b79878f3af4 Mon Sep 17 00:00:00 2001 From: Nicola Gatto Date: Wed, 29 May 2019 22:41:28 +0200 Subject: [PATCH 10/16] Rename parameter action_selection to strategy --- .../de/monticore/lang/monticar/CNNTrain.mc4 | 20 +++++++++---------- .../_cocos/ParameterAlgorithmMapping.java | 2 +- .../CNNTrainSymbolTableCreator.java | 4 ++-- .../CheckLearningParameterCombination1.cnnt | 2 +- ...CheckReinforcementRequiresEnvironment.cnnt | 2 +- ...kRosEnvironmentRequiresRewardFunction.cnnt | 2 +- .../valid_tests/ReinforcementConfig.cnnt | 2 +- .../valid_tests/ReinforcementConfig2.cnnt | 2 +- 8 files changed, 18 insertions(+), 18 deletions(-) diff --git a/src/main/grammars/de/monticore/lang/monticar/CNNTrain.mc4 b/src/main/grammars/de/monticore/lang/monticar/CNNTrain.mc4 index 2d76e41..757c647 100644 --- a/src/main/grammars/de/monticore/lang/monticar/CNNTrain.mc4 +++ b/src/main/grammars/de/monticore/lang/monticar/CNNTrain.mc4 @@ -135,19 +135,19 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number MemorySizeEntry implements GeneralReplayMemoryEntry = name:"memory_size" ":" value:IntegerValue; SampleSizeEntry implements GeneralReplayMemoryEntry = name:"sample_size" ":" value:IntegerValue; - // Action Selection - ActionSelectionEntry implements MultiParamConfigEntry = name:"action_selection" ":" value:ActionSelectionValue; - interface ActionSelectionValue extends MultiParamValue; + // Strategy + StrategyEntry implements MultiParamConfigEntry = name:"strategy" ":" value:StrategyValue; + interface StrategyValue extends MultiParamValue; - interface ActionSelectionEpsGreedyEntry extends Entry; - ActionSelectionEpsGreedyValue implements ActionSelectionValue = name:"epsgreedy" ("{" params:ActionSelectionEpsGreedyEntry* "}")?; + interface StrategyEpsGreedyEntry extends Entry; + StrategyEpsGreedyValue implements StrategyValue = name:"epsgreedy" ("{" params:StrategyEpsGreedyEntry* "}")?; - GreedyEpsilonEntry implements ActionSelectionEpsGreedyEntry = name:"epsilon" ":" value:NumberValue; - MinEpsilonEntry implements ActionSelectionEpsGreedyEntry = name:"min_epsilon" ":" value:NumberValue; - EpsilonDecayStartEntry implements ActionSelectionEpsGreedyEntry = name:"epsilon_decay_start" ":" value:IntegerValue; - EpsilonDecayMethodEntry implements ActionSelectionEpsGreedyEntry = name:"epsilon_decay_method" ":" value:EpsilonDecayMethodValue; + GreedyEpsilonEntry implements StrategyEpsGreedyEntry = name:"epsilon" ":" value:NumberValue; + MinEpsilonEntry implements StrategyEpsGreedyEntry = name:"min_epsilon" ":" value:NumberValue; + EpsilonDecayStartEntry implements StrategyEpsGreedyEntry = name:"epsilon_decay_start" ":" value:IntegerValue; + EpsilonDecayMethodEntry implements StrategyEpsGreedyEntry = name:"epsilon_decay_method" ":" value:EpsilonDecayMethodValue; EpsilonDecayMethodValue implements ConfigValue = (linear:"linear" | no:"no"); - EpsilonDecayEntry implements ActionSelectionEpsGreedyEntry = name:"epsilon_decay" ":" value:NumberValue; + EpsilonDecayEntry implements StrategyEpsGreedyEntry = name:"epsilon_decay" ":" value:NumberValue; // Environment EnvironmentEntry implements MultiParamConfigEntry = name:"environment" ":" value:EnvironmentValue; diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ParameterAlgorithmMapping.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ParameterAlgorithmMapping.java index 62aa360..4934071 100644 --- a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ParameterAlgorithmMapping.java +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ParameterAlgorithmMapping.java @@ -69,7 +69,7 @@ public class ParameterAlgorithmMapping { ASTReplayMemoryEntry.class, ASTMemorySizeEntry.class, ASTSampleSizeEntry.class, - ASTActionSelectionEntry.class, + ASTStrategyEntry.class, ASTGreedyEpsilonEntry.class, ASTMinEpsilonEntry.class, ASTEpsilonDecayEntry.class, diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/CNNTrainSymbolTableCreator.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/CNNTrainSymbolTableCreator.java index ccf994f..10fc0db 100644 --- a/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/CNNTrainSymbolTableCreator.java +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/CNNTrainSymbolTableCreator.java @@ -458,12 +458,12 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP { } @Override - public void visit(ASTActionSelectionEntry node) { + public void visit(ASTStrategyEntry node) { processMultiParamConfigVisit(node, node.getValue().getName()); } @Override - public void endVisit(ASTActionSelectionEntry node) { + public void endVisit(ASTStrategyEntry node) { processMultiParamConfigEndVisit(node); } diff --git a/src/test/resources/invalid_cocos_tests/CheckLearningParameterCombination1.cnnt b/src/test/resources/invalid_cocos_tests/CheckLearningParameterCombination1.cnnt index 1a2ddec..b428c38 100644 --- a/src/test/resources/invalid_cocos_tests/CheckLearningParameterCombination1.cnnt +++ b/src/test/resources/invalid_cocos_tests/CheckLearningParameterCombination1.cnnt @@ -25,7 +25,7 @@ configuration CheckLearningParameterCombination1 { sample_size : 64 } - action_selection : epsgreedy{ + strategy : epsgreedy{ epsilon : 1.0 min_epsilon : 0.01 epsilon_decay_method: linear diff --git a/src/test/resources/invalid_cocos_tests/CheckReinforcementRequiresEnvironment.cnnt b/src/test/resources/invalid_cocos_tests/CheckReinforcementRequiresEnvironment.cnnt index dc8861a..0cb8e23 100644 --- a/src/test/resources/invalid_cocos_tests/CheckReinforcementRequiresEnvironment.cnnt +++ b/src/test/resources/invalid_cocos_tests/CheckReinforcementRequiresEnvironment.cnnt @@ -23,7 +23,7 @@ configuration CheckReinforcementRequiresEnvironment { sample_size : 64 } - action_selection : epsgreedy{ + strategy : epsgreedy{ epsilon : 1.0 min_epsilon : 0.01 epsilon_decay_method: linear diff --git a/src/test/resources/invalid_cocos_tests/CheckRosEnvironmentRequiresRewardFunction.cnnt b/src/test/resources/invalid_cocos_tests/CheckRosEnvironmentRequiresRewardFunction.cnnt index 7175377..ef61344 100644 --- a/src/test/resources/invalid_cocos_tests/CheckRosEnvironmentRequiresRewardFunction.cnnt +++ b/src/test/resources/invalid_cocos_tests/CheckRosEnvironmentRequiresRewardFunction.cnnt @@ -32,7 +32,7 @@ configuration CheckRosEnvironmentRequiresRewardFunction { sample_size : 64 } - action_selection : epsgreedy{ + strategy : epsgreedy{ epsilon : 1.0 min_epsilon : 0.01 epsilon_decay_method: linear diff --git a/src/test/resources/valid_tests/ReinforcementConfig.cnnt b/src/test/resources/valid_tests/ReinforcementConfig.cnnt index 1dc025f..a1b59db 100644 --- a/src/test/resources/valid_tests/ReinforcementConfig.cnnt +++ b/src/test/resources/valid_tests/ReinforcementConfig.cnnt @@ -30,7 +30,7 @@ configuration ReinforcementConfig { sample_size : 64 } - action_selection : epsgreedy{ + strategy : epsgreedy{ epsilon : 1.0 min_epsilon : 0.01 epsilon_decay_method: linear diff --git a/src/test/resources/valid_tests/ReinforcementConfig2.cnnt b/src/test/resources/valid_tests/ReinforcementConfig2.cnnt index 537be2d..470df30 100644 --- a/src/test/resources/valid_tests/ReinforcementConfig2.cnnt +++ b/src/test/resources/valid_tests/ReinforcementConfig2.cnnt @@ -34,7 +34,7 @@ configuration ReinforcementConfig2 { sample_size : 64 } - action_selection : epsgreedy{ + strategy : epsgreedy{ epsilon : 1.0 min_epsilon : 0.01 epsilon_decay_method: linear -- GitLab From 4225e326620def8f6c71a64de94873f0f8ccc774 Mon Sep 17 00:00:00 2001 From: Nicola Gatto Date: Thu, 30 May 2019 01:49:37 +0200 Subject: [PATCH 11/16] Add ornstein_uhlenbeck strategy --- .../de/monticore/lang/monticar/CNNTrain.mc4 | 38 +++++++++------ .../_cocos/ASTConfigurationUtils.java | 14 ++++++ .../cnntrain/_cocos/CNNTrainCocos.java | 4 +- ...uousRLAlgorithmUsesContinuousStrategy.java | 47 +++++++++++++++++++ ...screteRLAlgorithmUsesDiscreteStrategy.java | 47 +++++++++++++++++++ .../_cocos/ParameterAlgorithmMapping.java | 5 +- .../CNNTrainSymbolTableCreator.java | 7 +++ .../monticar/cnntrain/helper/ErrorCodes.java | 1 + .../lang/monticar/cnntrain/ParserTest.java | 3 +- .../monticar/cnntrain/cocos/AllCoCoTest.java | 6 +++ ...uousRLAlgorithmUsesContinuousStrategy.cnnt | 12 +++++ ...screteRLAlgorithmUsesDiscreteStrategy.cnnt | 15 ++++++ .../WrongStrategyParameter.cnnt | 11 +++++ .../resources/valid_tests/DdpgConfig.cnnt | 10 ++++ 14 files changed, 202 insertions(+), 18 deletions(-) create mode 100644 src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckContinuousRLAlgorithmUsesContinuousStrategy.java create mode 100644 src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckDiscreteRLAlgorithmUsesDiscreteStrategy.java create mode 100644 src/test/resources/invalid_cocos_tests/CheckContinuousRLAlgorithmUsesContinuousStrategy.cnnt create mode 100644 src/test/resources/invalid_cocos_tests/CheckDiscreteRLAlgorithmUsesDiscreteStrategy.cnnt create mode 100644 src/test/resources/invalid_parser_tests/WrongStrategyParameter.cnnt diff --git a/src/main/grammars/de/monticore/lang/monticar/CNNTrain.mc4 b/src/main/grammars/de/monticore/lang/monticar/CNNTrain.mc4 index 757c647..177553c 100644 --- a/src/main/grammars/de/monticore/lang/monticar/CNNTrain.mc4 +++ b/src/main/grammars/de/monticore/lang/monticar/CNNTrain.mc4 @@ -1,7 +1,6 @@ package de.monticore.lang.monticar; grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.NumberUnit{ - symbol scope CNNTrainCompilationUnit = "configuration" name:Name& Configuration; @@ -16,6 +15,15 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number interface VariableReference; ast VariableReference = method String getName(){}; + // General Values + DataVariable implements VariableReference = Name&; + IntegerValue implements ConfigValue = NumberWithUnit; + NumberValue implements ConfigValue = NumberWithUnit; + StringValue implements ConfigValue = StringLiteral; + BooleanValue implements ConfigValue = (TRUE:"true" | FALSE:"false"); + ComponentNameValue implements ConfigValue = Name ("."Name)*; + DoubleVectorValue implements ConfigValue = "(" number:NumberWithUnit ("," number:NumberWithUnit)* ")"; + NumEpochEntry implements ConfigEntry = name:"num_epoch" ":" value:IntegerValue; BatchSizeEntry implements ConfigEntry = name:"batch_size" ":" value:IntegerValue; LoadCheckpointEntry implements ConfigEntry = name:"load_checkpoint" ":" value:BooleanValue; @@ -46,13 +54,6 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number | sigmoid:"sigmoid"); TrainContextValue implements ConfigValue = (cpu:"cpu" | gpu:"gpu"); - - DataVariable implements VariableReference = Name&; - IntegerValue implements ConfigValue = NumberWithUnit; - NumberValue implements ConfigValue = NumberWithUnit; - StringValue implements ConfigValue = StringLiteral; - BooleanValue implements ConfigValue = (TRUE:"true" | FALSE:"false"); - interface OptimizerParamEntry extends Entry; interface OptimizerValue extends ConfigValue; interface SGDEntry extends OptimizerParamEntry; @@ -110,8 +111,6 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number StartTrainingAtEntry implements ConfigEntry = name:"start_training_at" ":" value:IntegerValue; EvaluationSamplesEntry implements ConfigEntry = name:"evaluation_samples" ":" value:IntegerValue; - ComponentNameValue implements ConfigValue = Name ("."Name)*; - LearningMethodValue implements ConfigValue = (supervisedLearning:"supervised" | reinforcement:"reinforcement"); RLAlgorithmValue implements ConfigValue = (dqn:"dqn-algorithm" | ddpg:"ddpg-algorithm"); @@ -142,12 +141,21 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number interface StrategyEpsGreedyEntry extends Entry; StrategyEpsGreedyValue implements StrategyValue = name:"epsgreedy" ("{" params:StrategyEpsGreedyEntry* "}")?; - GreedyEpsilonEntry implements StrategyEpsGreedyEntry = name:"epsilon" ":" value:NumberValue; - MinEpsilonEntry implements StrategyEpsGreedyEntry = name:"min_epsilon" ":" value:NumberValue; - EpsilonDecayStartEntry implements StrategyEpsGreedyEntry = name:"epsilon_decay_start" ":" value:IntegerValue; - EpsilonDecayMethodEntry implements StrategyEpsGreedyEntry = name:"epsilon_decay_method" ":" value:EpsilonDecayMethodValue; + interface StrategyOrnsteinUhlenbeckEntry extends Entry; + StrategyOrnsteinUhlenbeckValue implements StrategyValue = name:"ornstein_uhlenbeck" ("{" params:StrategyOrnsteinUhlenbeckEntry* "}")?; + + StrategyOUMu implements StrategyOrnsteinUhlenbeckEntry = name: "mu" ":" value:DoubleVectorValue; + StrategyOUTheta implements StrategyOrnsteinUhlenbeckEntry = name: "theta" ":" value:DoubleVectorValue; + StrategyOUSigma implements StrategyOrnsteinUhlenbeckEntry = name: "sigma" ":" value:DoubleVectorValue; + + interface GeneralStrategyEntry extends StrategyEpsGreedyEntry, StrategyOrnsteinUhlenbeckEntry; + + GreedyEpsilonEntry implements GeneralStrategyEntry = name:"epsilon" ":" value:NumberValue; + MinEpsilonEntry implements GeneralStrategyEntry = name:"min_epsilon" ":" value:NumberValue; + EpsilonDecayStartEntry implements GeneralStrategyEntry = name:"epsilon_decay_start" ":" value:IntegerValue; + EpsilonDecayMethodEntry implements GeneralStrategyEntry = name:"epsilon_decay_method" ":" value:EpsilonDecayMethodValue; EpsilonDecayMethodValue implements ConfigValue = (linear:"linear" | no:"no"); - EpsilonDecayEntry implements StrategyEpsGreedyEntry = name:"epsilon_decay" ":" value:NumberValue; + EpsilonDecayEntry implements GeneralStrategyEntry = name:"epsilon_decay" ":" value:NumberValue; // Environment EnvironmentEntry implements MultiParamConfigEntry = name:"environment" ":" value:EnvironmentValue; diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ASTConfigurationUtils.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ASTConfigurationUtils.java index 1fb1f6a..6ff1535 100644 --- a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ASTConfigurationUtils.java +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ASTConfigurationUtils.java @@ -22,6 +22,8 @@ package de.monticore.lang.monticar.cnntrain._cocos; import de.monticore.lang.monticar.cnntrain._ast.*; +import java.util.Optional; + class ASTConfigurationUtils { static boolean isReinforcementLearning(final ASTConfiguration configuration) { return configuration.getEntriesList().stream().anyMatch(e -> @@ -46,4 +48,16 @@ class ASTConfigurationUtils { static boolean hasEntry(final ASTConfiguration configuration, final Class entryClazz) { return configuration.getEntriesList().stream().anyMatch(entryClazz::isInstance); } + + static boolean hasStrategy(final ASTConfiguration configuration) { + return configuration.getEntriesList().stream().anyMatch(e -> e instanceof ASTStrategyEntry); + } + + static Optional getStrategyMethod(final ASTConfiguration configuration) { + return configuration.getEntriesList().stream() + .filter(e -> e instanceof ASTStrategyEntry) + .map(e -> (ASTStrategyEntry)e) + .findFirst() + .map(astStrategyEntry -> astStrategyEntry.getValue().getName()); + } } diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CNNTrainCocos.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CNNTrainCocos.java index 35e3d7b..5038a14 100644 --- a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CNNTrainCocos.java +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CNNTrainCocos.java @@ -35,7 +35,9 @@ public class CNNTrainCocos { .addCoCo(new CheckLearningParameterCombination()) .addCoCo(new CheckRosEnvironmentRequiresRewardFunction()) .addCoCo(new CheckDdpgRequiresCriticNetwork()) - .addCoCo(new CheckRlAlgorithmParameter()); + .addCoCo(new CheckRlAlgorithmParameter()) + .addCoCo(new CheckDiscreteRLAlgorithmUsesDiscreteStrategy()) + .addCoCo(new CheckContinuousRLAlgorithmUsesContinuousStrategy()); } public static void checkAll(CNNTrainCompilationUnitSymbol compilationUnit){ diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckContinuousRLAlgorithmUsesContinuousStrategy.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckContinuousRLAlgorithmUsesContinuousStrategy.java new file mode 100644 index 0000000..5d0a68c --- /dev/null +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckContinuousRLAlgorithmUsesContinuousStrategy.java @@ -0,0 +1,47 @@ +/** + * + * ****************************************************************************** + * MontiCAR Modeling Family, www.se-rwth.de + * Copyright (c) 2017, Software Engineering Group at RWTH Aachen, + * All rights reserved. + * + * This project is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 3.0 of the License, or (at your option) any later version. + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this project. If not, see . + * ******************************************************************************* + */ +package de.monticore.lang.monticar.cnntrain._cocos; + +import com.google.common.collect.ImmutableSet; +import de.monticore.lang.monticar.cnntrain._ast.ASTConfiguration; +import de.monticore.lang.monticar.cnntrain.helper.ErrorCodes; +import de.se_rwth.commons.logging.Log; + +import java.util.Set; + +public class CheckContinuousRLAlgorithmUsesContinuousStrategy implements CNNTrainASTConfigurationCoCo{ + private static final Set CONTINUOUS_STRATEGIES = ImmutableSet.builder() + .add("ornstein_uhlenbeck") + .build(); + + @Override + public void check(ASTConfiguration node) { + if (ASTConfigurationUtils.isDdpgAlgorithm(node) + && ASTConfigurationUtils.hasStrategy(node) + && ASTConfigurationUtils.getStrategyMethod(node).isPresent()) { + final String usedStrategy = ASTConfigurationUtils.getStrategyMethod(node).get(); + if (!CONTINUOUS_STRATEGIES.contains(usedStrategy)) { + Log.error("0" + ErrorCodes.STRATEGY_NOT_APPLICABLE + " Strategy " + usedStrategy + " used but" + + " continuous algorithm used.", node.get_SourcePositionStart()); + } + } + } +} \ No newline at end of file diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckDiscreteRLAlgorithmUsesDiscreteStrategy.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckDiscreteRLAlgorithmUsesDiscreteStrategy.java new file mode 100644 index 0000000..64263bb --- /dev/null +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckDiscreteRLAlgorithmUsesDiscreteStrategy.java @@ -0,0 +1,47 @@ +/** + * + * ****************************************************************************** + * MontiCAR Modeling Family, www.se-rwth.de + * Copyright (c) 2017, Software Engineering Group at RWTH Aachen, + * All rights reserved. + * + * This project is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 3.0 of the License, or (at your option) any later version. + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this project. If not, see . + * ******************************************************************************* + */ +package de.monticore.lang.monticar.cnntrain._cocos; + +import com.google.common.collect.ImmutableSet; +import de.monticore.lang.monticar.cnntrain._ast.ASTConfiguration; +import de.monticore.lang.monticar.cnntrain.helper.ErrorCodes; +import de.se_rwth.commons.logging.Log; + +import java.util.Set; + +public class CheckDiscreteRLAlgorithmUsesDiscreteStrategy implements CNNTrainASTConfigurationCoCo{ + private static final Set DISCRETE_STRATEGIES = ImmutableSet.builder() + .add("epsgreedy") + .build(); + + @Override + public void check(ASTConfiguration node) { + if (ASTConfigurationUtils.isDqnAlgorithm(node) + && ASTConfigurationUtils.hasStrategy(node) + && ASTConfigurationUtils.getStrategyMethod(node).isPresent()) { + final String usedStrategy = ASTConfigurationUtils.getStrategyMethod(node).get(); + if (!DISCRETE_STRATEGIES.contains(usedStrategy)) { + Log.error("0" + ErrorCodes.STRATEGY_NOT_APPLICABLE + " Strategy " + usedStrategy + " used but" + + " discrete algorithm used.", node.get_SourcePositionStart()); + } + } + } +} diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ParameterAlgorithmMapping.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ParameterAlgorithmMapping.java index 4934071..a33813b 100644 --- a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ParameterAlgorithmMapping.java +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ParameterAlgorithmMapping.java @@ -97,7 +97,10 @@ public class ParameterAlgorithmMapping { private static final List EXCLUSIVE_DDPG_PARAMETERS = Lists.newArrayList( ASTCriticNetworkEntry.class, ASTSoftTargetUpdateRateEntry.class, - ASTCriticOptimizerEntry.class + ASTCriticOptimizerEntry.class, + ASTStrategyOUMu.class, + ASTStrategyOUTheta.class, + ASTStrategyOUSigma.class ); ParameterAlgorithmMapping() { diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/CNNTrainSymbolTableCreator.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/CNNTrainSymbolTableCreator.java index 10fc0db..50b4ab4 100644 --- a/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/CNNTrainSymbolTableCreator.java +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/CNNTrainSymbolTableCreator.java @@ -29,6 +29,7 @@ import de.monticore.symboltable.ResolvingConfiguration; import de.se_rwth.commons.logging.Log; import java.util.*; +import java.util.stream.Collectors; public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP { @@ -546,6 +547,12 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP { } else { return EpsilonDecayMethod.NO; } + } else if (configValue instanceof ASTDoubleVectorValue) { + ASTDoubleVectorValue astDoubleVectorValue = (ASTDoubleVectorValue)configValue; + return astDoubleVectorValue.getNumberList().stream() + .filter(n -> n.getNumber().isPresent()) + .map(n -> n.getNumber().get()) + .collect(Collectors.toList()); } throw new UnsupportedOperationException("Unknown Value type: " + configValue.getClass()); } diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/helper/ErrorCodes.java b/src/main/java/de/monticore/lang/monticar/cnntrain/helper/ErrorCodes.java index f21894e..98c123b 100644 --- a/src/main/java/de/monticore/lang/monticar/cnntrain/helper/ErrorCodes.java +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/helper/ErrorCodes.java @@ -29,4 +29,5 @@ public class ErrorCodes { public static final String UNSUPPORTED_PARAMETER = "xC8854"; public static final String RANGE_ERROR = "xC8855"; public static final String REQUIRED_PARAMETER_MISSING = "xC8856"; + public static final String STRATEGY_NOT_APPLICABLE = "xC8857"; } \ No newline at end of file diff --git a/src/test/java/de/monticore/lang/monticar/cnntrain/ParserTest.java b/src/test/java/de/monticore/lang/monticar/cnntrain/ParserTest.java index 6e10eca..8adfc8d 100644 --- a/src/test/java/de/monticore/lang/monticar/cnntrain/ParserTest.java +++ b/src/test/java/de/monticore/lang/monticar/cnntrain/ParserTest.java @@ -46,7 +46,8 @@ public class ParserTest { "src/test/resources/invalid_parser_tests/WrongParameterName2.cnnt", "src/test/resources/invalid_parser_tests/InvalidType.cnnt", "src/test/resources/invalid_parser_tests/InvalidOptimizer.cnnt", - "src/test/resources/invalid_parser_tests/MissingColon.cnnt") + "src/test/resources/invalid_parser_tests/MissingColon.cnnt", + "src/test/resources/invalid_parser_tests/WrongStrategyParameter.cnnt") .map(s -> Paths.get(s).toString()) .collect(Collectors.toList()); diff --git a/src/test/java/de/monticore/lang/monticar/cnntrain/cocos/AllCoCoTest.java b/src/test/java/de/monticore/lang/monticar/cnntrain/cocos/AllCoCoTest.java index 251058e..f200e94 100644 --- a/src/test/java/de/monticore/lang/monticar/cnntrain/cocos/AllCoCoTest.java +++ b/src/test/java/de/monticore/lang/monticar/cnntrain/cocos/AllCoCoTest.java @@ -85,5 +85,11 @@ public class AllCoCoTest extends AbstractCoCoTest{ checkInvalid(new CNNTrainCoCoChecker().addCoCo(new CheckRlAlgorithmParameter()), "invalid_cocos_tests", "CheckRLAlgorithmParameter3", new ExpectedErrorInfo(1, ErrorCodes.UNSUPPORTED_PARAMETER)); + checkInvalid(new CNNTrainCoCoChecker().addCoCo(new CheckDiscreteRLAlgorithmUsesDiscreteStrategy()), + "invalid_cocos_tests", "CheckDiscreteRLAlgorithmUsesDiscreteStrategy", + new ExpectedErrorInfo(1, ErrorCodes.STRATEGY_NOT_APPLICABLE)); + checkInvalid(new CNNTrainCoCoChecker().addCoCo(new CheckContinuousRLAlgorithmUsesContinuousStrategy()), + "invalid_cocos_tests", "CheckContinuousRLAlgorithmUsesContinuousStrategy", + new ExpectedErrorInfo(1, ErrorCodes.STRATEGY_NOT_APPLICABLE)); } } diff --git a/src/test/resources/invalid_cocos_tests/CheckContinuousRLAlgorithmUsesContinuousStrategy.cnnt b/src/test/resources/invalid_cocos_tests/CheckContinuousRLAlgorithmUsesContinuousStrategy.cnnt new file mode 100644 index 0000000..9a14a32 --- /dev/null +++ b/src/test/resources/invalid_cocos_tests/CheckContinuousRLAlgorithmUsesContinuousStrategy.cnnt @@ -0,0 +1,12 @@ +configuration CheckContinuousRLAlgorithmUsesContinuousStrategy { + learning_method : reinforcement + rl_algorithm: ddpg-algorithm + + environment : gym { name:"CartPole-v1" } + + strategy: epsgreedy { + epsilon: 1.0 + epsilon_decay_method: linear + epsilon_decay: 0.01 + } +} \ No newline at end of file diff --git a/src/test/resources/invalid_cocos_tests/CheckDiscreteRLAlgorithmUsesDiscreteStrategy.cnnt b/src/test/resources/invalid_cocos_tests/CheckDiscreteRLAlgorithmUsesDiscreteStrategy.cnnt new file mode 100644 index 0000000..26fa2b4 --- /dev/null +++ b/src/test/resources/invalid_cocos_tests/CheckDiscreteRLAlgorithmUsesDiscreteStrategy.cnnt @@ -0,0 +1,15 @@ +configuration CheckDiscreteRLAlgorithmUsesDiscreteStrategy { + learning_method : reinforcement + rl_algorithm: dqn-algorithm + + environment : gym { name:"CartPole-v1" } + + strategy: ornstein_uhlenbeck { + epsilon: 1.0 + epsilon_decay_method: linear + epsilon_decay: 0.01 + mu: (0.0, 0.1, 0.3) + theta: (0.5, 0.0, 0.8) + sigma: (0.3, 0.6, -0.9) + } +} \ No newline at end of file diff --git a/src/test/resources/invalid_parser_tests/WrongStrategyParameter.cnnt b/src/test/resources/invalid_parser_tests/WrongStrategyParameter.cnnt new file mode 100644 index 0000000..5b40e7b --- /dev/null +++ b/src/test/resources/invalid_parser_tests/WrongStrategyParameter.cnnt @@ -0,0 +1,11 @@ +configuration WrongStrategyParameter { + learning_method : reinforcement + rl_algorithm: ddpg-algorithm + environment : gym { name:"CartPole-v1" } + strategy: epsgreedy { + epsilon: 1.0 + epsilon_decay_method: linear + epsilon_decay: 0.01 + mu: (0.01) + } +} \ No newline at end of file diff --git a/src/test/resources/valid_tests/DdpgConfig.cnnt b/src/test/resources/valid_tests/DdpgConfig.cnnt index 6df4e7b..e75c6e5 100644 --- a/src/test/resources/valid_tests/DdpgConfig.cnnt +++ b/src/test/resources/valid_tests/DdpgConfig.cnnt @@ -16,4 +16,14 @@ configuration DdpgConfig { learning_rate_decay : 0.5 learning_rate_policy : step } + strategy : ornstein_uhlenbeck{ + epsilon: 1.0 + min_epsilon: 0.001 + epsilon_decay_method: linear + epsilon_decay : 0.0001 + epsilon_decay_start: 50 + mu: (0.0, 0.1, 0.3) + theta: (0.5, 0.0, 0.8) + sigma: (0.3, 0.6, -0.9) + } } \ No newline at end of file -- GitLab From 6116e9f6ec19a3691d5ec1583f354367aded63e4 Mon Sep 17 00:00:00 2001 From: Nicola Gatto Date: Fri, 31 May 2019 16:17:09 +0200 Subject: [PATCH 12/16] Use an existing common-monticar version --- pom.xml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pom.xml b/pom.xml index b91f77c..bf1e941 100644 --- a/pom.xml +++ b/pom.xml @@ -40,7 +40,7 @@ 5.0.1 1.7.8 0.0.6 - 0.0.17-20180824.094114-1 + 0.0.17-SNAPSHOT 18.0 @@ -350,4 +350,4 @@ https://nexus.se.rwth-aachen.de/content/repositories/embeddedmontiarc-snapshots/ - \ No newline at end of file + -- GitLab From 845baf14e4fb2100a50bea6dc5ae342816a9ef69 Mon Sep 17 00:00:00 2001 From: Nicola Gatto Date: Sat, 1 Jun 2019 18:41:07 +0200 Subject: [PATCH 13/16] Remove unused ROS env. topics --- src/main/grammars/de/monticore/lang/monticar/CNNTrain.mc4 | 2 -- .../monticar/cnntrain/_cocos/ParameterAlgorithmMapping.java | 2 -- .../CheckRosEnvironmentRequiresRewardFunction.cnnt | 2 -- src/test/resources/valid_tests/ReinforcementConfig2.cnnt | 2 -- 4 files changed, 8 deletions(-) diff --git a/src/main/grammars/de/monticore/lang/monticar/CNNTrain.mc4 b/src/main/grammars/de/monticore/lang/monticar/CNNTrain.mc4 index 177553c..51442db 100644 --- a/src/main/grammars/de/monticore/lang/monticar/CNNTrain.mc4 +++ b/src/main/grammars/de/monticore/lang/monticar/CNNTrain.mc4 @@ -170,8 +170,6 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number RosEnvironmentStateTopicEntry implements RosEnvironmentEntry = name:"state_topic" ":" value:StringValue; RosEnvironmentActionTopicEntry implements RosEnvironmentEntry = name:"action_topic" ":" value:StringValue; RosEnvironmentResetTopicEntry implements RosEnvironmentEntry = name:"reset_topic" ":" value:StringValue; - RosEnvironmentGreetingTopicEntry implements RosEnvironmentEntry = name:"greeting_topic" ":" value:StringValue; - RosEnvironmentMetaTopicEntry implements RosEnvironmentEntry = name:"meta_topic" ":" value:StringValue; RosEnvironmentTerminalStateTopicEntry implements RosEnvironmentEntry = name:"terminal_state_topic" ":" value:StringValue; // DQN exclusive parameters diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ParameterAlgorithmMapping.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ParameterAlgorithmMapping.java index a33813b..2c6d755 100644 --- a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ParameterAlgorithmMapping.java +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ParameterAlgorithmMapping.java @@ -77,10 +77,8 @@ public class ParameterAlgorithmMapping { ASTNumEpisodesEntry.class, ASTRosEnvironmentActionTopicEntry.class, ASTRosEnvironmentStateTopicEntry.class, - ASTRosEnvironmentMetaTopicEntry.class, ASTRosEnvironmentResetTopicEntry.class, ASTRosEnvironmentTerminalStateTopicEntry.class, - ASTRosEnvironmentGreetingTopicEntry.class, ASTStartTrainingAtEntry.class, ASTEvaluationSamplesEntry.class, ASTEpsilonDecayStartEntry.class, diff --git a/src/test/resources/invalid_cocos_tests/CheckRosEnvironmentRequiresRewardFunction.cnnt b/src/test/resources/invalid_cocos_tests/CheckRosEnvironmentRequiresRewardFunction.cnnt index ef61344..fdb28a7 100644 --- a/src/test/resources/invalid_cocos_tests/CheckRosEnvironmentRequiresRewardFunction.cnnt +++ b/src/test/resources/invalid_cocos_tests/CheckRosEnvironmentRequiresRewardFunction.cnnt @@ -8,8 +8,6 @@ configuration CheckRosEnvironmentRequiresRewardFunction { action_topic : "action" reset_topic : "reset" terminal_state_topic : "is_terminal" - greeting_topic : "greeting" - meta_topic : "meta" } context : cpu diff --git a/src/test/resources/valid_tests/ReinforcementConfig2.cnnt b/src/test/resources/valid_tests/ReinforcementConfig2.cnnt index 470df30..0462aaa 100644 --- a/src/test/resources/valid_tests/ReinforcementConfig2.cnnt +++ b/src/test/resources/valid_tests/ReinforcementConfig2.cnnt @@ -8,8 +8,6 @@ configuration ReinforcementConfig2 { action_topic : "action" reset_topic : "reset" terminal_state_topic : "is_terminal" - greeting_topic : "greeting" - meta_topic : "meta" } reward_function : path.to.reward.component -- GitLab From 4de36d1d5e648caeb30f72df411d8eba9f2f126d Mon Sep 17 00:00:00 2001 From: Nicola Gatto Date: Sat, 1 Jun 2019 23:59:56 +0200 Subject: [PATCH 14/16] Add reward topic parameter for ROS --- .../de/monticore/lang/monticar/CNNTrain.mc4 | 1 + .../_cocos/ASTConfigurationUtils.java | 24 ++++++++ ...kRosEnvironmentRequiresRewardFunction.java | 31 ++++------ .../_cocos/ParameterAlgorithmMapping.java | 3 +- .../monticar/cnntrain/cocos/AllCoCoTest.java | 1 + .../ReinforcementWithRosReward.cnnt | 56 +++++++++++++++++++ 6 files changed, 94 insertions(+), 22 deletions(-) create mode 100644 src/test/resources/valid_tests/ReinforcementWithRosReward.cnnt diff --git a/src/main/grammars/de/monticore/lang/monticar/CNNTrain.mc4 b/src/main/grammars/de/monticore/lang/monticar/CNNTrain.mc4 index 51442db..22c2e48 100644 --- a/src/main/grammars/de/monticore/lang/monticar/CNNTrain.mc4 +++ b/src/main/grammars/de/monticore/lang/monticar/CNNTrain.mc4 @@ -171,6 +171,7 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number RosEnvironmentActionTopicEntry implements RosEnvironmentEntry = name:"action_topic" ":" value:StringValue; RosEnvironmentResetTopicEntry implements RosEnvironmentEntry = name:"reset_topic" ":" value:StringValue; RosEnvironmentTerminalStateTopicEntry implements RosEnvironmentEntry = name:"terminal_state_topic" ":" value:StringValue; + RosEnvironmentRewardTopicEntry implements RosEnvironmentEntry = name:"reward_topic" ":" value:StringValue; // DQN exclusive parameters UseFixTargetNetworkEntry implements ConfigEntry = name:"use_fix_target_network" ":" value:BooleanValue; diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ASTConfigurationUtils.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ASTConfigurationUtils.java index 6ff1535..a955b73 100644 --- a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ASTConfigurationUtils.java +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ASTConfigurationUtils.java @@ -60,4 +60,28 @@ class ASTConfigurationUtils { .findFirst() .map(astStrategyEntry -> astStrategyEntry.getValue().getName()); } + + static boolean hasRewardFunction(final ASTConfiguration node) { + return node.getEntriesList().stream().anyMatch(e -> e instanceof ASTRewardFunctionEntry); + } + + static boolean hasRosEnvironment(final ASTConfiguration node) { + return ASTConfigurationUtils.hasEnvironment(node) + && node.getEntriesList().stream() + .anyMatch(e -> (e instanceof ASTEnvironmentEntry) + && ((ASTEnvironmentEntry)e).getValue().getName().equals("ros_interface")); + } + + static boolean hasRewardTopic(final ASTConfiguration node) { + if (ASTConfigurationUtils.isReinforcementLearning(node) && ASTConfigurationUtils.hasEnvironment(node)) { + return node.getEntriesList().stream() + .filter(ASTEnvironmentEntry.class::isInstance) + .map(e -> (ASTEnvironmentEntry)e) + .reduce((element, other) -> { throw new IllegalStateException("More than one entry");}) + .map(astEnvironmentEntry -> astEnvironmentEntry.getValue().getParamsList().stream() + .anyMatch(e -> e instanceof ASTRosEnvironmentRewardTopicEntry)).orElse(false); + + } + return false; + } } diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckRosEnvironmentRequiresRewardFunction.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckRosEnvironmentRequiresRewardFunction.java index 9903507..1f6b3ec 100644 --- a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckRosEnvironmentRequiresRewardFunction.java +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckRosEnvironmentRequiresRewardFunction.java @@ -21,34 +21,23 @@ package de.monticore.lang.monticar.cnntrain._cocos; import de.monticore.lang.monticar.cnntrain._ast.*; -import de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol; -import de.monticore.lang.monticar.cnntrain._symboltable.EntrySymbol; -import de.monticore.lang.monticar.cnntrain._symboltable.Environment; -import de.monticore.lang.monticar.cnntrain._symboltable.LearningMethod; import de.monticore.lang.monticar.cnntrain.helper.ErrorCodes; import de.se_rwth.commons.logging.Log; -import java.util.Map; +import static de.monticore.lang.monticar.cnntrain._cocos.ASTConfigurationUtils.*; public class CheckRosEnvironmentRequiresRewardFunction implements CNNTrainASTConfigurationCoCo { @Override public void check(final ASTConfiguration node) { - if (ASTConfigurationUtils.isReinforcementLearning(node) - && hasRosEnvironment(node) - && !hasRewardFunction(node)) { - Log.error("0" + ErrorCodes.REQUIRED_PARAMETER_MISSING + " The required parameter reward_function" + - " is missing"); + // Specification of reward function only required for reinforcement learning via ROS since OpenAI Gym defines + // their own reward functions + if (isReinforcementLearning(node) && hasRosEnvironment(node)) { + // Reward needs to be either be calculated with a custom component or + if (!hasRewardFunction(node) && !hasRewardTopic(node)) { + Log.error("0" + ErrorCodes.REQUIRED_PARAMETER_MISSING + + " Reward function is missing. Either add a reward function component with parameter " + + "reward_function or add a ROS topic with parameter reward_topic."); + } } } - - private boolean hasRewardFunction(final ASTConfiguration node) { - return node.getEntriesList().stream().anyMatch(e -> e instanceof ASTRewardFunctionEntry); - } - - private boolean hasRosEnvironment(final ASTConfiguration node) { - return ASTConfigurationUtils.hasEnvironment(node) - && node.getEntriesList().stream() - .anyMatch(e -> (e instanceof ASTEnvironmentEntry) - && ((ASTEnvironmentEntry)e).getValue().getName().equals("ros_interface")); - } } \ No newline at end of file diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ParameterAlgorithmMapping.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ParameterAlgorithmMapping.java index 2c6d755..6c930ff 100644 --- a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ParameterAlgorithmMapping.java +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ParameterAlgorithmMapping.java @@ -26,7 +26,7 @@ import de.monticore.lang.monticar.cnntrain._ast.*; import java.util.List; -public class ParameterAlgorithmMapping { +class ParameterAlgorithmMapping { private static final List GENERAL_PARAMETERS = Lists.newArrayList( ASTTrainContextEntry.class, ASTOptimizerEntry.class, @@ -79,6 +79,7 @@ public class ParameterAlgorithmMapping { ASTRosEnvironmentStateTopicEntry.class, ASTRosEnvironmentResetTopicEntry.class, ASTRosEnvironmentTerminalStateTopicEntry.class, + ASTRosEnvironmentRewardTopicEntry.class, ASTStartTrainingAtEntry.class, ASTEvaluationSamplesEntry.class, ASTEpsilonDecayStartEntry.class, diff --git a/src/test/java/de/monticore/lang/monticar/cnntrain/cocos/AllCoCoTest.java b/src/test/java/de/monticore/lang/monticar/cnntrain/cocos/AllCoCoTest.java index f200e94..abfd28f 100644 --- a/src/test/java/de/monticore/lang/monticar/cnntrain/cocos/AllCoCoTest.java +++ b/src/test/java/de/monticore/lang/monticar/cnntrain/cocos/AllCoCoTest.java @@ -42,6 +42,7 @@ public class AllCoCoTest extends AbstractCoCoTest{ checkValid("valid_tests", "ReinforcementConfig"); checkValid("valid_tests", "ReinforcementConfig2"); checkValid("valid_tests", "DdpgConfig"); + checkValid("valid_tests", "ReinforcementWithRosReward"); } @Test diff --git a/src/test/resources/valid_tests/ReinforcementWithRosReward.cnnt b/src/test/resources/valid_tests/ReinforcementWithRosReward.cnnt new file mode 100644 index 0000000..30e0a4d --- /dev/null +++ b/src/test/resources/valid_tests/ReinforcementWithRosReward.cnnt @@ -0,0 +1,56 @@ +configuration ReinforcementWithRosReward { + learning_method : reinforcement + + agent_name : "reinforcement-agent" + + environment : ros_interface { + state_topic : "state" + action_topic : "action" + reset_topic : "reset" + terminal_state_topic : "is_terminal" + reward_topic: "reward" + } + + context : cpu + + num_episodes : 300 + num_max_steps : 9999 + discount_factor : 0.998 + target_score : 1000 + training_interval : 10 + + loss : huber_loss + + use_fix_target_network : true + target_network_update_interval : 100 + + use_double_dqn : true + + replay_memory : buffer{ + memory_size : 1000000 + sample_size : 64 + } + + strategy : epsgreedy{ + epsilon : 1.0 + min_epsilon : 0.01 + epsilon_decay_method: linear + epsilon_decay : 0.0001 + } + + optimizer : rmsprop{ + learning_rate : 0.001 + learning_rate_minimum : 0.00001 + weight_decay : 0.01 + learning_rate_decay : 0.9 + learning_rate_policy : step + step_size : 1000 + rescale_grad : 1.1 + clip_gradient : 10 + gamma1 : 0.9 + gamma2 : 0.9 + epsilon : 0.000001 + centered : true + clip_weights : 10 + } +} \ No newline at end of file -- GitLab From 2fc3d29d680040d19b6e65925e9e6f6d85b9185e Mon Sep 17 00:00:00 2001 From: Nicola Gatto Date: Sun, 2 Jun 2019 00:33:38 +0200 Subject: [PATCH 15/16] Add Cocos: Allow only one reward specification --- .../cnntrain/_cocos/CNNTrainCocos.java | 3 +- ...ironmentHasOnlyOneRewardSpecification.java | 42 ++++++++++++++ .../monticar/cnntrain/helper/ErrorCodes.java | 1 + .../monticar/cnntrain/cocos/AllCoCoTest.java | 3 + ...ironmentHasOnlyOneRewardSpecification.cnnt | 58 +++++++++++++++++++ 5 files changed, 106 insertions(+), 1 deletion(-) create mode 100644 src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckRosEnvironmentHasOnlyOneRewardSpecification.java create mode 100644 src/test/resources/invalid_cocos_tests/CheckRosEnvironmentHasOnlyOneRewardSpecification.cnnt diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CNNTrainCocos.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CNNTrainCocos.java index 5038a14..b465042 100644 --- a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CNNTrainCocos.java +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CNNTrainCocos.java @@ -37,7 +37,8 @@ public class CNNTrainCocos { .addCoCo(new CheckDdpgRequiresCriticNetwork()) .addCoCo(new CheckRlAlgorithmParameter()) .addCoCo(new CheckDiscreteRLAlgorithmUsesDiscreteStrategy()) - .addCoCo(new CheckContinuousRLAlgorithmUsesContinuousStrategy()); + .addCoCo(new CheckContinuousRLAlgorithmUsesContinuousStrategy()) + .addCoCo(new CheckRosEnvironmentHasOnlyOneRewardSpecification()); } public static void checkAll(CNNTrainCompilationUnitSymbol compilationUnit){ diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckRosEnvironmentHasOnlyOneRewardSpecification.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckRosEnvironmentHasOnlyOneRewardSpecification.java new file mode 100644 index 0000000..d3ef290 --- /dev/null +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckRosEnvironmentHasOnlyOneRewardSpecification.java @@ -0,0 +1,42 @@ +/** + * + * ****************************************************************************** + * MontiCAR Modeling Family, www.se-rwth.de + * Copyright (c) 2017, Software Engineering Group at RWTH Aachen, + * All rights reserved. + * + * This project is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 3.0 of the License, or (at your option) any later version. + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this project. If not, see . + * ******************************************************************************* + */ +package de.monticore.lang.monticar.cnntrain._cocos; + +import de.monticore.lang.monticar.cnntrain._ast.ASTConfiguration; +import de.monticore.lang.monticar.cnntrain.helper.ErrorCodes; +import de.se_rwth.commons.logging.Log; + +import static de.monticore.lang.monticar.cnntrain._cocos.ASTConfigurationUtils.*; + +public class CheckRosEnvironmentHasOnlyOneRewardSpecification implements CNNTrainASTConfigurationCoCo { + @Override + public void check(final ASTConfiguration node) { + if (isReinforcementLearning(node) + && hasRosEnvironment(node) + && hasRewardFunction(node) + && hasRewardTopic(node)) { + Log.error("0" + ErrorCodes.CONTRADICTING_PARAMETERS + + " Multiple reward calculation method specified. Either use a reward function component with " + + "parameter reward_function or use ROS topic with parameter reward_topic. " + + "Both is not possible"); + } + } +} \ No newline at end of file diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/helper/ErrorCodes.java b/src/main/java/de/monticore/lang/monticar/cnntrain/helper/ErrorCodes.java index 98c123b..6aa9e50 100644 --- a/src/main/java/de/monticore/lang/monticar/cnntrain/helper/ErrorCodes.java +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/helper/ErrorCodes.java @@ -30,4 +30,5 @@ public class ErrorCodes { public static final String RANGE_ERROR = "xC8855"; public static final String REQUIRED_PARAMETER_MISSING = "xC8856"; public static final String STRATEGY_NOT_APPLICABLE = "xC8857"; + public static final String CONTRADICTING_PARAMETERS = "xC8858"; } \ No newline at end of file diff --git a/src/test/java/de/monticore/lang/monticar/cnntrain/cocos/AllCoCoTest.java b/src/test/java/de/monticore/lang/monticar/cnntrain/cocos/AllCoCoTest.java index abfd28f..cd7d8f6 100644 --- a/src/test/java/de/monticore/lang/monticar/cnntrain/cocos/AllCoCoTest.java +++ b/src/test/java/de/monticore/lang/monticar/cnntrain/cocos/AllCoCoTest.java @@ -92,5 +92,8 @@ public class AllCoCoTest extends AbstractCoCoTest{ checkInvalid(new CNNTrainCoCoChecker().addCoCo(new CheckContinuousRLAlgorithmUsesContinuousStrategy()), "invalid_cocos_tests", "CheckContinuousRLAlgorithmUsesContinuousStrategy", new ExpectedErrorInfo(1, ErrorCodes.STRATEGY_NOT_APPLICABLE)); + checkInvalid(new CNNTrainCoCoChecker().addCoCo(new CheckRosEnvironmentHasOnlyOneRewardSpecification()), + "invalid_cocos_tests", "CheckRosEnvironmentHasOnlyOneRewardSpecification", + new ExpectedErrorInfo(1, ErrorCodes.CONTRADICTING_PARAMETERS)); } } diff --git a/src/test/resources/invalid_cocos_tests/CheckRosEnvironmentHasOnlyOneRewardSpecification.cnnt b/src/test/resources/invalid_cocos_tests/CheckRosEnvironmentHasOnlyOneRewardSpecification.cnnt new file mode 100644 index 0000000..2f114df --- /dev/null +++ b/src/test/resources/invalid_cocos_tests/CheckRosEnvironmentHasOnlyOneRewardSpecification.cnnt @@ -0,0 +1,58 @@ +configuration CheckRosEnvironmentHasOnlyOneRewardSpecification { + learning_method : reinforcement + + agent_name : "reinforcement-agent" + + environment : ros_interface { + state_topic : "state" + action_topic : "action" + reset_topic : "reset" + terminal_state_topic : "is_terminal" + reward_topic: "reward" + } + + reward_function : path.to.reward.component + + context : cpu + + num_episodes : 300 + num_max_steps : 9999 + discount_factor : 0.998 + target_score : 1000 + training_interval : 10 + + loss : huber_loss + + use_fix_target_network : true + target_network_update_interval : 100 + + use_double_dqn : true + + replay_memory : buffer{ + memory_size : 1000000 + sample_size : 64 + } + + strategy : epsgreedy{ + epsilon : 1.0 + min_epsilon : 0.01 + epsilon_decay_method: linear + epsilon_decay : 0.0001 + } + + optimizer : rmsprop{ + learning_rate : 0.001 + learning_rate_minimum : 0.00001 + weight_decay : 0.01 + learning_rate_decay : 0.9 + learning_rate_policy : step + step_size : 1000 + rescale_grad : 1.1 + clip_gradient : 10 + gamma1 : 0.9 + gamma2 : 0.9 + epsilon : 0.000001 + centered : true + clip_weights : 10 + } +} \ No newline at end of file -- GitLab From 6eb431134ec7ec4a891a164ec6f517d548cbb4f4 Mon Sep 17 00:00:00 2001 From: Nicola Gatto Date: Mon, 3 Jun 2019 18:45:59 +0200 Subject: [PATCH 16/16] Update README.md --- README.md | 67 ++++++++++++++++++++++++++++++++++++++----------------- 1 file changed, 47 insertions(+), 20 deletions(-) diff --git a/README.md b/README.md index f1e4b87..f019ce9 100644 --- a/README.md +++ b/README.md @@ -114,24 +114,31 @@ configuration ReinforcementConfig { ### Available Parameters for Reinforcement Learning -| Parameter | Value | Default | Required | Description | -|------------|--------|---------|----------|-------------| -|learning_method| reinforcement,supervised | supervised | No | Determines that this CNNTrain configuration is a reinforcement or supervised learning configuration | -| agent_name | String | "agent" | No | Names the agent (e.g. for logging output) | -|environment | gym, ros_interface | Yes | / | If *ros_interface* is selected, then the agent and the environment communicates via [ROS](http://www.ros.org/). The gym environment comes with a set of environments which are listed [here](https://gym.openai.com/) | -| context | cpu, gpu | cpu | No | Determines whether the GPU is used during training or the CPU | -| num_episodes | Integer | 50 | No | Number of episodes the agent is trained. An episode is a full passing of a game from an initial state to a terminal state.| -| num_max_steps | Integer | 99999 | No | Number of steps within an episodes before the environment is forced to reset the state (e.g. to avoid a state in which the agent is stuck) | -|discount_factor | Float | 0.9 | No | Discount factor | -| target_score | Float | None | No | If set, the agent stops the training when the average score of the last 100 episodes is greater than the target score. | -| training_interval | Integer | 1 | No | Number of steps between two trainings | -| loss | euclidean, l1, softmax_cross_entropy, sigmoid_cross_entropy, huber_loss | euclidean | No | Selects the loss function -| use_fix_target_network | bool | false | No | If set, an extra network with fixed parameters is used to estimate the Q values | -| target_network_update_interval | Integer | / | Yes, if fixed target network is true | If *use_fix_target_network* is set, it determines the number of steps after the target network is updated (Minh et. al. "Human Level Control through Deep Reinforcement Learning")| +| Parameter | Value | Default | Required | Algorithm | Description | +|------------|--------|---------|----------|-----------|-------------| +|learning_method| reinforcement,supervised | supervised | No | All | Determines that this CNNTrain configuration is a reinforcement or supervised learning configuration | +| rl_algorithm | ddpg-algorithm, dqn-algorithm | dqn-algorithm | No | All | Determines the RL algorithm that is used to train the agent +| agent_name | String | "agent" | No | All | Names the agent (e.g. for logging output) | +|environment | gym, ros_interface | Yes | / | All | If *ros_interface* is selected, then the agent and the environment communicates via [ROS](http://www.ros.org/). The gym environment comes with a set of environments which are listed [here](https://gym.openai.com/) | +| context | cpu, gpu | cpu | No | All | Determines whether the GPU is used during training or the CPU | +| num_episodes | Integer | 50 | No | All | Number of episodes the agent is trained. An episode is a full passing of a game from an initial state to a terminal state.| +| num_max_steps | Integer | 99999 | No | All | Number of steps within an episodes before the environment is forced to reset the state (e.g. to avoid a state in which the agent is stuck) | +|discount_factor | Float | 0.9 | No | All | Discount factor | +| target_score | Float | None | No | All | If set, the agent stops the training when the average score of the last 100 episodes is greater than the target score. | +| training_interval | Integer | 1 | No | All | Number of steps between two trainings | +| loss | euclidean, l1, softmax_cross_entropy, sigmoid_cross_entropy, huber_loss | euclidean | No | DQN | Selects the loss function +| use_fix_target_network | bool | false | No | DQN | If set, an extra network with fixed parameters is used to estimate the Q values | +| target_network_update_interval | Integer | / | DQN | Yes, if fixed target network is true | If *use_fix_target_network* is set, it determines the number of steps after the target network is updated (Minh et. al. "Human Level Control through Deep Reinforcement Learning")| | use_double_dqn | bool | false | No | If set, two value functions are used to determine the action values (Hasselt et. al. "Deep Reinforcement Learning with Double Q Learning") | -| replay_memory | buffer, online, combined | buffer | No | Determines the behaviour of the replay memory | -| action_selection | epsgreedy | epsgreedy | No | Determines the action selection policy during the training | -| reward_function | Full name of an EMAM component | / | Yes, if *ros_interface* is selected as the environment | The EMAM component that is used to calculate the reward. It must have two inputs, one for the current state and one boolean input that determines if the current state is terminal. It must also have exactly one output which represents the reward. | +| replay_memory | buffer, online, combined | buffer | No | All | Determines the behaviour of the replay memory | +| strategy | epsgreedy, ornstein_uhlenbeck | epsgreedy (discrete), ornstein_uhlenbeck (continuous) | No | All | Determines the action selection policy during the training | +| reward_function | Full name of an EMAM component | / | Yes, if *ros_interface* is selected as the environment and no reward topic is given | All | The EMAM component that is used to calculate the reward. It must have two inputs, one for the current state and one boolean input that determines if the current state is terminal. It must also have exactly one output which represents the reward. | +critic | Full name of architecture definition | / | Yes, if DDPG is selected | DDPG | The architecture definition which specifies the architecture of the critic network | +soft_target_update_rate | Float | 0.001 | No | DDPG | Determines the update rate of the critic and actor target network | +actor_optimizer | See supervised learning | adam with LR .0001 | No | DDPG | Determines the optimizer parameters of the actor network | +critic_optimizer | See supervised learning | adam with LR .001 | No | DDPG | Determines the optimizer parameters of the critic network | +| start_training_at | Integer | 0 | No | All | Determines at which episode the training starts | +| evaluation_samples | Integer | 100 | No | All | Determines how many epsiodes are run when evaluating the network | #### Environment @@ -169,19 +176,39 @@ No buffer is used. Only the current SARS tuple is used for taining. Combination of *online* and *buffer*. Both the current SARS tuple as well as a sample from the buffer are used for each training step. Parameters are the same as *buffer*. -### Action Selection +### Strategy -Determines the behaviour when selecting an action based on the values. (Currently, only epsilon greedy is available.) +Determines the behaviour when selecting an action based on the values. #### Option: epsgreedy -Selects an action based on Epsilon-Greedy-Policy. This means, based on epsilon, either a random action is choosen or an action with the highest value. Additional parameters: +This strategy is only available for discrete problems. It selects an action based on Epsilon-Greedy-Policy. This means, based on epsilon, either a random action is choosen or an action with the highest Q-value. Additional parameters: - **epsilon**: Probability of choosing an action randomly - **epsilon_decay_method**: Method which determines how epsilon decreases after each step. Can be *linear* for linear decrease or *no* for no decrease. +- **epsilon_decay_start**: Number of Episodes after the decay of epsilon starts - **epsilon_decay**: The actual decay of epsilon after each step. - **min_epsilon**: After *min_epsilon* is reached, epsilon is not decreased further. + +#### Option: ornstein_uhlenbeck +This strategy is only available for continuous problems. The action is selected based on the actor network. Based on the current epsilon, noise is added based on the [Ornstein-Uhlenbeck](https://en.wikipedia.org/wiki/Ornstein%E2%80%93Uhlenbeck_process) process. Additional parameters: + +All epsilon parameters from epsgreedy strategy can be used. Additionally, **mu**, **theta**, and **sigma** needs to be specified. For each action output you can specify the corresponding value with a tuple-style notation: `(x,y,z)` + +Example: Given an actor network with action output of shape (3,), we can write + +```EMADL + strategy: ornstein_uhlenbeck{ + ... + mu: (0.0, 0.1, 0.3) + theta: (0.5, 0.0, 0.8) + sigma: (0.3, 0.6, -0.9) + } +``` + +to specify the parameters for each place. + ## Generation To execute generation in your project, use the following code to generate a separate Config file: -- GitLab