From 2a2abdc558a304e232331e18c7ec26a11ae077d2 Mon Sep 17 00:00:00 2001 From: Nicola Gatto Date: Mon, 8 Jul 2019 00:18:49 +0200 Subject: [PATCH 01/21] Add parameter for gaussian strategy --- .../de/monticore/lang/monticar/CNNTrain.mc4 | 5 +++- ...uousRLAlgorithmUsesContinuousStrategy.java | 1 + .../monticar/cnntrain/cocos/AllCoCoTest.java | 1 + src/test/resources/valid_tests/TD3Config.cnnt | 26 +++++++++++++++++++ 4 files changed, 32 insertions(+), 1 deletion(-) create mode 100644 src/test/resources/valid_tests/TD3Config.cnnt diff --git a/src/main/grammars/de/monticore/lang/monticar/CNNTrain.mc4 b/src/main/grammars/de/monticore/lang/monticar/CNNTrain.mc4 index 3503f53..362cce6 100644 --- a/src/main/grammars/de/monticore/lang/monticar/CNNTrain.mc4 +++ b/src/main/grammars/de/monticore/lang/monticar/CNNTrain.mc4 @@ -176,11 +176,14 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number interface StrategyOrnsteinUhlenbeckEntry extends Entry; StrategyOrnsteinUhlenbeckValue implements StrategyValue = name:"ornstein_uhlenbeck" ("{" params:StrategyOrnsteinUhlenbeckEntry* "}")?; + interface StrategyGaussianEntry extends Entry; + StrategyGaussianValue implements StrategyValue = name:"gaussian" ("{" params:StrategyGaussianEntry* "}")?; + StrategyOUMu implements StrategyOrnsteinUhlenbeckEntry = name: "mu" ":" value:DoubleVectorValue; StrategyOUTheta implements StrategyOrnsteinUhlenbeckEntry = name: "theta" ":" value:DoubleVectorValue; StrategyOUSigma implements StrategyOrnsteinUhlenbeckEntry = name: "sigma" ":" value:DoubleVectorValue; - interface GeneralStrategyEntry extends StrategyEpsGreedyEntry, StrategyOrnsteinUhlenbeckEntry; + interface GeneralStrategyEntry extends StrategyEpsGreedyEntry, StrategyOrnsteinUhlenbeckEntry, StrategyGaussianEntry; GreedyEpsilonEntry implements GeneralStrategyEntry = name:"epsilon" ":" value:NumberValue; MinEpsilonEntry implements GeneralStrategyEntry = name:"min_epsilon" ":" value:NumberValue; diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckContinuousRLAlgorithmUsesContinuousStrategy.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckContinuousRLAlgorithmUsesContinuousStrategy.java index 5d0a68c..16d4f73 100644 --- a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckContinuousRLAlgorithmUsesContinuousStrategy.java +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckContinuousRLAlgorithmUsesContinuousStrategy.java @@ -30,6 +30,7 @@ import java.util.Set; public class CheckContinuousRLAlgorithmUsesContinuousStrategy implements CNNTrainASTConfigurationCoCo{ private static final Set CONTINUOUS_STRATEGIES = ImmutableSet.builder() .add("ornstein_uhlenbeck") + .add("gaussian") .build(); @Override diff --git a/src/test/java/de/monticore/lang/monticar/cnntrain/cocos/AllCoCoTest.java b/src/test/java/de/monticore/lang/monticar/cnntrain/cocos/AllCoCoTest.java index cd7d8f6..a8512b5 100644 --- a/src/test/java/de/monticore/lang/monticar/cnntrain/cocos/AllCoCoTest.java +++ b/src/test/java/de/monticore/lang/monticar/cnntrain/cocos/AllCoCoTest.java @@ -42,6 +42,7 @@ public class AllCoCoTest extends AbstractCoCoTest{ checkValid("valid_tests", "ReinforcementConfig"); checkValid("valid_tests", "ReinforcementConfig2"); checkValid("valid_tests", "DdpgConfig"); + checkValid("valid_tests", "TD3Config"); checkValid("valid_tests", "ReinforcementWithRosReward"); } diff --git a/src/test/resources/valid_tests/TD3Config.cnnt b/src/test/resources/valid_tests/TD3Config.cnnt new file mode 100644 index 0000000..d45b64f --- /dev/null +++ b/src/test/resources/valid_tests/TD3Config.cnnt @@ -0,0 +1,26 @@ +configuration TD3Config { + learning_method : reinforcement + rl_algorithm : ddpg-algorithm + critic : path.to.component + environment : gym { name:"CartPole-v1" } + soft_target_update_rate: 0.001 + actor_optimizer : adam{ + learning_rate : 0.0001 + learning_rate_minimum : 0.00005 + learning_rate_decay : 0.9 + learning_rate_policy : step + } + critic_optimizer : rmsprop{ + learning_rate : 0.001 + learning_rate_minimum : 0.0001 + learning_rate_decay : 0.5 + learning_rate_policy : step + } + strategy : gaussian { + epsilon: 1.0 + min_epsilon: 0.001 + epsilon_decay_method: linear + epsilon_decay : 0.0001 + epsilon_decay_start: 50 + } +} \ No newline at end of file -- GitLab From e3b5108d5fd40d3ca5a1c8c056bfc599942d0f54 Mon Sep 17 00:00:00 2001 From: Nicola Gatto Date: Mon, 8 Jul 2019 00:26:04 +0200 Subject: [PATCH 02/21] Add parameter for strategy epsilon decay per step --- src/main/grammars/de/monticore/lang/monticar/CNNTrain.mc4 | 1 + .../lang/monticar/cnntrain/_cocos/ParameterAlgorithmMapping.java | 1 + src/test/resources/valid_tests/TD3Config.cnnt | 1 + 3 files changed, 3 insertions(+) diff --git a/src/main/grammars/de/monticore/lang/monticar/CNNTrain.mc4 b/src/main/grammars/de/monticore/lang/monticar/CNNTrain.mc4 index 362cce6..e32e500 100644 --- a/src/main/grammars/de/monticore/lang/monticar/CNNTrain.mc4 +++ b/src/main/grammars/de/monticore/lang/monticar/CNNTrain.mc4 @@ -190,6 +190,7 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number EpsilonDecayStartEntry implements GeneralStrategyEntry = name:"epsilon_decay_start" ":" value:IntegerValue; EpsilonDecayMethodEntry implements GeneralStrategyEntry = name:"epsilon_decay_method" ":" value:EpsilonDecayMethodValue; EpsilonDecayMethodValue implements ConfigValue = (linear:"linear" | no:"no"); + EpsilonDecayPerStepEntry implements GeneralStrategyEntry = name:"epsilon_decay_per_step" ":" value:BooleanValue; EpsilonDecayEntry implements GeneralStrategyEntry = name:"epsilon_decay" ":" value:NumberValue; // Environment diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ParameterAlgorithmMapping.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ParameterAlgorithmMapping.java index 6443cf4..a77f4bf 100644 --- a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ParameterAlgorithmMapping.java +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ParameterAlgorithmMapping.java @@ -79,6 +79,7 @@ class ParameterAlgorithmMapping { ASTMinEpsilonEntry.class, ASTEpsilonDecayEntry.class, ASTEpsilonDecayMethodEntry.class, + ASTEpsilonDecayPerStepEntry.class, ASTNumEpisodesEntry.class, ASTRosEnvironmentActionTopicEntry.class, ASTRosEnvironmentStateTopicEntry.class, diff --git a/src/test/resources/valid_tests/TD3Config.cnnt b/src/test/resources/valid_tests/TD3Config.cnnt index d45b64f..6d8bbba 100644 --- a/src/test/resources/valid_tests/TD3Config.cnnt +++ b/src/test/resources/valid_tests/TD3Config.cnnt @@ -19,6 +19,7 @@ configuration TD3Config { strategy : gaussian { epsilon: 1.0 min_epsilon: 0.001 + epsilon_decay_per_step: true epsilon_decay_method: linear epsilon_decay : 0.0001 epsilon_decay_start: 50 -- GitLab From 39987bfaf6cc94dc038810d00a180693c6004e65 Mon Sep 17 00:00:00 2001 From: Nicola Gatto Date: Tue, 9 Jul 2019 00:59:09 +0200 Subject: [PATCH 03/21] Rename cocos check --- ... CheckActorCriticRequiresCriticNetwork.java} | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) rename src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/{CheckDdpgRequiresCriticNetwork.java => CheckActorCriticRequiresCriticNetwork.java} (76%) diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckDdpgRequiresCriticNetwork.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckActorCriticRequiresCriticNetwork.java similarity index 76% rename from src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckDdpgRequiresCriticNetwork.java rename to src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckActorCriticRequiresCriticNetwork.java index 6ca5412..821c6fb 100644 --- a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckDdpgRequiresCriticNetwork.java +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckActorCriticRequiresCriticNetwork.java @@ -26,24 +26,23 @@ import de.monticore.lang.monticar.cnntrain._ast.ASTRLAlgorithmEntry; import de.monticore.lang.monticar.cnntrain.helper.ErrorCodes; import de.se_rwth.commons.logging.Log; -public class CheckDdpgRequiresCriticNetwork implements CNNTrainASTConfigurationCoCo { +import static de.monticore.lang.monticar.cnntrain._cocos.ASTConfigurationUtils.hasCriticEntry; +import static de.monticore.lang.monticar.cnntrain._cocos.ASTConfigurationUtils.isActorCriticAlgorithm; + +public class CheckActorCriticRequiresCriticNetwork implements CNNTrainASTConfigurationCoCo { @Override public void check(ASTConfiguration node) { - boolean isDdpg = node.getEntriesList().stream() - .anyMatch(e -> e instanceof ASTRLAlgorithmEntry - && ((ASTRLAlgorithmEntry)e).getValue().isPresentDdpg()); - boolean hasCriticEntry = node.getEntriesList().stream() - .anyMatch(e -> ((e instanceof ASTCriticNetworkEntry) - && !((ASTCriticNetworkEntry)e).getValue().getNameList().isEmpty())); + boolean isActorCritic = isActorCriticAlgorithm(node); + boolean hasCriticEntry = hasCriticEntry(node); - if (isDdpg && !hasCriticEntry) { + if (isActorCritic && !hasCriticEntry) { ASTRLAlgorithmEntry algorithmEntry = node.getEntriesList().stream() .filter(e -> e instanceof ASTRLAlgorithmEntry) .map(e -> (ASTRLAlgorithmEntry)e) .findFirst() .orElseThrow(() -> new IllegalStateException("ASTRLAlgorithmEntry entry must be available")); - Log.error("0" + ErrorCodes.REQUIRED_PARAMETER_MISSING + " DDPG learning algorithm requires critc" + + Log.error("0" + ErrorCodes.REQUIRED_PARAMETER_MISSING + " DDPG learning algorithm requires critic" + " network entry", algorithmEntry.get_SourcePositionStart()); } } -- GitLab From b5b9cdaa1088ce9a27cd0f0ae356f8f7d5485d76 Mon Sep 17 00:00:00 2001 From: Nicola Gatto Date: Tue, 9 Jul 2019 00:59:31 +0200 Subject: [PATCH 04/21] Isolate test cases --- .../monticar/cnntrain/cocos/AllCoCoTest.java | 106 ++++++++++++++++-- 1 file changed, 99 insertions(+), 7 deletions(-) diff --git a/src/test/java/de/monticore/lang/monticar/cnntrain/cocos/AllCoCoTest.java b/src/test/java/de/monticore/lang/monticar/cnntrain/cocos/AllCoCoTest.java index a8512b5..85a88cc 100644 --- a/src/test/java/de/monticore/lang/monticar/cnntrain/cocos/AllCoCoTest.java +++ b/src/test/java/de/monticore/lang/monticar/cnntrain/cocos/AllCoCoTest.java @@ -34,65 +34,157 @@ public class AllCoCoTest extends AbstractCoCoTest{ } @Test - public void testValidCoCos() throws IOException { + public void testValidSimpleConfig1() { checkValid("valid_tests","SimpleConfig1"); + } + + @Test + public void testValidSimpleConfig2() { checkValid("valid_tests","SimpleConfig2"); + } + + @Test + public void testValidFullConfig() { checkValid("valid_tests","FullConfig"); + } + + @Test + public void testValidFullConfig2() { checkValid("valid_tests","FullConfig2"); - checkValid("valid_tests", "ReinforcementConfig"); - checkValid("valid_tests", "ReinforcementConfig2"); - checkValid("valid_tests", "DdpgConfig"); - checkValid("valid_tests", "TD3Config"); + } + + @Test + public void testValidReinforcementConfig() { + checkValid("valid_tests","ReinforcementConfig"); + } + + @Test + public void testValidReinforcementConfig2() { + checkValid("valid_tests","ReinforcementConfig2"); + } + + @Test + public void testValidDdpgConfig() { + checkValid("valid_tests","DdpgConfig"); + } + + @Test + public void testValidTD3Config() { + checkValid("valid_tests","TD3Config"); + } + + @Test + public void testValidReinforcementWithRosReward() throws IOException { checkValid("valid_tests", "ReinforcementWithRosReward"); } @Test - public void testInvalidCoCos() throws IOException { + public void testInvalidEntryRepetition() { checkInvalid(new CNNTrainCoCoChecker().addCoCo(new CheckEntryRepetition()), "invalid_cocos_tests", "EntryRepetition", new ExpectedErrorInfo(1, ErrorCodes.ENTRY_REPETITION_CODE)); + } + + @Test + public void testInvalidIntegerTest() { checkInvalid(new CNNTrainCoCoChecker().addCoCo(new CheckInteger()), "invalid_cocos_tests", "IntegerTest", new ExpectedErrorInfo(1, ErrorCodes.NOT_INTEGER_CODE)); + } + + @Test + public void testInvalidFixTargetNetworkRequiresInterval1() { checkInvalid(new CNNTrainCoCoChecker().addCoCo(new CheckFixTargetNetworkRequiresInterval()), "invalid_cocos_tests", "FixTargetNetworkRequiresInterval1", new ExpectedErrorInfo(1, ErrorCodes.REQUIRED_PARAMETER_MISSING)); + } + + @Test + public void testInvalidFixTargetNetworkRequiresInterval2() { checkInvalid(new CNNTrainCoCoChecker().addCoCo(new CheckFixTargetNetworkRequiresInterval()), "invalid_cocos_tests", "FixTargetNetworkRequiresInterval2", new ExpectedErrorInfo(1, ErrorCodes.REQUIRED_PARAMETER_MISSING)); + } + + @Test + public void testInvalidCheckLearningParameterCombination1() { checkInvalid(new CNNTrainCoCoChecker().addCoCo(new CheckLearningParameterCombination()), "invalid_cocos_tests", "CheckLearningParameterCombination1", new ExpectedErrorInfo(1, ErrorCodes.UNSUPPORTED_PARAMETER)); + } + + @Test + public void testInvalidCheckLearningParameterCombination2() { checkInvalid(new CNNTrainCoCoChecker().addCoCo(new CheckLearningParameterCombination()), "invalid_cocos_tests", "CheckLearningParameterCombination2", new ExpectedErrorInfo(3, ErrorCodes.UNSUPPORTED_PARAMETER)); + } + + @Test + public void testInvalidCheckLearningParameterCombination3() { checkInvalid(new CNNTrainCoCoChecker().addCoCo(new CheckLearningParameterCombination()), "invalid_cocos_tests", "CheckLearningParameterCombination3", new ExpectedErrorInfo(2, ErrorCodes.UNSUPPORTED_PARAMETER)); + } + + @Test + public void testInvalidCheckLearningParameterCombination4() { checkInvalid(new CNNTrainCoCoChecker().addCoCo(new CheckLearningParameterCombination()), "invalid_cocos_tests", "CheckLearningParameterCombination4", new ExpectedErrorInfo(5, ErrorCodes.UNSUPPORTED_PARAMETER)); + } + + @Test + public void testInvalidCheckReinforcementRequiresEnvironment() { checkInvalid(new CNNTrainCoCoChecker().addCoCo(new CheckReinforcementRequiresEnvironment()), "invalid_cocos_tests", "CheckReinforcementRequiresEnvironment", new ExpectedErrorInfo(1, ErrorCodes.REQUIRED_PARAMETER_MISSING)); + } + + @Test + public void testInvalidCheckRosEnvironmentRequiresRewardFunction() { checkInvalid(new CNNTrainCoCoChecker().addCoCo(new CheckRosEnvironmentRequiresRewardFunction()), "invalid_cocos_tests", "CheckRosEnvironmentRequiresRewardFunction", new ExpectedErrorInfo(1, ErrorCodes.REQUIRED_PARAMETER_MISSING)); + } + + @Test + public void testInvalidCheckRLAlgorithmParameter1() { checkInvalid(new CNNTrainCoCoChecker().addCoCo(new CheckRlAlgorithmParameter()), "invalid_cocos_tests", "CheckRLAlgorithmParameter1", new ExpectedErrorInfo(1, ErrorCodes.UNSUPPORTED_PARAMETER)); + } + + @Test + public void testInvalidCheckRLAlgorithmParameter2() { checkInvalid(new CNNTrainCoCoChecker().addCoCo(new CheckRlAlgorithmParameter()), "invalid_cocos_tests", "CheckRLAlgorithmParameter2", - new ExpectedErrorInfo(1, ErrorCodes.UNSUPPORTED_PARAMETER)); + new ExpectedErrorInfo(2, ErrorCodes.UNSUPPORTED_PARAMETER)); + } + + @Test + public void testInvalidCheckRLAlgorithmParameter3() { checkInvalid(new CNNTrainCoCoChecker().addCoCo(new CheckRlAlgorithmParameter()), "invalid_cocos_tests", "CheckRLAlgorithmParameter3", new ExpectedErrorInfo(1, ErrorCodes.UNSUPPORTED_PARAMETER)); + } + + @Test + public void testInvalidCheckDiscreteRLAlgorithmUsesDiscreteStrategy() { checkInvalid(new CNNTrainCoCoChecker().addCoCo(new CheckDiscreteRLAlgorithmUsesDiscreteStrategy()), "invalid_cocos_tests", "CheckDiscreteRLAlgorithmUsesDiscreteStrategy", new ExpectedErrorInfo(1, ErrorCodes.STRATEGY_NOT_APPLICABLE)); + } + + @Test + public void testInvalidCheckContinuousRLAlgorithmUsesContinuousStrategy() { checkInvalid(new CNNTrainCoCoChecker().addCoCo(new CheckContinuousRLAlgorithmUsesContinuousStrategy()), "invalid_cocos_tests", "CheckContinuousRLAlgorithmUsesContinuousStrategy", new ExpectedErrorInfo(1, ErrorCodes.STRATEGY_NOT_APPLICABLE)); + } + + @Test + public void testInvalidCheckRosEnvironmentHasOnlyOneRewardSpecification() { checkInvalid(new CNNTrainCoCoChecker().addCoCo(new CheckRosEnvironmentHasOnlyOneRewardSpecification()), "invalid_cocos_tests", "CheckRosEnvironmentHasOnlyOneRewardSpecification", new ExpectedErrorInfo(1, ErrorCodes.CONTRADICTING_PARAMETERS)); -- GitLab From cdc5657aabe86571464b14e7132ca6ebf113c3d7 Mon Sep 17 00:00:00 2001 From: Nicola Gatto Date: Tue, 9 Jul 2019 01:02:50 +0200 Subject: [PATCH 05/21] Add TD3 reinforcement learning parameter --- .../de/monticore/lang/monticar/CNNTrain.mc4 | 2 +- .../_cocos/ASTConfigurationUtils.java | 24 ++++++- .../cnntrain/_cocos/CNNTrainCocos.java | 2 +- ...uousRLAlgorithmUsesContinuousStrategy.java | 2 +- .../_cocos/CheckRlAlgorithmParameter.java | 68 +++++++++---------- .../_cocos/ParameterAlgorithmMapping.java | 9 +++ .../CNNTrainSymbolTableCreator.java | 2 + .../cnntrain/_symboltable/RLAlgorithm.java | 8 ++- src/test/resources/valid_tests/TD3Config.cnnt | 2 +- 9 files changed, 78 insertions(+), 41 deletions(-) diff --git a/src/main/grammars/de/monticore/lang/monticar/CNNTrain.mc4 b/src/main/grammars/de/monticore/lang/monticar/CNNTrain.mc4 index e32e500..f9db79b 100644 --- a/src/main/grammars/de/monticore/lang/monticar/CNNTrain.mc4 +++ b/src/main/grammars/de/monticore/lang/monticar/CNNTrain.mc4 @@ -145,7 +145,7 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number LearningMethodValue implements ConfigValue = (supervisedLearning:"supervised" | reinforcement:"reinforcement"); - RLAlgorithmValue implements ConfigValue = (dqn:"dqn-algorithm" | ddpg:"ddpg-algorithm"); + RLAlgorithmValue implements ConfigValue = (dqn:"dqn-algorithm" | ddpg:"ddpg-algorithm" | tdThree:"td3-algorithm"); interface MultiParamConfigEntry extends ConfigEntry; diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ASTConfigurationUtils.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ASTConfigurationUtils.java index a955b73..673d620 100644 --- a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ASTConfigurationUtils.java +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ASTConfigurationUtils.java @@ -41,8 +41,16 @@ class ASTConfigurationUtils { e -> (e instanceof ASTRLAlgorithmEntry) && ((ASTRLAlgorithmEntry)e).getValue().isPresentDdpg()); } + static boolean isTd3Algorithm(final ASTConfiguration configuration) { + return isReinforcementLearning(configuration) + && configuration.getEntriesList().stream().anyMatch( + e -> (e instanceof ASTRLAlgorithmEntry) && ((ASTRLAlgorithmEntry)e).getValue().isPresentTdThree()); + } + static boolean isDqnAlgorithm(final ASTConfiguration configuration) { - return isReinforcementLearning(configuration) && !isDdpgAlgorithm(configuration); + return isReinforcementLearning(configuration) + && !isDdpgAlgorithm(configuration) + && !isTd3Algorithm(configuration); } static boolean hasEntry(final ASTConfiguration configuration, final Class entryClazz) { @@ -84,4 +92,18 @@ class ASTConfigurationUtils { } return false; } + + static boolean isActorCriticAlgorithm(final ASTConfiguration node) { + return isDdpgAlgorithm(node) || isTd3Algorithm(node); + } + + static boolean hasCriticEntry(final ASTConfiguration node) { + return node.getEntriesList().stream() + .anyMatch(e -> ((e instanceof ASTCriticNetworkEntry) + && !((ASTCriticNetworkEntry)e).getValue().getNameList().isEmpty())); + } + + public static boolean isContinuousAlgorithm(final ASTConfiguration node) { + return isDdpgAlgorithm(node) || isTd3Algorithm(node); + } } diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CNNTrainCocos.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CNNTrainCocos.java index b465042..bd67172 100644 --- a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CNNTrainCocos.java +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CNNTrainCocos.java @@ -34,7 +34,7 @@ public class CNNTrainCocos { .addCoCo(new CheckReinforcementRequiresEnvironment()) .addCoCo(new CheckLearningParameterCombination()) .addCoCo(new CheckRosEnvironmentRequiresRewardFunction()) - .addCoCo(new CheckDdpgRequiresCriticNetwork()) + .addCoCo(new CheckActorCriticRequiresCriticNetwork()) .addCoCo(new CheckRlAlgorithmParameter()) .addCoCo(new CheckDiscreteRLAlgorithmUsesDiscreteStrategy()) .addCoCo(new CheckContinuousRLAlgorithmUsesContinuousStrategy()) diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckContinuousRLAlgorithmUsesContinuousStrategy.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckContinuousRLAlgorithmUsesContinuousStrategy.java index 16d4f73..008bcb6 100644 --- a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckContinuousRLAlgorithmUsesContinuousStrategy.java +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckContinuousRLAlgorithmUsesContinuousStrategy.java @@ -35,7 +35,7 @@ public class CheckContinuousRLAlgorithmUsesContinuousStrategy implements CNNTrai @Override public void check(ASTConfiguration node) { - if (ASTConfigurationUtils.isDdpgAlgorithm(node) + if (ASTConfigurationUtils.isContinuousAlgorithm(node) && ASTConfigurationUtils.hasStrategy(node) && ASTConfigurationUtils.getStrategyMethod(node).isPresent()) { final String usedStrategy = ASTConfigurationUtils.getStrategyMethod(node).get(); diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckRlAlgorithmParameter.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckRlAlgorithmParameter.java index 85d0bd1..4b70fe9 100644 --- a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckRlAlgorithmParameter.java +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckRlAlgorithmParameter.java @@ -20,6 +20,7 @@ */ package de.monticore.lang.monticar.cnntrain._cocos; +import de.monticore.lang.monticar.cnntrain._ast.ASTConfiguration; import de.monticore.lang.monticar.cnntrain._ast.ASTEntry; import de.monticore.lang.monticar.cnntrain._ast.ASTRLAlgorithmEntry; import de.monticore.lang.monticar.cnntrain._symboltable.RLAlgorithm; @@ -29,61 +30,58 @@ import de.se_rwth.commons.logging.Log; public class CheckRlAlgorithmParameter implements CNNTrainASTEntryCoCo { private final ParameterAlgorithmMapping parameterAlgorithmMapping; - boolean algorithmKnown; + private boolean isDqn = true; + private boolean isDdpg = true; + private boolean isTd3 = true; + RLAlgorithm algorithm; public CheckRlAlgorithmParameter() { parameterAlgorithmMapping = new ParameterAlgorithmMapping(); - algorithmKnown = false; - algorithm = null; } @Override public void check(ASTEntry node) { - final boolean isDdpgParameter = parameterAlgorithmMapping.isDdpgParameter(node.getClass()); - final boolean isDqnParameter = parameterAlgorithmMapping.isDqnParameter(node.getClass()); - + if (!parameterAlgorithmMapping.isReinforcementLearningParameter(node.getClass())) { + return; + } if (node instanceof ASTRLAlgorithmEntry) { ASTRLAlgorithmEntry algorithmEntry = (ASTRLAlgorithmEntry)node; if (algorithmEntry.getValue().isPresentDdpg()) { - setAlgorithmToDdpg(node); + logWrongParameterIfCheckFails(isDdpg, node); + isTd3 = false; + isDqn = false; + } else if(algorithmEntry.getValue().isPresentTdThree()) { + logWrongParameterIfCheckFails(isTd3, node); + isDdpg = false; + isDqn = false; } else { - setAlgorithmToDqn(node); + logWrongParameterIfCheckFails(isDqn, node); + isDdpg = false; + isTd3 = false; } } else { - if (isDdpgParameter && !isDqnParameter) { - setAlgorithmToDdpg(node); - } else if (!isDdpgParameter && isDqnParameter) { - setAlgorithmToDqn(node); + final boolean isDdpgParameter = parameterAlgorithmMapping.isDdpgParameter(node.getClass()); + final boolean isDqnParameter = parameterAlgorithmMapping.isDqnParameter(node.getClass()); + final boolean isTd3Parameter = parameterAlgorithmMapping.isTd3Parameter(node.getClass()); + if (!isDdpgParameter) { + isDdpg = false; + } + if (!isTd3Parameter) { + isTd3 = false; + } + if (!isDqnParameter) { + isDqn = false; } } + logWrongParameterIfCheckFails(isDqn || isTd3 || isDdpg, node); } - private void logErrorIfAlgorithmIsDqn(final ASTEntry node) { - if (algorithmKnown && algorithm.equals(RLAlgorithm.DQN)) { - Log.error("0" + ErrorCodes.UNSUPPORTED_PARAMETER - + " DDPG Parameter " + node.getName() + " used but algorithm is " + algorithm + ".", - node.get_SourcePositionStart()); - } - } - - private void setAlgorithmToDdpg(final ASTEntry node) { - logErrorIfAlgorithmIsDqn(node); - algorithmKnown = true; - algorithm = RLAlgorithm.DDPG; - } - - private void setAlgorithmToDqn(final ASTEntry node) { - logErrorIfAlgorithmIsDdpg(node); - algorithmKnown = true; - algorithm = RLAlgorithm.DQN; - } - - private void logErrorIfAlgorithmIsDdpg(final ASTEntry node) { - if (algorithmKnown && algorithm.equals(RLAlgorithm.DDPG)) { + private void logWrongParameterIfCheckFails(final boolean condition, final ASTEntry node) { + if (!condition) { Log.error("0" + ErrorCodes.UNSUPPORTED_PARAMETER - + " DQN Parameter " + node.getName() + " used but algorithm is " + algorithm + ".", + + "Parameter " + node.getName() + " used but parameter is not for chosen algorithm.", node.get_SourcePositionStart()); } } diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ParameterAlgorithmMapping.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ParameterAlgorithmMapping.java index a77f4bf..b082066 100644 --- a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ParameterAlgorithmMapping.java +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ParameterAlgorithmMapping.java @@ -108,6 +108,8 @@ class ParameterAlgorithmMapping { ASTStrategyOUSigma.class ); + private static final List EXCLUSIVE_TD3_PARAMETERS = Lists.newArrayList(EXCLUSIVE_DDPG_PARAMETERS); + ParameterAlgorithmMapping() { } @@ -136,12 +138,19 @@ class ParameterAlgorithmMapping { || EXCLUSIVE_DDPG_PARAMETERS.contains(entryClazz); } + boolean isTd3Parameter(Class entryClazz) { + return GENERAL_PARAMETERS.contains(entryClazz) + || GENERAL_REINFORCEMENT_PARAMETERS.contains(entryClazz) + || EXCLUSIVE_TD3_PARAMETERS.contains(entryClazz); + } + List getAllReinforcementParameters() { return ImmutableList. builder() .addAll(GENERAL_PARAMETERS) .addAll(GENERAL_REINFORCEMENT_PARAMETERS) .addAll(EXCLUSIVE_DQN_PARAMETERS) .addAll(EXCLUSIVE_DDPG_PARAMETERS) + .addAll(EXCLUSIVE_TD3_PARAMETERS) .build(); } diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/CNNTrainSymbolTableCreator.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/CNNTrainSymbolTableCreator.java index f38c711..05396ab 100644 --- a/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/CNNTrainSymbolTableCreator.java +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/CNNTrainSymbolTableCreator.java @@ -351,6 +351,8 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP { if (node.getValue().isPresentDdpg()) { value.setValue(RLAlgorithm.DDPG); + } else if(node.getValue().isPresentTdThree()) { + value.setValue(RLAlgorithm.TD3); } else { value.setValue(RLAlgorithm.DQN); } diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/RLAlgorithm.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/RLAlgorithm.java index f0e38ff..014649b 100644 --- a/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/RLAlgorithm.java +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/RLAlgorithm.java @@ -32,5 +32,11 @@ public enum RLAlgorithm { public String toString() { return "ddpg"; } + }, + TD3 { + @Override + public String toString() { + return "td3"; + } } -} +} \ No newline at end of file diff --git a/src/test/resources/valid_tests/TD3Config.cnnt b/src/test/resources/valid_tests/TD3Config.cnnt index 6d8bbba..ff428a7 100644 --- a/src/test/resources/valid_tests/TD3Config.cnnt +++ b/src/test/resources/valid_tests/TD3Config.cnnt @@ -1,6 +1,6 @@ configuration TD3Config { learning_method : reinforcement - rl_algorithm : ddpg-algorithm + rl_algorithm : td3-algorithm critic : path.to.component environment : gym { name:"CartPole-v1" } soft_target_update_rate: 0.001 -- GitLab From 73e80dc05802f8999a5b435b593228e3bd29dddb Mon Sep 17 00:00:00 2001 From: Nicola Gatto Date: Tue, 9 Jul 2019 01:07:19 +0200 Subject: [PATCH 06/21] Add invalid test case for td3 --- .../monticar/cnntrain/cocos/AllCoCoTest.java | 7 +++++++ .../CheckRLAlgorithmParameter4.cnnt | 19 +++++++++++++++++++ 2 files changed, 26 insertions(+) create mode 100644 src/test/resources/invalid_cocos_tests/CheckRLAlgorithmParameter4.cnnt diff --git a/src/test/java/de/monticore/lang/monticar/cnntrain/cocos/AllCoCoTest.java b/src/test/java/de/monticore/lang/monticar/cnntrain/cocos/AllCoCoTest.java index 85a88cc..059932d 100644 --- a/src/test/java/de/monticore/lang/monticar/cnntrain/cocos/AllCoCoTest.java +++ b/src/test/java/de/monticore/lang/monticar/cnntrain/cocos/AllCoCoTest.java @@ -169,6 +169,13 @@ public class AllCoCoTest extends AbstractCoCoTest{ new ExpectedErrorInfo(1, ErrorCodes.UNSUPPORTED_PARAMETER)); } + @Test + public void testInvalidCheckRLAlgorithmParameter4() { + checkInvalid(new CNNTrainCoCoChecker().addCoCo(new CheckRlAlgorithmParameter()), + "invalid_cocos_tests", "CheckRLAlgorithmParameter4", + new ExpectedErrorInfo(1, ErrorCodes.UNSUPPORTED_PARAMETER)); + } + @Test public void testInvalidCheckDiscreteRLAlgorithmUsesDiscreteStrategy() { checkInvalid(new CNNTrainCoCoChecker().addCoCo(new CheckDiscreteRLAlgorithmUsesDiscreteStrategy()), diff --git a/src/test/resources/invalid_cocos_tests/CheckRLAlgorithmParameter4.cnnt b/src/test/resources/invalid_cocos_tests/CheckRLAlgorithmParameter4.cnnt new file mode 100644 index 0000000..a8b0c27 --- /dev/null +++ b/src/test/resources/invalid_cocos_tests/CheckRLAlgorithmParameter4.cnnt @@ -0,0 +1,19 @@ +configuration CheckRLAlgorithmParameter4 { + learning_method : reinforcement + + rl_algorithm: td3-algorithm + + agent_name : "reinforcement-agent" + + environment : gym { name:"CartPole-v1" } + + context : cpu + + num_episodes : 300 + num_max_steps : 9999 + discount_factor : 0.998 + target_score : 1000 + training_interval : 10 + + use_double_dqn: true +} \ No newline at end of file -- GitLab From 11efa54b966af9535d6eb43900c267d0b91acf32 Mon Sep 17 00:00:00 2001 From: Nicola Gatto Date: Wed, 10 Jul 2019 00:20:57 +0200 Subject: [PATCH 07/21] Add td3 parameters policy_noise, noise_clip, policy_delay --- .../de/monticore/lang/monticar/CNNTrain.mc4 | 7 +++++- .../_cocos/ParameterAlgorithmMapping.java | 15 ++++++++++-- .../CNNTrainSymbolTableCreator.java | 24 +++++++++++++++++++ src/test/resources/valid_tests/TD3Config.cnnt | 4 ++++ 4 files changed, 47 insertions(+), 3 deletions(-) diff --git a/src/main/grammars/de/monticore/lang/monticar/CNNTrain.mc4 b/src/main/grammars/de/monticore/lang/monticar/CNNTrain.mc4 index f9db79b..3da8fde 100644 --- a/src/main/grammars/de/monticore/lang/monticar/CNNTrain.mc4 +++ b/src/main/grammars/de/monticore/lang/monticar/CNNTrain.mc4 @@ -215,8 +215,13 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number UseDoubleDQNEntry implements ConfigEntry = name:"use_double_dqn" ":" value:BooleanValue; - // DDPG exclusive parameters + // DDPG and TD3 exclusive parameters CriticNetworkEntry implements ConfigEntry = name:"critic" ":" value:ComponentNameValue; SoftTargetUpdateRateEntry implements ConfigEntry = name:"soft_target_update_rate" ":" value:NumberValue; CriticOptimizerEntry implements ConfigEntry = name:"critic_optimizer" ":" value:OptimizerValue; + + // TD3 exclusive parameters + PolicyNoiseEntry implements ConfigEntry = name:"policy_noise" ":" value:NumberValue; + NoiseClipEntry implements ConfigEntry = name:"noise_clip" ":" value:NumberValue; + PolicyDelayEntry implements ConfigEntry = name:"policy_delay" ":" value:IntegerValue; } \ No newline at end of file diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ParameterAlgorithmMapping.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ParameterAlgorithmMapping.java index b082066..9b75d54 100644 --- a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ParameterAlgorithmMapping.java +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ParameterAlgorithmMapping.java @@ -108,7 +108,17 @@ class ParameterAlgorithmMapping { ASTStrategyOUSigma.class ); - private static final List EXCLUSIVE_TD3_PARAMETERS = Lists.newArrayList(EXCLUSIVE_DDPG_PARAMETERS); + private static final List EXCLUSIVE_TD3_PARAMETERS = Lists.newArrayList( + ASTCriticNetworkEntry.class, + ASTSoftTargetUpdateRateEntry.class, + ASTCriticOptimizerEntry.class, + ASTStrategyOUMu.class, + ASTStrategyOUTheta.class, + ASTStrategyOUSigma.class, + ASTPolicyNoiseEntry.class, + ASTNoiseClipEntry.class, + ASTPolicyDelayEntry.class + ); ParameterAlgorithmMapping() { @@ -118,7 +128,8 @@ class ParameterAlgorithmMapping { return GENERAL_PARAMETERS.contains(entryClazz) || GENERAL_REINFORCEMENT_PARAMETERS.contains(entryClazz) || EXCLUSIVE_DQN_PARAMETERS.contains(entryClazz) - || EXCLUSIVE_DDPG_PARAMETERS.contains(entryClazz); + || EXCLUSIVE_DDPG_PARAMETERS.contains(entryClazz) + || EXCLUSIVE_TD3_PARAMETERS.contains(entryClazz); } boolean isSupervisedLearningParameter(Class entryClazz) { diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/CNNTrainSymbolTableCreator.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/CNNTrainSymbolTableCreator.java index 05396ab..495bec5 100644 --- a/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/CNNTrainSymbolTableCreator.java +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/CNNTrainSymbolTableCreator.java @@ -516,6 +516,30 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP { configuration.getEntryMap().put(node.getName(), entry); } + @Override + public void visit(ASTPolicyNoiseEntry node) { + EntrySymbol entry = new EntrySymbol(node.getName()); + entry.setValue(getValueSymbolForDouble(node.getValue())); + addToScopeAndLinkWithNode(entry, node); + configuration.getEntryMap().put(node.getName(), entry); + } + + @Override + public void visit(ASTNoiseClipEntry node) { + EntrySymbol entry = new EntrySymbol(node.getName()); + entry.setValue(getValueSymbolForDouble(node.getValue())); + addToScopeAndLinkWithNode(entry, node); + configuration.getEntryMap().put(node.getName(), entry); + } + + @Override + public void visit(ASTPolicyDelayEntry node) { + EntrySymbol entry = new EntrySymbol(node.getName()); + entry.setValue(getValueSymbolForInteger(node.getValue())); + addToScopeAndLinkWithNode(entry, node); + configuration.getEntryMap().put(node.getName(), entry); + } + private void processMultiParamConfigVisit(ASTMultiParamConfigEntry node, Object value) { EntrySymbol entry = new EntrySymbol(node.getName()); MultiParamValueSymbol valueSymbol = new MultiParamValueSymbol(); diff --git a/src/test/resources/valid_tests/TD3Config.cnnt b/src/test/resources/valid_tests/TD3Config.cnnt index ff428a7..c3dcc31 100644 --- a/src/test/resources/valid_tests/TD3Config.cnnt +++ b/src/test/resources/valid_tests/TD3Config.cnnt @@ -4,6 +4,10 @@ configuration TD3Config { critic : path.to.component environment : gym { name:"CartPole-v1" } soft_target_update_rate: 0.001 + policy_noise: 0.2 + noise_clip: 0.5 + policy_delay: 2 + actor_optimizer : adam{ learning_rate : 0.0001 learning_rate_minimum : 0.00005 -- GitLab From 8a0652b828ff822da4c1492ff45bfb18a1a4ea9b Mon Sep 17 00:00:00 2001 From: Nicola Gatto Date: Wed, 17 Jul 2019 00:23:48 +0200 Subject: [PATCH 08/21] Add parameter noise variance --- src/main/grammars/de/monticore/lang/monticar/CNNTrain.mc4 | 2 ++ .../cnntrain/_cocos/ParameterAlgorithmMapping.java | 6 ++++-- .../cnntrain/_symboltable/CNNTrainSymbolTableCreator.java | 8 ++++++++ src/test/resources/valid_tests/TD3Config.cnnt | 1 + 4 files changed, 15 insertions(+), 2 deletions(-) diff --git a/src/main/grammars/de/monticore/lang/monticar/CNNTrain.mc4 b/src/main/grammars/de/monticore/lang/monticar/CNNTrain.mc4 index 3da8fde..efc7507 100644 --- a/src/main/grammars/de/monticore/lang/monticar/CNNTrain.mc4 +++ b/src/main/grammars/de/monticore/lang/monticar/CNNTrain.mc4 @@ -179,6 +179,8 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number interface StrategyGaussianEntry extends Entry; StrategyGaussianValue implements StrategyValue = name:"gaussian" ("{" params:StrategyGaussianEntry* "}")?; + StrategyGaussianNoiseVarianceEntry implements StrategyGaussianEntry = name: "noise_variance" ":" value:NumberValue; + StrategyOUMu implements StrategyOrnsteinUhlenbeckEntry = name: "mu" ":" value:DoubleVectorValue; StrategyOUTheta implements StrategyOrnsteinUhlenbeckEntry = name: "theta" ":" value:DoubleVectorValue; StrategyOUSigma implements StrategyOrnsteinUhlenbeckEntry = name: "sigma" ":" value:DoubleVectorValue; diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ParameterAlgorithmMapping.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ParameterAlgorithmMapping.java index 9b75d54..e05aadb 100644 --- a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ParameterAlgorithmMapping.java +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ParameterAlgorithmMapping.java @@ -105,7 +105,8 @@ class ParameterAlgorithmMapping { ASTCriticOptimizerEntry.class, ASTStrategyOUMu.class, ASTStrategyOUTheta.class, - ASTStrategyOUSigma.class + ASTStrategyOUSigma.class, + ASTStrategyGaussianNoiseVarianceEntry.class ); private static final List EXCLUSIVE_TD3_PARAMETERS = Lists.newArrayList( @@ -117,7 +118,8 @@ class ParameterAlgorithmMapping { ASTStrategyOUSigma.class, ASTPolicyNoiseEntry.class, ASTNoiseClipEntry.class, - ASTPolicyDelayEntry.class + ASTPolicyDelayEntry.class, + ASTStrategyGaussianNoiseVarianceEntry.class ); ParameterAlgorithmMapping() { diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/CNNTrainSymbolTableCreator.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/CNNTrainSymbolTableCreator.java index 495bec5..25bc4b9 100644 --- a/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/CNNTrainSymbolTableCreator.java +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/CNNTrainSymbolTableCreator.java @@ -532,6 +532,14 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP { configuration.getEntryMap().put(node.getName(), entry); } + @Override + public void visit(ASTStrategyGaussianNoiseVarianceEntry node) { + EntrySymbol entry = new EntrySymbol(node.getName()); + entry.setValue(getValueSymbolForDouble(node.getValue())); + addToScopeAndLinkWithNode(entry, node); + configuration.getEntryMap().put(node.getName(), entry); + } + @Override public void visit(ASTPolicyDelayEntry node) { EntrySymbol entry = new EntrySymbol(node.getName()); diff --git a/src/test/resources/valid_tests/TD3Config.cnnt b/src/test/resources/valid_tests/TD3Config.cnnt index c3dcc31..0ed09ce 100644 --- a/src/test/resources/valid_tests/TD3Config.cnnt +++ b/src/test/resources/valid_tests/TD3Config.cnnt @@ -23,6 +23,7 @@ configuration TD3Config { strategy : gaussian { epsilon: 1.0 min_epsilon: 0.001 + noise_variance: 0.5 epsilon_decay_per_step: true epsilon_decay_method: linear epsilon_decay : 0.0001 -- GitLab From 3592294d86d2a749c72bc77ae4076d5dbdb1f679 Mon Sep 17 00:00:00 2001 From: Nicola Gatto Date: Wed, 17 Jul 2019 23:31:37 +0200 Subject: [PATCH 09/21] Increment version --- pom.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pom.xml b/pom.xml index 2e1be20..9493e23 100644 --- a/pom.xml +++ b/pom.xml @@ -30,7 +30,7 @@ de.monticore.lang.monticar cnn-train - 0.3.4-SNAPSHOT + 0.3.5-SNAPSHOT -- GitLab From e69c5e5099ae9db21374b5bd30ffcf5abbc8ee37 Mon Sep 17 00:00:00 2001 From: Nicola Gatto Date: Thu, 18 Jul 2019 12:58:07 +0200 Subject: [PATCH 10/21] Make neural network architecture a symbol --- .../_symboltable/ConfigurationSymbol.java | 9 ++-- .../_symboltable/NNArchitectureSymbol.java | 41 +++++++++++++++++++ .../NNArchitectureSymbolKind.java} | 23 +++++++---- 3 files changed, 58 insertions(+), 15 deletions(-) create mode 100644 src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/NNArchitectureSymbol.java rename src/main/java/de/monticore/lang/monticar/cnntrain/{annotations/TrainedArchitecture.java => _symboltable/NNArchitectureSymbolKind.java} (62%) diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/ConfigurationSymbol.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/ConfigurationSymbol.java index 1a7e882..632a1e8 100644 --- a/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/ConfigurationSymbol.java +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/ConfigurationSymbol.java @@ -20,11 +20,8 @@ */ package de.monticore.lang.monticar.cnntrain._symboltable; -import com.google.common.collect.Lists; -import de.monticore.lang.monticar.cnntrain.annotations.TrainedArchitecture; import de.monticore.symboltable.CommonScopeSpanningSymbol; -import javax.swing.text.html.Option; import java.util.*; public class ConfigurationSymbol extends CommonScopeSpanningSymbol { @@ -34,7 +31,7 @@ public class ConfigurationSymbol extends CommonScopeSpanningSymbol { private OptimizerSymbol criticOptimizer; private LossSymbol loss; private RewardFunctionSymbol rlRewardFunctionSymbol; - private TrainedArchitecture trainedArchitecture; + private NNArchitectureSymbol trainedArchitecture; public static final ConfigurationSymbolKind KIND = new ConfigurationSymbolKind(); @@ -76,11 +73,11 @@ public class ConfigurationSymbol extends CommonScopeSpanningSymbol { return Optional.ofNullable(this.rlRewardFunctionSymbol); } - public Optional getTrainedArchitecture() { + public Optional getTrainedArchitecture() { return Optional.ofNullable(trainedArchitecture); } - public void setTrainedArchitecture(TrainedArchitecture trainedArchitecture) { + public void setTrainedArchitecture(NNArchitectureSymbol trainedArchitecture) { this.trainedArchitecture = trainedArchitecture; } diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/NNArchitectureSymbol.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/NNArchitectureSymbol.java new file mode 100644 index 0000000..c606b90 --- /dev/null +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/NNArchitectureSymbol.java @@ -0,0 +1,41 @@ +/** + * + * ****************************************************************************** + * MontiCAR Modeling Family, www.se-rwth.de + * Copyright (c) 2017, Software Engineering Group at RWTH Aachen, + * All rights reserved. + * + * This project is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 3.0 of the License, or (at your option) any later version. + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this project. If not, see . + * ******************************************************************************* + */ +package de.monticore.lang.monticar.cnntrain._symboltable; + +import de.monticore.lang.monticar.cnntrain.annotations.Range; +import de.monticore.symboltable.SymbolKind; + +import java.util.List; +import java.util.Map; + +public abstract class NNArchitectureSymbol extends de.monticore.symboltable.CommonSymbol { + public static final NNArchitectureSymbolKind KIND = NNArchitectureSymbolKind.INSTANCE; + + public NNArchitectureSymbol(String name) { + super(name, KIND); + } + + abstract public List getInputs(); + abstract public List getOutputs(); + abstract public Map> getDimensions(); + abstract public Map getRanges(); + abstract public Map getTypes(); +} \ No newline at end of file diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/annotations/TrainedArchitecture.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/NNArchitectureSymbolKind.java similarity index 62% rename from src/main/java/de/monticore/lang/monticar/cnntrain/annotations/TrainedArchitecture.java rename to src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/NNArchitectureSymbolKind.java index 0308a50..f06e869 100644 --- a/src/main/java/de/monticore/lang/monticar/cnntrain/annotations/TrainedArchitecture.java +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/NNArchitectureSymbolKind.java @@ -18,16 +18,21 @@ * License along with this project. If not, see . * ******************************************************************************* */ -package de.monticore.lang.monticar.cnntrain.annotations; +package de.monticore.lang.monticar.cnntrain._symboltable; -import java.util.List; -import java.util.Map; +import de.monticore.symboltable.SymbolKind; -public interface TrainedArchitecture { - public List getInputs(); - public List getOutputs(); - public Map> getDimensions(); - public Map getRanges(); - public Map getTypes(); +public class NNArchitectureSymbolKind implements SymbolKind { + public static final NNArchitectureSymbolKind INSTANCE = new NNArchitectureSymbolKind(); + private static final String NAME = "de.monticore.lang.monticar.cnntrain._symboltable.NNArchitectureSymbolKind"; + @Override + public String getName() { + return NAME; + } + + @Override + public boolean isKindOf(SymbolKind kind) { + return NAME.equals(kind.getName()) || SymbolKind.super.isKindOf(kind); + } } \ No newline at end of file -- GitLab From 18344e12960cbf466d674ff2c9531c3d427adff3 Mon Sep 17 00:00:00 2001 From: Nicola Gatto Date: Thu, 18 Jul 2019 13:06:10 +0200 Subject: [PATCH 11/21] Add critic network to configuration symbol --- .../cnntrain/_symboltable/ConfigurationSymbol.java | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/ConfigurationSymbol.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/ConfigurationSymbol.java index 632a1e8..cf4f4d0 100644 --- a/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/ConfigurationSymbol.java +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/ConfigurationSymbol.java @@ -32,6 +32,7 @@ public class ConfigurationSymbol extends CommonScopeSpanningSymbol { private LossSymbol loss; private RewardFunctionSymbol rlRewardFunctionSymbol; private NNArchitectureSymbol trainedArchitecture; + private NNArchitectureSymbol criticNetwork; public static final ConfigurationSymbolKind KIND = new ConfigurationSymbolKind(); @@ -81,6 +82,14 @@ public class ConfigurationSymbol extends CommonScopeSpanningSymbol { this.trainedArchitecture = trainedArchitecture; } + public Optional getCriticNetwork() { + return Optional.ofNullable(criticNetwork); + } + + public void setCriticNetwork(NNArchitectureSymbol criticNetwork) { + this.criticNetwork = criticNetwork; + } + public Map getEntryMap() { return entryMap; } -- GitLab From adf69ddf9994e3229f785d7dbb4112e7198c7675 Mon Sep 17 00:00:00 2001 From: Nicola Gatto Date: Fri, 19 Jul 2019 00:55:35 +0200 Subject: [PATCH 12/21] Add convenient methods for configuration --- .../_symboltable/ConfigurationSymbol.java | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/ConfigurationSymbol.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/ConfigurationSymbol.java index cf4f4d0..797e77a 100644 --- a/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/ConfigurationSymbol.java +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/ConfigurationSymbol.java @@ -102,4 +102,22 @@ public class ConfigurationSymbol extends CommonScopeSpanningSymbol { return this.entryMap.containsKey("learning_method") ? (LearningMethod)this.entryMap.get("learning_method").getValue().getValue() : LearningMethod.SUPERVISED; } + + public boolean isReinforcementLearningMethod() { + return getLearningMethod().equals(LearningMethod.REINFORCEMENT); + } + + public boolean hasCritic() { + return getEntryMap().containsKey("critic"); + } + + public Optional getCriticName() { + if (!hasCritic()) { + return Optional.empty(); + } + + final Object criticNameValue = getEntry("critic").getValue().getValue(); + assert criticNameValue instanceof String; + return Optional.of((String)criticNameValue); + } } \ No newline at end of file -- GitLab From 37e759d3d3ce34dc8b66e7e48d7da05fbb0b9030 Mon Sep 17 00:00:00 2001 From: Nicola Gatto Date: Fri, 19 Jul 2019 01:51:51 +0200 Subject: [PATCH 13/21] Add ConfigurationSymbolChecker --- .../CNNTrainConfigurationSymbolChecker.java | 44 +++++++++++++++++++ .../CNNTrainConfigurationSymbolCoCo.java | 30 +++++++++++++ 2 files changed, 74 insertions(+) create mode 100644 src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CNNTrainConfigurationSymbolChecker.java create mode 100644 src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CNNTrainConfigurationSymbolCoCo.java diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CNNTrainConfigurationSymbolChecker.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CNNTrainConfigurationSymbolChecker.java new file mode 100644 index 0000000..c59d5a0 --- /dev/null +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CNNTrainConfigurationSymbolChecker.java @@ -0,0 +1,44 @@ +/** + * + * ****************************************************************************** + * MontiCAR Modeling Family, www.se-rwth.de + * Copyright (c) 2017, Software Engineering Group at RWTH Aachen, + * All rights reserved. + * + * This project is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 3.0 of the License, or (at your option) any later version. + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this project. If not, see . + * ******************************************************************************* + */ +package de.monticore.lang.monticar.cnntrain._cocos; + +import de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol; + +import java.util.ArrayList; +import java.util.List; + +/** + * + */ +public class CNNTrainConfigurationSymbolChecker { + private List cocos = new ArrayList<>(); + + public CNNTrainConfigurationSymbolChecker addCoCo(CNNTrainConfigurationSymbolCoCo coco) { + cocos.add(coco); + return this; + } + + public void checkAll(ConfigurationSymbol configurationSymbol) { + for (CNNTrainConfigurationSymbolCoCo coco : cocos) { + coco.check(configurationSymbol); + } + } +} \ No newline at end of file diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CNNTrainConfigurationSymbolCoCo.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CNNTrainConfigurationSymbolCoCo.java new file mode 100644 index 0000000..320f247 --- /dev/null +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CNNTrainConfigurationSymbolCoCo.java @@ -0,0 +1,30 @@ +/** + * + * ****************************************************************************** + * MontiCAR Modeling Family, www.se-rwth.de + * Copyright (c) 2017, Software Engineering Group at RWTH Aachen, + * All rights reserved. + * + * This project is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 3.0 of the License, or (at your option) any later version. + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this project. If not, see . + * ******************************************************************************* + */ +package de.monticore.lang.monticar.cnntrain._cocos; + +import de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol; + +/** + * + */ +public interface CNNTrainConfigurationSymbolCoCo { + void check(ConfigurationSymbol configurationSymbol); +} \ No newline at end of file -- GitLab From 7ff494bf8f4a05e7b21b86619e7090631aebdf00 Mon Sep 17 00:00:00 2001 From: Nicola Gatto Date: Fri, 19 Jul 2019 01:52:25 +0200 Subject: [PATCH 14/21] Add check critic network --- .../cnntrain/_cocos/CNNTrainCocos.java | 8 +++ ...etworkHasExactlyAOneDimensionalOutput.java | 53 +++++++++++++++++++ .../monticar/cnntrain/helper/ErrorCodes.java | 1 + 3 files changed, 62 insertions(+) create mode 100644 src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckCriticNetworkHasExactlyAOneDimensionalOutput.java diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CNNTrainCocos.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CNNTrainCocos.java index bd67172..081587b 100644 --- a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CNNTrainCocos.java +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CNNTrainCocos.java @@ -22,6 +22,7 @@ package de.monticore.lang.monticar.cnntrain._cocos; import de.monticore.lang.monticar.cnntrain._ast.ASTCNNTrainNode; import de.monticore.lang.monticar.cnntrain._symboltable.CNNTrainCompilationUnitSymbol; +import de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol; import de.se_rwth.commons.logging.Log; public class CNNTrainCocos { @@ -46,4 +47,11 @@ public class CNNTrainCocos { int findings = Log.getFindings().size(); createChecker().checkAll(node); } + + public static void checkCriticCocos(final ConfigurationSymbol configurationSymbol) { + CNNTrainConfigurationSymbolChecker checker = new CNNTrainConfigurationSymbolChecker() + .addCoCo(new CheckCriticNetworkHasExactlyAOneDimensionalOutput()); + int findings = Log.getFindings().size(); + checker.checkAll(configurationSymbol); + } } \ No newline at end of file diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckCriticNetworkHasExactlyAOneDimensionalOutput.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckCriticNetworkHasExactlyAOneDimensionalOutput.java new file mode 100644 index 0000000..5cc3fbb --- /dev/null +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckCriticNetworkHasExactlyAOneDimensionalOutput.java @@ -0,0 +1,53 @@ +/** + * + * ****************************************************************************** + * MontiCAR Modeling Family, www.se-rwth.de + * Copyright (c) 2017, Software Engineering Group at RWTH Aachen, + * All rights reserved. + * + * This project is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 3.0 of the License, or (at your option) any later version. + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this project. If not, see . + * ******************************************************************************* + */ +package de.monticore.lang.monticar.cnntrain._cocos; + +import de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol; +import de.monticore.lang.monticar.cnntrain._symboltable.NNArchitectureSymbol; +import de.monticore.lang.monticar.cnntrain.helper.ErrorCodes; +import de.se_rwth.commons.logging.Log; + +import java.util.List; + +/** + * + */ +public class CheckCriticNetworkHasExactlyAOneDimensionalOutput implements CNNTrainConfigurationSymbolCoCo { + + @Override + public void check(ConfigurationSymbol configurationSymbol) { + if (configurationSymbol.getCriticNetwork().isPresent()) { + NNArchitectureSymbol criticNetwork = configurationSymbol.getCriticNetwork().get(); + + if (criticNetwork.getOutputs().size() > 1) { + Log.error("0" + ErrorCodes.CRITIC_NETWORK_ERROR + + " The critic network has more than one outputs", criticNetwork.getSourcePosition()); + } + final String outputName = criticNetwork.getOutputs().get(0); + List dimensions = criticNetwork.getDimensions().get(outputName); + + if (dimensions.size() != 1 || dimensions.get(0) != 1) { + Log.error("0" + ErrorCodes.CRITIC_NETWORK_ERROR + " The output " + outputName + + " of critic network is not a one-dimensional vector", configurationSymbol.getSourcePosition()); + } + } + } +} diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/helper/ErrorCodes.java b/src/main/java/de/monticore/lang/monticar/cnntrain/helper/ErrorCodes.java index 6aa9e50..b9850b8 100644 --- a/src/main/java/de/monticore/lang/monticar/cnntrain/helper/ErrorCodes.java +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/helper/ErrorCodes.java @@ -31,4 +31,5 @@ public class ErrorCodes { public static final String REQUIRED_PARAMETER_MISSING = "xC8856"; public static final String STRATEGY_NOT_APPLICABLE = "xC8857"; public static final String CONTRADICTING_PARAMETERS = "xC8858"; + public static final String CRITIC_NETWORK_ERROR = "xC7100"; } \ No newline at end of file -- GitLab From 3d15a4deb22ed2b7fe668b6577aab7460891be2e Mon Sep 17 00:00:00 2001 From: Nicola Gatto Date: Sat, 20 Jul 2019 14:10:47 +0200 Subject: [PATCH 15/21] Add critic network input cocos --- .../cnntrain/_cocos/CNNTrainCocos.java | 3 +- .../_cocos/CheckCriticNetworkInputs.java | 98 +++++++++++++++++++ .../monticar/cnntrain/annotations/Range.java | 25 +++++ .../monticar/cnntrain/helper/ErrorCodes.java | 1 + 4 files changed, 126 insertions(+), 1 deletion(-) create mode 100644 src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckCriticNetworkInputs.java diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CNNTrainCocos.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CNNTrainCocos.java index 081587b..93d5ed9 100644 --- a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CNNTrainCocos.java +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CNNTrainCocos.java @@ -50,7 +50,8 @@ public class CNNTrainCocos { public static void checkCriticCocos(final ConfigurationSymbol configurationSymbol) { CNNTrainConfigurationSymbolChecker checker = new CNNTrainConfigurationSymbolChecker() - .addCoCo(new CheckCriticNetworkHasExactlyAOneDimensionalOutput()); + .addCoCo(new CheckCriticNetworkHasExactlyAOneDimensionalOutput()) + .addCoCo(new CheckCriticNetworkInputs()); int findings = Log.getFindings().size(); checker.checkAll(configurationSymbol); } diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckCriticNetworkInputs.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckCriticNetworkInputs.java new file mode 100644 index 0000000..6f86a12 --- /dev/null +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckCriticNetworkInputs.java @@ -0,0 +1,98 @@ +package de.monticore.lang.monticar.cnntrain._cocos; + +import de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol; +import de.monticore.lang.monticar.cnntrain._symboltable.NNArchitectureSymbol; +import de.monticore.lang.monticar.cnntrain.annotations.Range; +import de.monticore.lang.monticar.cnntrain.helper.ErrorCodes; +import de.se_rwth.commons.logging.Log; + +import java.util.List; +import java.util.stream.Collectors; + +/** + * + */ +public class CheckCriticNetworkInputs implements CNNTrainConfigurationSymbolCoCo { + + @Override + public void check(ConfigurationSymbol configurationSymbol) { + if (configurationSymbol.getCriticNetwork().isPresent()) { + if (!configurationSymbol.getTrainedArchitecture().isPresent()) { + Log.error("0" + ErrorCodes.MISSING_TRAINED_ARCHITECTURE + + "No architecture found that is trained by this configuration.", configurationSymbol.getSourcePosition()); + } + NNArchitectureSymbol trainedArchitecture = configurationSymbol.getTrainedArchitecture().get(); + NNArchitectureSymbol criticNetwork = configurationSymbol.getCriticNetwork().get(); + + if (trainedArchitecture.getInputs().size() != 1 || trainedArchitecture.getOutputs().size() != 1) { + Log.error("Malformed trained architecture"); + } + + if (trainedArchitecture.getInputs().size() != 2) { + Log.error("0" + ErrorCodes.CRITIC_NETWORK_ERROR + + "Number of critic network inputs is wrong. Critic network has two inputs," + + "first needs to be a state input and second needs to be the action input."); + } + + final String stateInput = trainedArchitecture.getInputs().get(0); + final String actionOutput = trainedArchitecture.getOutputs().get(0); + final List stateDimensions = trainedArchitecture.getDimensions().get(stateInput); + final List actionDimensions = trainedArchitecture.getDimensions().get(actionOutput); + final Range stateRange = trainedArchitecture.getRanges().get(stateInput); + final Range actionRange = trainedArchitecture.getRanges().get(actionOutput); + final String stateType = trainedArchitecture.getTypes().get(stateInput); + final String actionType = trainedArchitecture.getTypes().get(actionOutput); + + String criticInput1 = criticNetwork.getInputs().get(0); + String criticInput2 = criticNetwork.getInputs().get(1); + + if (criticNetwork.getDimensions().get(criticInput1).equals(stateDimensions)) { + Log.error("0" + ErrorCodes.CRITIC_NETWORK_ERROR + + " Declared critic network is not a critic: Dimensions of first input of critic architecture must be" + + " equal to state's dimensions " + + stateDimensions.stream().map(Object::toString).collect(Collectors.joining("{", ",", "}")) + + ".", configurationSymbol.getSourcePosition()); + } + + if (criticNetwork.getDimensions().get(criticInput2).equals(actionDimensions)) { + Log.error("0" + ErrorCodes.CRITIC_NETWORK_ERROR + + " Declared critic network is not a critic: Dimensions of second input of critic architecture must be" + + " equal to action's dimensions " + + actionDimensions.stream().map(Object::toString).collect(Collectors.joining("{", ",", "}")) + + ".", configurationSymbol.getSourcePosition()); + } + + if (criticNetwork.getRanges().get(criticInput1).equals(stateRange)) { + Log.error("0" + ErrorCodes.CRITIC_NETWORK_ERROR + + " Declared critic network is not a critic: Ranges of first input of critic architecture must be" + + " equal to state's ranges " + + stateRange.toString() + + ".", configurationSymbol.getSourcePosition()); + } + + if (criticNetwork.getRanges().get(criticInput2).equals(actionRange)) { + Log.error("0" + ErrorCodes.CRITIC_NETWORK_ERROR + + " Declared critic network is not a critic: Ranges of second input of critic architecture must be" + + " equal to action's ranges " + + actionRange.toString() + + ".", configurationSymbol.getSourcePosition()); + } + + if (criticNetwork.getTypes().get(criticInput1).equals(stateType)) { + Log.error("0" + ErrorCodes.CRITIC_NETWORK_ERROR + + " Declared critic network is not a critic: Type of first input of critic architecture must be" + + " equal to state's types " + + stateType + + ".", configurationSymbol.getSourcePosition()); + } + + if (criticNetwork.getTypes().get(criticInput2).equals(actionType)) { + Log.error("0" + ErrorCodes.CRITIC_NETWORK_ERROR + + " Declared critic network is not a critic: Type of second input of critic architecture must be" + + " equal to action's types " + + stateType + + ".", configurationSymbol.getSourcePosition()); + } + } + } +} diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/annotations/Range.java b/src/main/java/de/monticore/lang/monticar/cnntrain/annotations/Range.java index bbbb3fc..06a0582 100644 --- a/src/main/java/de/monticore/lang/monticar/cnntrain/annotations/Range.java +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/annotations/Range.java @@ -20,6 +20,7 @@ */ package de.monticore.lang.monticar.cnntrain.annotations; +import java.util.Objects; import java.util.Optional; public class Range { @@ -66,4 +67,28 @@ public class Range { public static Range withLowerInfinityLimit(double upperLimit) { return new Range(true, false, null, upperLimit); } + + @Override + public String toString() { + final String lowerLimit = isLowerLimitInfinity() || !getLowerLimit().isPresent() ? "-oo" : getLowerLimit().get().toString(); + final String upperLimit = isUpperLimitInfinity() || !getUpperLimit().isPresent() ? "oo" : getUpperLimit().get().toString(); + + return "[" + lowerLimit + ", " + upperLimit + "]"; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (!(o instanceof Range)) return false; + Range range = (Range) o; + return lowerLimitIsInfinity == range.lowerLimitIsInfinity && + upperLimitIsInfinity == range.upperLimitIsInfinity && + Objects.equals(lowerLimit, range.lowerLimit) && + Objects.equals(upperLimit, range.upperLimit); + } + + @Override + public int hashCode() { + return Objects.hash(lowerLimitIsInfinity, upperLimitIsInfinity, lowerLimit, upperLimit); + } } diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/helper/ErrorCodes.java b/src/main/java/de/monticore/lang/monticar/cnntrain/helper/ErrorCodes.java index b9850b8..1e6f9d2 100644 --- a/src/main/java/de/monticore/lang/monticar/cnntrain/helper/ErrorCodes.java +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/helper/ErrorCodes.java @@ -32,4 +32,5 @@ public class ErrorCodes { public static final String STRATEGY_NOT_APPLICABLE = "xC8857"; public static final String CONTRADICTING_PARAMETERS = "xC8858"; public static final String CRITIC_NETWORK_ERROR = "xC7100"; + public static final String MISSING_TRAINED_ARCHITECTURE = "xC7101"; } \ No newline at end of file -- GitLab From 46333043abf3d983f047af41a57d86a549d5b6b1 Mon Sep 17 00:00:00 2001 From: Nicola Gatto Date: Sat, 20 Jul 2019 15:16:12 +0200 Subject: [PATCH 16/21] Checks for trained network inputs and outputs --- .../cnntrain/_cocos/CNNTrainCocos.java | 8 +++++- ...eckTrainedRlNetworkHasExactlyOneInput.java | 27 +++++++++++++++++++ ...ckTrainedRlNetworkHasExactlyOneOutput.java | 27 +++++++++++++++++++ .../monticar/cnntrain/helper/ErrorCodes.java | 1 + 4 files changed, 62 insertions(+), 1 deletion(-) create mode 100644 src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckTrainedRlNetworkHasExactlyOneInput.java create mode 100644 src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckTrainedRlNetworkHasExactlyOneOutput.java diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CNNTrainCocos.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CNNTrainCocos.java index 93d5ed9..1423d34 100644 --- a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CNNTrainCocos.java +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CNNTrainCocos.java @@ -48,11 +48,17 @@ public class CNNTrainCocos { createChecker().checkAll(node); } + public static void checkTrainedArchitectureCoCos(final ConfigurationSymbol configurationSymbol) { + CNNTrainConfigurationSymbolChecker checker = new CNNTrainConfigurationSymbolChecker() + .addCoCo(new CheckTrainedRlNetworkHasExactlyOneInput()) + .addCoCo(new CheckTrainedRlNetworkHasExactlyOneOutput()); + checker.checkAll(configurationSymbol); + } + public static void checkCriticCocos(final ConfigurationSymbol configurationSymbol) { CNNTrainConfigurationSymbolChecker checker = new CNNTrainConfigurationSymbolChecker() .addCoCo(new CheckCriticNetworkHasExactlyAOneDimensionalOutput()) .addCoCo(new CheckCriticNetworkInputs()); - int findings = Log.getFindings().size(); checker.checkAll(configurationSymbol); } } \ No newline at end of file diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckTrainedRlNetworkHasExactlyOneInput.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckTrainedRlNetworkHasExactlyOneInput.java new file mode 100644 index 0000000..31af51f --- /dev/null +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckTrainedRlNetworkHasExactlyOneInput.java @@ -0,0 +1,27 @@ +package de.monticore.lang.monticar.cnntrain._cocos; + +import de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol; +import de.monticore.lang.monticar.cnntrain._symboltable.RLAlgorithm; +import de.monticore.lang.monticar.cnntrain.helper.ErrorCodes; +import de.se_rwth.commons.logging.Log; + +/** + * + */ +public class CheckTrainedRlNetworkHasExactlyOneInput implements CNNTrainConfigurationSymbolCoCo { + @Override + public void check(ConfigurationSymbol configurationSymbol) { + if (configurationSymbol.isReinforcementLearningMethod() + && configurationSymbol.getTrainedArchitecture().isPresent()) { + final int numberOfInputs = configurationSymbol.getTrainedArchitecture().get().getInputs().size(); + if (numberOfInputs != 1) { + final String networkName + = configurationSymbol.getEntry("rl_algorithm").getValue().getValue() + .equals(RLAlgorithm.DQN) ? "Q-Network" : "Actor-Network"; + Log.error("x0" + ErrorCodes.TRAINED_ARCHITECTURE_ERROR + + networkName + " " +configurationSymbol.getTrainedArchitecture().get().getName() + +" has " + numberOfInputs + " inputs but 1 is only allowed.", configurationSymbol.getSourcePosition()); + } + } + } +} \ No newline at end of file diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckTrainedRlNetworkHasExactlyOneOutput.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckTrainedRlNetworkHasExactlyOneOutput.java new file mode 100644 index 0000000..47d7cbe --- /dev/null +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckTrainedRlNetworkHasExactlyOneOutput.java @@ -0,0 +1,27 @@ +package de.monticore.lang.monticar.cnntrain._cocos; + +import de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol; +import de.monticore.lang.monticar.cnntrain._symboltable.RLAlgorithm; +import de.monticore.lang.monticar.cnntrain.helper.ErrorCodes; +import de.se_rwth.commons.logging.Log; + +/** + * + */ +public class CheckTrainedRlNetworkHasExactlyOneOutput implements CNNTrainConfigurationSymbolCoCo { + @Override + public void check(final ConfigurationSymbol configurationSymbol) { + if (configurationSymbol.isReinforcementLearningMethod() + && configurationSymbol.getTrainedArchitecture().isPresent()) { + final int numberOfOutputs = configurationSymbol.getTrainedArchitecture().get().getOutputs().size(); + if (numberOfOutputs != 1) { + final String networkName + = configurationSymbol.getEntry("rl_algorithm").getValue().getValue() + .equals(RLAlgorithm.DQN) ? "Q-Network" : "Actor-Network"; + Log.error("x0" + ErrorCodes.TRAINED_ARCHITECTURE_ERROR + + networkName + " " +configurationSymbol.getTrainedArchitecture().get().getName() + +" has " + numberOfOutputs + " outputs but 1 is only allowed.", configurationSymbol.getSourcePosition()); + } + } + } +} \ No newline at end of file diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/helper/ErrorCodes.java b/src/main/java/de/monticore/lang/monticar/cnntrain/helper/ErrorCodes.java index 1e6f9d2..fa7c21b 100644 --- a/src/main/java/de/monticore/lang/monticar/cnntrain/helper/ErrorCodes.java +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/helper/ErrorCodes.java @@ -33,4 +33,5 @@ public class ErrorCodes { public static final String CONTRADICTING_PARAMETERS = "xC8858"; public static final String CRITIC_NETWORK_ERROR = "xC7100"; public static final String MISSING_TRAINED_ARCHITECTURE = "xC7101"; + public static final String TRAINED_ARCHITECTURE_ERROR = "xC7102"; } \ No newline at end of file -- GitLab From d3a10207a74ba779ce1b31b02d5a10875a42eb43 Mon Sep 17 00:00:00 2001 From: Nicola Gatto Date: Sat, 20 Jul 2019 16:29:31 +0200 Subject: [PATCH 17/21] Add constants for parameter entries --- .../_cocos/ASTConfigurationUtils.java | 3 +- ...uousRLAlgorithmUsesContinuousStrategy.java | 7 +++- ...screteRLAlgorithmUsesDiscreteStrategy.java | 3 +- ...CheckFixTargetNetworkRequiresInterval.java | 9 ++--- ...CheckReinforcementRequiresEnvironment.java | 5 +-- ...eckTrainedRlNetworkHasExactlyOneInput.java | 3 +- ...ckTrainedRlNetworkHasExactlyOneOutput.java | 3 +- .../_symboltable/ConfigurationSymbol.java | 10 +++-- .../helper/ConfigEntryNameConstants.java | 40 +++++++++++++++++++ 9 files changed, 65 insertions(+), 18 deletions(-) create mode 100644 src/main/java/de/monticore/lang/monticar/cnntrain/helper/ConfigEntryNameConstants.java diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ASTConfigurationUtils.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ASTConfigurationUtils.java index 673d620..a31e150 100644 --- a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ASTConfigurationUtils.java +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ASTConfigurationUtils.java @@ -21,6 +21,7 @@ package de.monticore.lang.monticar.cnntrain._cocos; import de.monticore.lang.monticar.cnntrain._ast.*; +import static de.monticore.lang.monticar.cnntrain.helper.ConfigEntryNameConstants.*; import java.util.Optional; @@ -77,7 +78,7 @@ class ASTConfigurationUtils { return ASTConfigurationUtils.hasEnvironment(node) && node.getEntriesList().stream() .anyMatch(e -> (e instanceof ASTEnvironmentEntry) - && ((ASTEnvironmentEntry)e).getValue().getName().equals("ros_interface")); + && ((ASTEnvironmentEntry)e).getValue().getName().equals(ENVIRONMENT_ROS)); } static boolean hasRewardTopic(final ASTConfiguration node) { diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckContinuousRLAlgorithmUsesContinuousStrategy.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckContinuousRLAlgorithmUsesContinuousStrategy.java index 008bcb6..03c0979 100644 --- a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckContinuousRLAlgorithmUsesContinuousStrategy.java +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckContinuousRLAlgorithmUsesContinuousStrategy.java @@ -22,15 +22,18 @@ package de.monticore.lang.monticar.cnntrain._cocos; import com.google.common.collect.ImmutableSet; import de.monticore.lang.monticar.cnntrain._ast.ASTConfiguration; +import de.monticore.lang.monticar.cnntrain.helper.ConfigEntryNameConstants; import de.monticore.lang.monticar.cnntrain.helper.ErrorCodes; import de.se_rwth.commons.logging.Log; import java.util.Set; +import static de.monticore.lang.monticar.cnntrain.helper.ConfigEntryNameConstants.*; + public class CheckContinuousRLAlgorithmUsesContinuousStrategy implements CNNTrainASTConfigurationCoCo{ private static final Set CONTINUOUS_STRATEGIES = ImmutableSet.builder() - .add("ornstein_uhlenbeck") - .add("gaussian") + .add(STRATEGY_OU) + .add(STRATEGY_GAUSSIAN) .build(); @Override diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckDiscreteRLAlgorithmUsesDiscreteStrategy.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckDiscreteRLAlgorithmUsesDiscreteStrategy.java index 64263bb..edf32c7 100644 --- a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckDiscreteRLAlgorithmUsesDiscreteStrategy.java +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckDiscreteRLAlgorithmUsesDiscreteStrategy.java @@ -22,6 +22,7 @@ package de.monticore.lang.monticar.cnntrain._cocos; import com.google.common.collect.ImmutableSet; import de.monticore.lang.monticar.cnntrain._ast.ASTConfiguration; +import de.monticore.lang.monticar.cnntrain.helper.ConfigEntryNameConstants; import de.monticore.lang.monticar.cnntrain.helper.ErrorCodes; import de.se_rwth.commons.logging.Log; @@ -29,7 +30,7 @@ import java.util.Set; public class CheckDiscreteRLAlgorithmUsesDiscreteStrategy implements CNNTrainASTConfigurationCoCo{ private static final Set DISCRETE_STRATEGIES = ImmutableSet.builder() - .add("epsgreedy") + .add(ConfigEntryNameConstants.STRATEGY_EPSGREEDY) .build(); @Override diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckFixTargetNetworkRequiresInterval.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckFixTargetNetworkRequiresInterval.java index 221c38e..435cf74 100644 --- a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckFixTargetNetworkRequiresInterval.java +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckFixTargetNetworkRequiresInterval.java @@ -23,6 +23,7 @@ package de.monticore.lang.monticar.cnntrain._cocos; import de.monticore.lang.monticar.cnntrain._ast.*; import de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol; import de.monticore.lang.monticar.cnntrain._symboltable.EntrySymbol; +import de.monticore.lang.monticar.cnntrain.helper.ConfigEntryNameConstants; import de.monticore.lang.monticar.cnntrain.helper.ErrorCodes; import de.se_rwth.commons.logging.Log; @@ -33,8 +34,6 @@ import java.util.Map; * */ public class CheckFixTargetNetworkRequiresInterval implements CNNTrainASTConfigurationCoCo { - private static final String PARAMETER_USE_FIX_TARGET_NETWORK = "use_fix_target_network"; - private static final String PARAMETER_TARGET_NETWORK_UPDATE_INTERVAL = "target_network_update_interval"; @Override public void check(ASTConfiguration node) { @@ -50,8 +49,8 @@ public class CheckFixTargetNetworkRequiresInterval implements CNNTrainASTConfigu .map(e -> (ASTUseFixTargetNetworkEntry)e) .findFirst() .orElseThrow(() -> new IllegalStateException("ASTUseFixTargetNetwork entry must be available")); - Log.error("0" + ErrorCodes.REQUIRED_PARAMETER_MISSING + " Parameter " + Boolean.toString(useFixTargetNetwork) - + " requires parameter " + PARAMETER_TARGET_NETWORK_UPDATE_INTERVAL, + Log.error("0" + ErrorCodes.REQUIRED_PARAMETER_MISSING + " Parameter " + ConfigEntryNameConstants.USE_FIX_TARGET_NETWORK + + " requires parameter " + ConfigEntryNameConstants.TARGET_NETWORK_UPDATE_INTERVAL, useFixTargetNetworkEntry.get_SourcePositionStart()); } else if (!useFixTargetNetwork && hasTargetNetworkUpdateInterval) { ASTTargetNetworkUpdateIntervalEntry targetNetworkUpdateIntervalEntry = node.getEntriesList().stream() @@ -62,7 +61,7 @@ public class CheckFixTargetNetworkRequiresInterval implements CNNTrainASTConfigu () -> new IllegalStateException("ASTTargetNetworkUpdateInterval entry must be available")); Log.error("0" + ErrorCodes.REQUIRED_PARAMETER_MISSING + " Parameter " + targetNetworkUpdateIntervalEntry.getName() + " requires that parameter " - + PARAMETER_USE_FIX_TARGET_NETWORK + " to be true.", + + ConfigEntryNameConstants.USE_FIX_TARGET_NETWORK + " to be true.", targetNetworkUpdateIntervalEntry.get_SourcePositionStart()); } } diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckReinforcementRequiresEnvironment.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckReinforcementRequiresEnvironment.java index 19672e7..29c5f06 100644 --- a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckReinforcementRequiresEnvironment.java +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckReinforcementRequiresEnvironment.java @@ -25,6 +25,7 @@ import de.monticore.lang.monticar.cnntrain._ast.ASTEnvironmentEntry; import de.monticore.lang.monticar.cnntrain._ast.ASTLearningMethodEntry; import de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol; import de.monticore.lang.monticar.cnntrain._symboltable.LearningMethod; +import de.monticore.lang.monticar.cnntrain.helper.ConfigEntryNameConstants; import de.monticore.lang.monticar.cnntrain.helper.ErrorCodes; import de.se_rwth.commons.logging.Log; @@ -32,8 +33,6 @@ import de.se_rwth.commons.logging.Log; * */ public class CheckReinforcementRequiresEnvironment implements CNNTrainASTConfigurationCoCo { - private static final String PARAMETER_ENVIRONMENT = "environment"; - @Override public void check(ASTConfiguration node) { boolean isReinforcementLearning = ASTConfigurationUtils.isReinforcementLearning(node); @@ -41,7 +40,7 @@ public class CheckReinforcementRequiresEnvironment implements CNNTrainASTConfigu if (isReinforcementLearning && !hasEnvironment) { Log.error("0" + ErrorCodes.REQUIRED_PARAMETER_MISSING + " The required parameter " - + PARAMETER_ENVIRONMENT + " is missing"); + + ConfigEntryNameConstants.ENVIRONMENT + " is missing"); } } } \ No newline at end of file diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckTrainedRlNetworkHasExactlyOneInput.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckTrainedRlNetworkHasExactlyOneInput.java index 31af51f..7a79b98 100644 --- a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckTrainedRlNetworkHasExactlyOneInput.java +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckTrainedRlNetworkHasExactlyOneInput.java @@ -2,6 +2,7 @@ package de.monticore.lang.monticar.cnntrain._cocos; import de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol; import de.monticore.lang.monticar.cnntrain._symboltable.RLAlgorithm; +import de.monticore.lang.monticar.cnntrain.helper.ConfigEntryNameConstants; import de.monticore.lang.monticar.cnntrain.helper.ErrorCodes; import de.se_rwth.commons.logging.Log; @@ -16,7 +17,7 @@ public class CheckTrainedRlNetworkHasExactlyOneInput implements CNNTrainConfigur final int numberOfInputs = configurationSymbol.getTrainedArchitecture().get().getInputs().size(); if (numberOfInputs != 1) { final String networkName - = configurationSymbol.getEntry("rl_algorithm").getValue().getValue() + = configurationSymbol.getEntry(ConfigEntryNameConstants.RL_ALGORITHM).getValue().getValue() .equals(RLAlgorithm.DQN) ? "Q-Network" : "Actor-Network"; Log.error("x0" + ErrorCodes.TRAINED_ARCHITECTURE_ERROR + networkName + " " +configurationSymbol.getTrainedArchitecture().get().getName() diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckTrainedRlNetworkHasExactlyOneOutput.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckTrainedRlNetworkHasExactlyOneOutput.java index 47d7cbe..437a4d1 100644 --- a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckTrainedRlNetworkHasExactlyOneOutput.java +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckTrainedRlNetworkHasExactlyOneOutput.java @@ -2,6 +2,7 @@ package de.monticore.lang.monticar.cnntrain._cocos; import de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol; import de.monticore.lang.monticar.cnntrain._symboltable.RLAlgorithm; +import de.monticore.lang.monticar.cnntrain.helper.ConfigEntryNameConstants; import de.monticore.lang.monticar.cnntrain.helper.ErrorCodes; import de.se_rwth.commons.logging.Log; @@ -16,7 +17,7 @@ public class CheckTrainedRlNetworkHasExactlyOneOutput implements CNNTrainConfigu final int numberOfOutputs = configurationSymbol.getTrainedArchitecture().get().getOutputs().size(); if (numberOfOutputs != 1) { final String networkName - = configurationSymbol.getEntry("rl_algorithm").getValue().getValue() + = configurationSymbol.getEntry(ConfigEntryNameConstants.RL_ALGORITHM).getValue().getValue() .equals(RLAlgorithm.DQN) ? "Q-Network" : "Actor-Network"; Log.error("x0" + ErrorCodes.TRAINED_ARCHITECTURE_ERROR + networkName + " " +configurationSymbol.getTrainedArchitecture().get().getName() diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/ConfigurationSymbol.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/ConfigurationSymbol.java index 797e77a..6e41d6f 100644 --- a/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/ConfigurationSymbol.java +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/ConfigurationSymbol.java @@ -24,6 +24,8 @@ import de.monticore.symboltable.CommonScopeSpanningSymbol; import java.util.*; +import static de.monticore.lang.monticar.cnntrain.helper.ConfigEntryNameConstants.*; + public class ConfigurationSymbol extends CommonScopeSpanningSymbol { private Map entryMap = new HashMap<>(); @@ -99,8 +101,8 @@ public class ConfigurationSymbol extends CommonScopeSpanningSymbol { } public LearningMethod getLearningMethod() { - return this.entryMap.containsKey("learning_method") - ? (LearningMethod)this.entryMap.get("learning_method").getValue().getValue() : LearningMethod.SUPERVISED; + return this.entryMap.containsKey(LEARNING_METHOD) + ? (LearningMethod)this.entryMap.get(LEARNING_METHOD).getValue().getValue() : LearningMethod.SUPERVISED; } public boolean isReinforcementLearningMethod() { @@ -108,7 +110,7 @@ public class ConfigurationSymbol extends CommonScopeSpanningSymbol { } public boolean hasCritic() { - return getEntryMap().containsKey("critic"); + return getEntryMap().containsKey(CRITIC); } public Optional getCriticName() { @@ -116,7 +118,7 @@ public class ConfigurationSymbol extends CommonScopeSpanningSymbol { return Optional.empty(); } - final Object criticNameValue = getEntry("critic").getValue().getValue(); + final Object criticNameValue = getEntry(CRITIC).getValue().getValue(); assert criticNameValue instanceof String; return Optional.of((String)criticNameValue); } diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/helper/ConfigEntryNameConstants.java b/src/main/java/de/monticore/lang/monticar/cnntrain/helper/ConfigEntryNameConstants.java new file mode 100644 index 0000000..8e8802e --- /dev/null +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/helper/ConfigEntryNameConstants.java @@ -0,0 +1,40 @@ +package de.monticore.lang.monticar.cnntrain.helper; + +/** + * + */ +public class ConfigEntryNameConstants { + public static final String LEARNING_METHOD = "learning_method"; + public static final String NUM_EPISODES = "num_episodes"; + public static final String DISCOUNT_FACTOR = "discount_factor"; + public static final String NUM_MAX_STEPS = "num_max_steps"; + public static final String TARGET_SCORE = "target_score"; + public static final String TRAINING_INTERVAL = "training_interval"; + public static final String USE_FIX_TARGET_NETWORK = "use_fix_target_network"; + public static final String TARGET_NETWORK_UPDATE_INTERVAL = "target_network_update_interval"; + public static final String SNAPSHOT_INTERVAL = "snapshot_interval"; + public static final String AGENT_NAME = "agent_name"; + public static final String USE_DOUBLE_DQN = "use_double_dqn"; + public static final String LOSS = "loss"; + public static final String RL_ALGORITHM = "rl_algorithm"; + public static final String REPLAY_MEMORY = "replay_memory"; + public static final String ENVIRONMENT = "environment"; + public static final String START_TRAINING_AT = "start_training_at"; + public static final String SOFT_TARGET_UPDATE_RATE = "soft_target_update_rate"; + public static final String EVALUATION_SAMPLES = "evaluation_samples"; + public static final String POLICY_NOISE = "policy_noise"; + public static final String NOISE_CLIP = "noise_clip"; + public static final String POLICY_DELAY = "policy_delay"; + + public static final String ENVIRONMENT_REWARD_TOPIC = "reward_topic"; + public static final String ENVIRONMENT_ROS = "ros_interface"; + public static final String ENVIRONMENT_GYM = "gym"; + + public static final String STRATEGY = "strategy"; + public static final String STRATEGY_OU = "ornstein_uhlenbeck"; + public static final String STRATEGY_GAUSSIAN = "gaussian"; + public static final String STRATEGY_EPSGREEDY = "epsgreedy"; + public static final String STRATEGY_EPSDECAY = "epsdecay"; + + public static final String CRITIC = "critic"; +} -- GitLab From 944c9ed87cabfc2cc5e5b5ea3c3a44847e5b8d74 Mon Sep 17 00:00:00 2001 From: Nicola Gatto Date: Sat, 20 Jul 2019 16:58:26 +0200 Subject: [PATCH 18/21] Add Ornstein Uhlenbeck dimension check --- .../cnntrain/_cocos/CNNTrainCocos.java | 3 +- .../_cocos/CheckCriticNetworkInputs.java | 20 +++++ ...rameterDimensionEqualsActionDimension.java | 79 +++++++++++++++++++ ...eckTrainedRlNetworkHasExactlyOneInput.java | 20 +++++ ...ckTrainedRlNetworkHasExactlyOneOutput.java | 20 +++++ .../_symboltable/MultiParamValueSymbol.java | 4 + .../helper/ConfigEntryNameConstants.java | 23 ++++++ 7 files changed, 168 insertions(+), 1 deletion(-) create mode 100644 src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckOUParameterDimensionEqualsActionDimension.java diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CNNTrainCocos.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CNNTrainCocos.java index 1423d34..c994809 100644 --- a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CNNTrainCocos.java +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CNNTrainCocos.java @@ -51,7 +51,8 @@ public class CNNTrainCocos { public static void checkTrainedArchitectureCoCos(final ConfigurationSymbol configurationSymbol) { CNNTrainConfigurationSymbolChecker checker = new CNNTrainConfigurationSymbolChecker() .addCoCo(new CheckTrainedRlNetworkHasExactlyOneInput()) - .addCoCo(new CheckTrainedRlNetworkHasExactlyOneOutput()); + .addCoCo(new CheckTrainedRlNetworkHasExactlyOneOutput()) + .addCoCo(new CheckOUParameterDimensionEqualsActionDimension()); checker.checkAll(configurationSymbol); } diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckCriticNetworkInputs.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckCriticNetworkInputs.java index 6f86a12..7688402 100644 --- a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckCriticNetworkInputs.java +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckCriticNetworkInputs.java @@ -1,3 +1,23 @@ +/** + * + * ****************************************************************************** + * MontiCAR Modeling Family, www.se-rwth.de + * Copyright (c) 2017, Software Engineering Group at RWTH Aachen, + * All rights reserved. + * + * This project is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 3.0 of the License, or (at your option) any later version. + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this project. If not, see . + * ******************************************************************************* + */ package de.monticore.lang.monticar.cnntrain._cocos; import de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol; diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckOUParameterDimensionEqualsActionDimension.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckOUParameterDimensionEqualsActionDimension.java new file mode 100644 index 0000000..c466165 --- /dev/null +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckOUParameterDimensionEqualsActionDimension.java @@ -0,0 +1,79 @@ +/** + * + * ****************************************************************************** + * MontiCAR Modeling Family, www.se-rwth.de + * Copyright (c) 2017, Software Engineering Group at RWTH Aachen, + * All rights reserved. + * + * This project is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 3.0 of the License, or (at your option) any later version. + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this project. If not, see . + * ******************************************************************************* + */ +package de.monticore.lang.monticar.cnntrain._cocos; + +import de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol; +import de.monticore.lang.monticar.cnntrain._symboltable.MultiParamValueSymbol; +import de.monticore.lang.monticar.cnntrain._symboltable.NNArchitectureSymbol; +import de.monticore.lang.monticar.cnntrain.helper.ConfigEntryNameConstants; +import de.se_rwth.commons.logging.Log; + +import java.util.Collection; +import java.util.List; + +import static de.monticore.lang.monticar.cnntrain.helper.ConfigEntryNameConstants.*; + +/** + * + */ +public class CheckOUParameterDimensionEqualsActionDimension implements CNNTrainConfigurationSymbolCoCo { + @Override + public void check(final ConfigurationSymbol configurationSymbol) { + if (configurationSymbol.getTrainedArchitecture().isPresent() + && configurationSymbol.isReinforcementLearningMethod() + && configurationSymbol.getEntry(STRATEGY).getValue().getValue().equals(STRATEGY_OU)) { + final MultiParamValueSymbol strategyParameters + = (MultiParamValueSymbol)configurationSymbol.getEntry(STRATEGY).getValue(); + final NNArchitectureSymbol architectureSymbol = configurationSymbol.getTrainedArchitecture().get(); + final String outputNameOfTrainedArchitecture = architectureSymbol.getOutputs().get(0); + final int actionDimension = architectureSymbol.getDimensions().get(outputNameOfTrainedArchitecture).size(); + + if (strategyParameters.hasParameter(STRATEGY_OU_MU)) { + logIfDimensionIsUnequal(configurationSymbol, strategyParameters, outputNameOfTrainedArchitecture, + actionDimension, STRATEGY_OU_MU); + } + + if (strategyParameters.hasParameter(STRATEGY_OU_SIGMA)) { + logIfDimensionIsUnequal(configurationSymbol, strategyParameters, outputNameOfTrainedArchitecture, + actionDimension, STRATEGY_OU_SIGMA); + } + + if (strategyParameters.hasParameter(STRATEGY_OU_THETA)) { + logIfDimensionIsUnequal(configurationSymbol, strategyParameters, outputNameOfTrainedArchitecture, + actionDimension, STRATEGY_OU_THETA); + } + } + } + + private void logIfDimensionIsUnequal(ConfigurationSymbol configurationSymbol, + MultiParamValueSymbol strategyParameters, + String outputNameOfTrainedArchitecture, + int actionDimension, + String ouParameterName) { + final int ouParameterDimension = ((Collection) strategyParameters.getParameter(ouParameterName)).size(); + if (ouParameterDimension != actionDimension) { + Log.error("Vector parameter " + ouParameterName + " of parameter " + STRATEGY_OU + " must have" + + " the same dimensions as the action dimension of output " + + outputNameOfTrainedArchitecture + " which is " + actionDimension, + configurationSymbol.getSourcePosition()); + } + } +} diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckTrainedRlNetworkHasExactlyOneInput.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckTrainedRlNetworkHasExactlyOneInput.java index 7a79b98..a0e5726 100644 --- a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckTrainedRlNetworkHasExactlyOneInput.java +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckTrainedRlNetworkHasExactlyOneInput.java @@ -1,3 +1,23 @@ +/** + * + * ****************************************************************************** + * MontiCAR Modeling Family, www.se-rwth.de + * Copyright (c) 2017, Software Engineering Group at RWTH Aachen, + * All rights reserved. + * + * This project is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 3.0 of the License, or (at your option) any later version. + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this project. If not, see . + * ******************************************************************************* + */ package de.monticore.lang.monticar.cnntrain._cocos; import de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol; diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckTrainedRlNetworkHasExactlyOneOutput.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckTrainedRlNetworkHasExactlyOneOutput.java index 437a4d1..a0d7487 100644 --- a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckTrainedRlNetworkHasExactlyOneOutput.java +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckTrainedRlNetworkHasExactlyOneOutput.java @@ -1,3 +1,23 @@ +/** + * + * ****************************************************************************** + * MontiCAR Modeling Family, www.se-rwth.de + * Copyright (c) 2017, Software Engineering Group at RWTH Aachen, + * All rights reserved. + * + * This project is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 3.0 of the License, or (at your option) any later version. + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this project. If not, see . + * ******************************************************************************* + */ package de.monticore.lang.monticar.cnntrain._cocos; import de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol; diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/MultiParamValueSymbol.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/MultiParamValueSymbol.java index bff39c9..5f152d8 100644 --- a/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/MultiParamValueSymbol.java +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/MultiParamValueSymbol.java @@ -44,6 +44,10 @@ public class MultiParamValueSymbol extends ValueSymbol { return parameters.get(parameterName); } + public boolean hasParameter(final String parameterName) { + return parameters.containsKey(parameterName); + } + public void addParameter(final String parameterName, final Object value) { parameters.put(parameterName, value); } diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/helper/ConfigEntryNameConstants.java b/src/main/java/de/monticore/lang/monticar/cnntrain/helper/ConfigEntryNameConstants.java index 8e8802e..58c2274 100644 --- a/src/main/java/de/monticore/lang/monticar/cnntrain/helper/ConfigEntryNameConstants.java +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/helper/ConfigEntryNameConstants.java @@ -1,3 +1,23 @@ +/** + * + * ****************************************************************************** + * MontiCAR Modeling Family, www.se-rwth.de + * Copyright (c) 2017, Software Engineering Group at RWTH Aachen, + * All rights reserved. + * + * This project is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 3.0 of the License, or (at your option) any later version. + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this project. If not, see . + * ******************************************************************************* + */ package de.monticore.lang.monticar.cnntrain.helper; /** @@ -32,6 +52,9 @@ public class ConfigEntryNameConstants { public static final String STRATEGY = "strategy"; public static final String STRATEGY_OU = "ornstein_uhlenbeck"; + public static final String STRATEGY_OU_MU = "mu"; + public static final String STRATEGY_OU_THETA = "theta"; + public static final String STRATEGY_OU_SIGMA = "sigma"; public static final String STRATEGY_GAUSSIAN = "gaussian"; public static final String STRATEGY_EPSGREEDY = "epsgreedy"; public static final String STRATEGY_EPSDECAY = "epsdecay"; -- GitLab From 53fe8e9c65e0cfbe89c2e06d39e8fefe857e0d9d Mon Sep 17 00:00:00 2001 From: Nicola Gatto Date: Tue, 23 Jul 2019 17:57:27 +0200 Subject: [PATCH 19/21] Fix cocos for critic inputs --- .../_cocos/CheckCriticNetworkInputs.java | 28 +++++++++---------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckCriticNetworkInputs.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckCriticNetworkInputs.java index 7688402..e2d745d 100644 --- a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckCriticNetworkInputs.java +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckCriticNetworkInputs.java @@ -48,12 +48,6 @@ public class CheckCriticNetworkInputs implements CNNTrainConfigurationSymbolCoCo Log.error("Malformed trained architecture"); } - if (trainedArchitecture.getInputs().size() != 2) { - Log.error("0" + ErrorCodes.CRITIC_NETWORK_ERROR - + "Number of critic network inputs is wrong. Critic network has two inputs," + - "first needs to be a state input and second needs to be the action input."); - } - final String stateInput = trainedArchitecture.getInputs().get(0); final String actionOutput = trainedArchitecture.getOutputs().get(0); final List stateDimensions = trainedArchitecture.getDimensions().get(stateInput); @@ -66,23 +60,29 @@ public class CheckCriticNetworkInputs implements CNNTrainConfigurationSymbolCoCo String criticInput1 = criticNetwork.getInputs().get(0); String criticInput2 = criticNetwork.getInputs().get(1); - if (criticNetwork.getDimensions().get(criticInput1).equals(stateDimensions)) { + if (criticNetwork.getInputs().size() != 2) { + Log.error("0" + ErrorCodes.CRITIC_NETWORK_ERROR + + "Number of critic network inputs is wrong. Critic network has two inputs," + + "first needs to be a state input and second needs to be the action input."); + } + + if (!criticNetwork.getDimensions().get(criticInput1).equals(stateDimensions)) { Log.error("0" + ErrorCodes.CRITIC_NETWORK_ERROR + " Declared critic network is not a critic: Dimensions of first input of critic architecture must be" + " equal to state's dimensions " - + stateDimensions.stream().map(Object::toString).collect(Collectors.joining("{", ",", "}")) + + stateDimensions.stream().map(Object::toString).collect(Collectors.joining(",", "{", "}")) + ".", configurationSymbol.getSourcePosition()); } - if (criticNetwork.getDimensions().get(criticInput2).equals(actionDimensions)) { + if (!criticNetwork.getDimensions().get(criticInput2).equals(actionDimensions)) { Log.error("0" + ErrorCodes.CRITIC_NETWORK_ERROR + " Declared critic network is not a critic: Dimensions of second input of critic architecture must be" + " equal to action's dimensions " - + actionDimensions.stream().map(Object::toString).collect(Collectors.joining("{", ",", "}")) + + actionDimensions.stream().map(Object::toString).collect(Collectors.joining(",", "{", "}")) + ".", configurationSymbol.getSourcePosition()); } - if (criticNetwork.getRanges().get(criticInput1).equals(stateRange)) { + if (!criticNetwork.getRanges().get(criticInput1).equals(stateRange)) { Log.error("0" + ErrorCodes.CRITIC_NETWORK_ERROR + " Declared critic network is not a critic: Ranges of first input of critic architecture must be" + " equal to state's ranges " @@ -90,7 +90,7 @@ public class CheckCriticNetworkInputs implements CNNTrainConfigurationSymbolCoCo + ".", configurationSymbol.getSourcePosition()); } - if (criticNetwork.getRanges().get(criticInput2).equals(actionRange)) { + if (!criticNetwork.getRanges().get(criticInput2).equals(actionRange)) { Log.error("0" + ErrorCodes.CRITIC_NETWORK_ERROR + " Declared critic network is not a critic: Ranges of second input of critic architecture must be" + " equal to action's ranges " @@ -98,7 +98,7 @@ public class CheckCriticNetworkInputs implements CNNTrainConfigurationSymbolCoCo + ".", configurationSymbol.getSourcePosition()); } - if (criticNetwork.getTypes().get(criticInput1).equals(stateType)) { + if (!criticNetwork.getTypes().get(criticInput1).equals(stateType)) { Log.error("0" + ErrorCodes.CRITIC_NETWORK_ERROR + " Declared critic network is not a critic: Type of first input of critic architecture must be" + " equal to state's types " @@ -106,7 +106,7 @@ public class CheckCriticNetworkInputs implements CNNTrainConfigurationSymbolCoCo + ".", configurationSymbol.getSourcePosition()); } - if (criticNetwork.getTypes().get(criticInput2).equals(actionType)) { + if (!criticNetwork.getTypes().get(criticInput2).equals(actionType)) { Log.error("0" + ErrorCodes.CRITIC_NETWORK_ERROR + " Declared critic network is not a critic: Type of second input of critic architecture must be" + " equal to action's types " -- GitLab From 5fc2a4587bfc52f580d06222042a476abc5eca95 Mon Sep 17 00:00:00 2001 From: Nicola Gatto Date: Tue, 23 Jul 2019 18:09:08 +0200 Subject: [PATCH 20/21] Increase version --- pom.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pom.xml b/pom.xml index 9493e23..3797a1f 100644 --- a/pom.xml +++ b/pom.xml @@ -30,7 +30,7 @@ de.monticore.lang.monticar cnn-train - 0.3.5-SNAPSHOT + 0.3.6-SNAPSHOT -- GitLab From e19e3385152e0d44db9cd13f93428bb407f56cce Mon Sep 17 00:00:00 2001 From: Nicola Gatto Date: Mon, 5 Aug 2019 11:27:47 +0200 Subject: [PATCH 21/21] Adapt readme file --- README.md | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 91102da..c8c4b38 100644 --- a/README.md +++ b/README.md @@ -117,7 +117,7 @@ configuration ReinforcementConfig { | Parameter | Value | Default | Required | Algorithm | Description | |------------|--------|---------|----------|-----------|-------------| |learning_method| reinforcement,supervised | supervised | No | All | Determines that this CNNTrain configuration is a reinforcement or supervised learning configuration | -| rl_algorithm | ddpg-algorithm, dqn-algorithm | dqn-algorithm | No | All | Determines the RL algorithm that is used to train the agent +| rl_algorithm | ddpg-algorithm, dqn-algorithm, td3-algorithm | dqn-algorithm | No | All | Determines the RL algorithm that is used to train the agent | agent_name | String | "agent" | No | All | Names the agent (e.g. for logging output) | |environment | gym, ros_interface | Yes | / | All | If *ros_interface* is selected, then the agent and the environment communicates via [ROS](http://www.ros.org/). The gym environment comes with a set of environments which are listed [here](https://gym.openai.com/) | | context | cpu, gpu | cpu | No | All | Determines whether the GPU is used during training or the CPU | @@ -133,12 +133,15 @@ configuration ReinforcementConfig { | replay_memory | buffer, online, combined | buffer | No | All | Determines the behaviour of the replay memory | | strategy | epsgreedy, ornstein_uhlenbeck | epsgreedy (discrete), ornstein_uhlenbeck (continuous) | No | All | Determines the action selection policy during the training | | reward_function | Full name of an EMAM component | / | Yes, if *ros_interface* is selected as the environment and no reward topic is given | All | The EMAM component that is used to calculate the reward. It must have two inputs, one for the current state and one boolean input that determines if the current state is terminal. It must also have exactly one output which represents the reward. | -critic | Full name of architecture definition | / | Yes, if DDPG is selected | DDPG | The architecture definition which specifies the architecture of the critic network | -soft_target_update_rate | Float | 0.001 | No | DDPG | Determines the update rate of the critic and actor target network | -actor_optimizer | See supervised learning | adam with LR .0001 | No | DDPG | Determines the optimizer parameters of the actor network | -critic_optimizer | See supervised learning | adam with LR .001 | No | DDPG | Determines the optimizer parameters of the critic network | +critic | Full name of architecture definition | / | Yes, if DDPG or TD3 is selected | DDPG, TD3 | The architecture definition which specifies the architecture of the critic network | +soft_target_update_rate | Float | 0.001 | No | DDPG, TD3 | Determines the update rate of the critic and actor target network | +actor_optimizer | See supervised learning | adam with LR .0001 | No | DDPG, TD3 | Determines the optimizer parameters of the actor network | +critic_optimizer | See supervised learning | adam with LR .001 | No | DDPG, TD3 | Determines the optimizer parameters of the critic network | | start_training_at | Integer | 0 | No | All | Determines at which episode the training starts | | evaluation_samples | Integer | 100 | No | All | Determines how many epsiodes are run when evaluating the network | +| policy_noise | Float | 0.1 | No | TD3 | Determines the standard deviation of the noise that is added to the actions predicted by the target actor network when calculating the targets. +| noise_clip | Float | 0.5 | No | TD3 | Sets the upper and lower limit of the policy noise +policy_delay | Integer | 2 | No | TD3 | Every policy_delay of steps, the actor network and targets are updated. #### Environment @@ -189,6 +192,7 @@ This strategy is only available for discrete problems. It selects an action base - **epsilon_decay_start**: Number of Episodes after the decay of epsilon starts - **epsilon_decay**: The actual decay of epsilon after each step. - **min_epsilon**: After *min_epsilon* is reached, epsilon is not decreased further. +- **epsilon_decay_per_step**:Expects either true or false. If true, the decay will be performed for each step the agent executes instead of performing the decay after each episode. The default value is false #### Option: ornstein_uhlenbeck @@ -209,6 +213,9 @@ Example: Given an actor network with action output of shape (3,), we can write to specify the parameters for each place. +### Option: gaussian +This strategy is also only available for continuous problems. If this strat- egy is selected, uncorrelated Gaussian noise with zero mean is added to the current policy action selection. This strategy provides the same parameters as the epsgreedy option and the parameter **noise_variance** that determines the variance of the noise. + ## Generation To execute generation in your project, use the following code to generate a separate Config file: -- GitLab