From 8787d934c8dcb1826552e017b23e917b8eb482e7 Mon Sep 17 00:00:00 2001 From: Nicola Gatto Date: Wed, 15 May 2019 22:24:42 +0200 Subject: [PATCH 1/7] Update version# --- pom.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pom.xml b/pom.xml index 1e3a1d1..037c85e 100644 --- a/pom.xml +++ b/pom.xml @@ -30,7 +30,7 @@ de.monticore.lang.monticar cnn-train - 0.3.0-SNAPSHOT + 0.3.1-SNAPSHOT -- GitLab From ebf2c614bb798478b665ad82ef0a470ccaf9dc4a Mon Sep 17 00:00:00 2001 From: Nicola Gatto Date: Wed, 15 May 2019 22:36:01 +0200 Subject: [PATCH 2/7] Add rl learning algorithm parameter --- .../de/monticore/lang/monticar/CNNTrain.mc4 | 3 ++ .../_cocos/ASTConfigurationUtils.java | 20 +++++++++++ .../CheckLearningParameterCombination.java | 1 + .../CNNTrainSymbolTableCreator.java | 16 +++++++++ .../cnntrain/_symboltable/RLAlgorithm.java | 36 +++++++++++++++++++ .../monticar/cnntrain/cocos/AllCoCoTest.java | 1 + .../resources/valid_tests/DdpgConfig.cnnt | 5 +++ 7 files changed, 82 insertions(+) create mode 100644 src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/RLAlgorithm.java create mode 100644 src/test/resources/valid_tests/DdpgConfig.cnnt diff --git a/src/main/grammars/de/monticore/lang/monticar/CNNTrain.mc4 b/src/main/grammars/de/monticore/lang/monticar/CNNTrain.mc4 index 4e81ce1..78aa625 100644 --- a/src/main/grammars/de/monticore/lang/monticar/CNNTrain.mc4 +++ b/src/main/grammars/de/monticore/lang/monticar/CNNTrain.mc4 @@ -98,6 +98,7 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number interface MultiParamValue extends ConfigValue; LearningMethodEntry implements ConfigEntry = name:"learning_method" ":" value:LearningMethodValue; + RLAlgorithmEntry implements ConfigEntry = name:"rl-algorithm" ":" value:RLAlgorithmValue; NumEpisodesEntry implements ConfigEntry = name:"num_episodes" ":" value:IntegerValue; DiscountFactorEntry implements ConfigEntry = name:"discount_factor" ":" value:NumberValue; NumMaxStepsEntry implements ConfigEntry = name:"num_max_steps" ":" value:IntegerValue; @@ -114,6 +115,8 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number LearningMethodValue implements ConfigValue = (supervisedLearning:"supervised" | reinforcement:"reinforcement"); + RLAlgorithmValue implements ConfigValue = (dqn:"dqn" | ddpg:"ddpg"); + interface MultiParamConfigEntry extends ConfigEntry; // Replay Memory diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ASTConfigurationUtils.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ASTConfigurationUtils.java index 8cf3de8..a0d4ed9 100644 --- a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ASTConfigurationUtils.java +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ASTConfigurationUtils.java @@ -1,3 +1,23 @@ +/** + * + * ****************************************************************************** + * MontiCAR Modeling Family, www.se-rwth.de + * Copyright (c) 2017, Software Engineering Group at RWTH Aachen, + * All rights reserved. + * + * This project is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 3.0 of the License, or (at your option) any later version. + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this project. If not, see . + * ******************************************************************************* + */ package de.monticore.lang.monticar.cnntrain._cocos; import de.monticore.lang.monticar.cnntrain._ast.ASTConfiguration; diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckLearningParameterCombination.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckLearningParameterCombination.java index 0dd6d8a..50cba34 100644 --- a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckLearningParameterCombination.java +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckLearningParameterCombination.java @@ -61,6 +61,7 @@ public class CheckLearningParameterCombination implements CNNTrainASTEntryCoCo { ); private final static List ALLOWED_REINFORCEMENT_LEARNING = Lists.newArrayList( ASTTrainContextEntry.class, + ASTRLAlgorithmEntry.class, ASTOptimizerEntry.class, ASTRewardFunctionEntry.class, ASTMinimumLearningRateEntry.class, diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/CNNTrainSymbolTableCreator.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/CNNTrainSymbolTableCreator.java index a72c5fc..e370910 100644 --- a/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/CNNTrainSymbolTableCreator.java +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/CNNTrainSymbolTableCreator.java @@ -310,6 +310,22 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP { configuration.getEntryMap().put(node.getName(), entry); } + @Override + public void visit(ASTRLAlgorithmEntry node) { + EntrySymbol entry = new EntrySymbol(node.getName()); + ValueSymbol value = new ValueSymbol(); + + if (node.getValue().isPresentDdpg()) { + value.setValue(RLAlgorithm.DDPG); + } else { + value.setValue(RLAlgorithm.DQN); + } + + entry.setValue(value); + addToScopeAndLinkWithNode(entry, node); + configuration.getEntryMap().put(node.getName(), entry); + } + @Override public void visit(ASTNumEpisodesEntry node) { EntrySymbol entry = new EntrySymbol(node.getName()); diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/RLAlgorithm.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/RLAlgorithm.java new file mode 100644 index 0000000..f0e38ff --- /dev/null +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/RLAlgorithm.java @@ -0,0 +1,36 @@ +/** + * + * ****************************************************************************** + * MontiCAR Modeling Family, www.se-rwth.de + * Copyright (c) 2017, Software Engineering Group at RWTH Aachen, + * All rights reserved. + * + * This project is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 3.0 of the License, or (at your option) any later version. + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this project. If not, see . + * ******************************************************************************* + */ +package de.monticore.lang.monticar.cnntrain._symboltable; + +public enum RLAlgorithm { + DQN { + @Override + public String toString() { + return "dqn"; + } + }, + DDPG { + @Override + public String toString() { + return "ddpg"; + } + } +} diff --git a/src/test/java/de/monticore/lang/monticar/cnntrain/cocos/AllCoCoTest.java b/src/test/java/de/monticore/lang/monticar/cnntrain/cocos/AllCoCoTest.java index ad6ee8d..57cf32f 100644 --- a/src/test/java/de/monticore/lang/monticar/cnntrain/cocos/AllCoCoTest.java +++ b/src/test/java/de/monticore/lang/monticar/cnntrain/cocos/AllCoCoTest.java @@ -41,6 +41,7 @@ public class AllCoCoTest extends AbstractCoCoTest{ checkValid("valid_tests","FullConfig2"); checkValid("valid_tests", "ReinforcementConfig"); checkValid("valid_tests", "ReinforcementConfig2"); + checkValid("valid_tests", "DdpgConfig"); } @Test diff --git a/src/test/resources/valid_tests/DdpgConfig.cnnt b/src/test/resources/valid_tests/DdpgConfig.cnnt new file mode 100644 index 0000000..984ff8d --- /dev/null +++ b/src/test/resources/valid_tests/DdpgConfig.cnnt @@ -0,0 +1,5 @@ +configuration DdpgConfig { + learning_method : reinforcement + rl-algorithm : ddpg + environment : gym { name:"CartPole-v1" } +} \ No newline at end of file -- GitLab From c2dd5ea10aeee6dafe16e8055a7d6a6163cb267a Mon Sep 17 00:00:00 2001 From: Nicola Gatto Date: Thu, 16 May 2019 00:21:06 +0200 Subject: [PATCH 3/7] Add critic parameter --- .../de/monticore/lang/monticar/CNNTrain.mc4 | 3 +- .../cnntrain/_cocos/CNNTrainCocos.java | 3 +- .../CheckDdpgRequiresCriticNetwork.java | 50 +++++++++++++++++++ .../CheckLearningParameterCombination.java | 1 + .../CNNTrainSymbolTableCreator.java | 15 ++++++ .../resources/valid_tests/DdpgConfig.cnnt | 3 +- 6 files changed, 72 insertions(+), 3 deletions(-) create mode 100644 src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckDdpgRequiresCriticNetwork.java diff --git a/src/main/grammars/de/monticore/lang/monticar/CNNTrain.mc4 b/src/main/grammars/de/monticore/lang/monticar/CNNTrain.mc4 index 78aa625..80a8b9e 100644 --- a/src/main/grammars/de/monticore/lang/monticar/CNNTrain.mc4 +++ b/src/main/grammars/de/monticore/lang/monticar/CNNTrain.mc4 @@ -98,7 +98,7 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number interface MultiParamValue extends ConfigValue; LearningMethodEntry implements ConfigEntry = name:"learning_method" ":" value:LearningMethodValue; - RLAlgorithmEntry implements ConfigEntry = name:"rl-algorithm" ":" value:RLAlgorithmValue; + RLAlgorithmEntry implements ConfigEntry = name:"rl_algorithm" ":" value:RLAlgorithmValue; NumEpisodesEntry implements ConfigEntry = name:"num_episodes" ":" value:IntegerValue; DiscountFactorEntry implements ConfigEntry = name:"discount_factor" ":" value:NumberValue; NumMaxStepsEntry implements ConfigEntry = name:"num_max_steps" ":" value:IntegerValue; @@ -110,6 +110,7 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number AgentNameEntry implements ConfigEntry = name:"agent_name" ":" value:StringValue; UseDoubleDQNEntry implements ConfigEntry = name:"use_double_dqn" ":" value:BooleanValue; RewardFunctionEntry implements ConfigEntry = name:"reward_function" ":" value:ComponentNameValue; + CriticNetworkEntry implements ConfigEntry = name:"critic" ":" value:ComponentNameValue; ComponentNameValue implements ConfigValue = Name ("."Name)*; diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CNNTrainCocos.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CNNTrainCocos.java index ac2a8fd..3d04980 100644 --- a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CNNTrainCocos.java +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CNNTrainCocos.java @@ -33,7 +33,8 @@ public class CNNTrainCocos { .addCoCo(new CheckFixTargetNetworkRequiresInterval()) .addCoCo(new CheckReinforcementRequiresEnvironment()) .addCoCo(new CheckLearningParameterCombination()) - .addCoCo(new CheckRosEnvironmentRequiresRewardFunction()); + .addCoCo(new CheckRosEnvironmentRequiresRewardFunction()) + .addCoCo(new CheckDdpgRequiresCriticNetwork()); } public static void checkAll(CNNTrainCompilationUnitSymbol compilationUnit){ diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckDdpgRequiresCriticNetwork.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckDdpgRequiresCriticNetwork.java new file mode 100644 index 0000000..6ca5412 --- /dev/null +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckDdpgRequiresCriticNetwork.java @@ -0,0 +1,50 @@ +/** + * + * ****************************************************************************** + * MontiCAR Modeling Family, www.se-rwth.de + * Copyright (c) 2017, Software Engineering Group at RWTH Aachen, + * All rights reserved. + * + * This project is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 3.0 of the License, or (at your option) any later version. + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this project. If not, see . + * ******************************************************************************* + */ +package de.monticore.lang.monticar.cnntrain._cocos; + +import de.monticore.lang.monticar.cnntrain._ast.ASTConfiguration; +import de.monticore.lang.monticar.cnntrain._ast.ASTCriticNetworkEntry; +import de.monticore.lang.monticar.cnntrain._ast.ASTRLAlgorithmEntry; +import de.monticore.lang.monticar.cnntrain.helper.ErrorCodes; +import de.se_rwth.commons.logging.Log; + +public class CheckDdpgRequiresCriticNetwork implements CNNTrainASTConfigurationCoCo { + + @Override + public void check(ASTConfiguration node) { + boolean isDdpg = node.getEntriesList().stream() + .anyMatch(e -> e instanceof ASTRLAlgorithmEntry + && ((ASTRLAlgorithmEntry)e).getValue().isPresentDdpg()); + boolean hasCriticEntry = node.getEntriesList().stream() + .anyMatch(e -> ((e instanceof ASTCriticNetworkEntry) + && !((ASTCriticNetworkEntry)e).getValue().getNameList().isEmpty())); + + if (isDdpg && !hasCriticEntry) { + ASTRLAlgorithmEntry algorithmEntry = node.getEntriesList().stream() + .filter(e -> e instanceof ASTRLAlgorithmEntry) + .map(e -> (ASTRLAlgorithmEntry)e) + .findFirst() + .orElseThrow(() -> new IllegalStateException("ASTRLAlgorithmEntry entry must be available")); + Log.error("0" + ErrorCodes.REQUIRED_PARAMETER_MISSING + " DDPG learning algorithm requires critc" + + " network entry", algorithmEntry.get_SourcePositionStart()); + } + } +} \ No newline at end of file diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckLearningParameterCombination.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckLearningParameterCombination.java index 50cba34..90ea12a 100644 --- a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckLearningParameterCombination.java +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckLearningParameterCombination.java @@ -62,6 +62,7 @@ public class CheckLearningParameterCombination implements CNNTrainASTEntryCoCo { private final static List ALLOWED_REINFORCEMENT_LEARNING = Lists.newArrayList( ASTTrainContextEntry.class, ASTRLAlgorithmEntry.class, + ASTCriticNetworkEntry.class, ASTOptimizerEntry.class, ASTRewardFunctionEntry.class, ASTMinimumLearningRateEntry.class, diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/CNNTrainSymbolTableCreator.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/CNNTrainSymbolTableCreator.java index e370910..fbfe065 100644 --- a/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/CNNTrainSymbolTableCreator.java +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/CNNTrainSymbolTableCreator.java @@ -277,6 +277,13 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP { return value; } + private ValueSymbol getValueSymbolForComponentName(ASTComponentNameValue astComponentNameValue) { + ValueSymbol value = new ValueSymbol(); + List valueAsList = astComponentNameValue.getNameList(); + value.setValue(valueAsList); + return value; + } + private String getStringFromStringValue(ASTStringValue value) { return value.getStringLiteral().getValue(); } @@ -406,6 +413,14 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP { configuration.getEntryMap().put(node.getName(), entry); } + @Override + public void visit(ASTCriticNetworkEntry node) { + EntrySymbol entry = new EntrySymbol(node.getName()); + entry.setValue(getValueSymbolForComponentName(node.getValue())); + addToScopeAndLinkWithNode(entry, node); + configuration.getEntryMap().put(node.getName(), entry); + } + @Override public void visit(ASTReplayMemoryEntry node) { processMultiParamConfigVisit(node, node.getValue().getName()); diff --git a/src/test/resources/valid_tests/DdpgConfig.cnnt b/src/test/resources/valid_tests/DdpgConfig.cnnt index 984ff8d..45e67f8 100644 --- a/src/test/resources/valid_tests/DdpgConfig.cnnt +++ b/src/test/resources/valid_tests/DdpgConfig.cnnt @@ -1,5 +1,6 @@ configuration DdpgConfig { learning_method : reinforcement - rl-algorithm : ddpg + rl_algorithm : ddpg + critic : path.to.component environment : gym { name:"CartPole-v1" } } \ No newline at end of file -- GitLab From 0b4642f207bd6d770cb766014e8fe47ad0794ecf Mon Sep 17 00:00:00 2001 From: Nicola Gatto Date: Thu, 16 May 2019 15:12:08 +0200 Subject: [PATCH 4/7] Add architecture annotation --- .../de/monticore/lang/monticar/CNNTrain.mc4 | 2 +- .../_symboltable/ConfigurationSymbol.java | 13 ++++++- .../monticar/cnntrain/annotations/Range.java | 39 +++++++++++++++++++ .../annotations/TrainedArchitecture.java | 32 +++++++++++++++ .../resources/valid_tests/DdpgConfig.cnnt | 2 +- 5 files changed, 85 insertions(+), 3 deletions(-) create mode 100644 src/main/java/de/monticore/lang/monticar/cnntrain/annotations/Range.java create mode 100644 src/main/java/de/monticore/lang/monticar/cnntrain/annotations/TrainedArchitecture.java diff --git a/src/main/grammars/de/monticore/lang/monticar/CNNTrain.mc4 b/src/main/grammars/de/monticore/lang/monticar/CNNTrain.mc4 index 80a8b9e..7409708 100644 --- a/src/main/grammars/de/monticore/lang/monticar/CNNTrain.mc4 +++ b/src/main/grammars/de/monticore/lang/monticar/CNNTrain.mc4 @@ -116,7 +116,7 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number LearningMethodValue implements ConfigValue = (supervisedLearning:"supervised" | reinforcement:"reinforcement"); - RLAlgorithmValue implements ConfigValue = (dqn:"dqn" | ddpg:"ddpg"); + RLAlgorithmValue implements ConfigValue = (dqn:"dqn-algorithm" | ddpg:"ddpg-algorithm"); interface MultiParamConfigEntry extends ConfigEntry; diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/ConfigurationSymbol.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/ConfigurationSymbol.java index 76ccf68..f137c47 100644 --- a/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/ConfigurationSymbol.java +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/ConfigurationSymbol.java @@ -21,6 +21,7 @@ package de.monticore.lang.monticar.cnntrain._symboltable; import com.google.common.collect.Lists; +import de.monticore.lang.monticar.cnntrain.annotations.TrainedArchitecture; import de.monticore.symboltable.CommonScopeSpanningSymbol; import java.util.*; @@ -30,12 +31,14 @@ public class ConfigurationSymbol extends CommonScopeSpanningSymbol { private Map entryMap = new HashMap<>(); private OptimizerSymbol optimizer; private RewardFunctionSymbol rlRewardFunctionSymbol; + private TrainedArchitecture trainedArchitecture; public static final ConfigurationSymbolKind KIND = new ConfigurationSymbolKind(); public ConfigurationSymbol() { super("", KIND); rlRewardFunctionSymbol = null; + trainedArchitecture = null; } public OptimizerSymbol getOptimizer() { @@ -54,6 +57,14 @@ public class ConfigurationSymbol extends CommonScopeSpanningSymbol { return Optional.ofNullable(this.rlRewardFunctionSymbol); } + public Optional getTrainedArchitecture() { + return Optional.ofNullable(trainedArchitecture); + } + + public void setTrainedArchitecture(TrainedArchitecture trainedArchitecture) { + this.trainedArchitecture = trainedArchitecture; + } + public Map getEntryMap() { return entryMap; } @@ -66,4 +77,4 @@ public class ConfigurationSymbol extends CommonScopeSpanningSymbol { return this.entryMap.containsKey("learning_method") ? (LearningMethod)this.entryMap.get("learning_method").getValue().getValue() : LearningMethod.SUPERVISED; } -} +} \ No newline at end of file diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/annotations/Range.java b/src/main/java/de/monticore/lang/monticar/cnntrain/annotations/Range.java new file mode 100644 index 0000000..a067a44 --- /dev/null +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/annotations/Range.java @@ -0,0 +1,39 @@ +/** + * + * ****************************************************************************** + * MontiCAR Modeling Family, www.se-rwth.de + * Copyright (c) 2017, Software Engineering Group at RWTH Aachen, + * All rights reserved. + * + * This project is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 3.0 of the License, or (at your option) any later version. + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this project. If not, see . + * ******************************************************************************* + */ +package de.monticore.lang.monticar.cnntrain.annotations; + +public class Range { + private final double lowerLimit; + private final double upperLimit; + + public Range(final double lowerLimit, final double upperLimit) { + this.lowerLimit = lowerLimit; + this.upperLimit = upperLimit; + } + + public double getLowerLimit() { + return lowerLimit; + } + + public double getUpperLimit() { + return upperLimit; + } +} diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/annotations/TrainedArchitecture.java b/src/main/java/de/monticore/lang/monticar/cnntrain/annotations/TrainedArchitecture.java new file mode 100644 index 0000000..45231d5 --- /dev/null +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/annotations/TrainedArchitecture.java @@ -0,0 +1,32 @@ +/** + * + * ****************************************************************************** + * MontiCAR Modeling Family, www.se-rwth.de + * Copyright (c) 2017, Software Engineering Group at RWTH Aachen, + * All rights reserved. + * + * This project is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 3.0 of the License, or (at your option) any later version. + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this project. If not, see . + * ******************************************************************************* + */ +package de.monticore.lang.monticar.cnntrain.annotations; + +import java.util.List; +import java.util.Map; + +public interface TrainedArchitecture { + public List getInputs(); + public List getOutputs(); + public Map> getDimensions(); + public Map getRanges(); + +} \ No newline at end of file diff --git a/src/test/resources/valid_tests/DdpgConfig.cnnt b/src/test/resources/valid_tests/DdpgConfig.cnnt index 45e67f8..45f3b28 100644 --- a/src/test/resources/valid_tests/DdpgConfig.cnnt +++ b/src/test/resources/valid_tests/DdpgConfig.cnnt @@ -1,6 +1,6 @@ configuration DdpgConfig { learning_method : reinforcement - rl_algorithm : ddpg + rl_algorithm : ddpg-algorithm critic : path.to.component environment : gym { name:"CartPole-v1" } } \ No newline at end of file -- GitLab From 15d548b0b7db84dbfff7d7688e72090cf52bf811 Mon Sep 17 00:00:00 2001 From: Nicola Gatto Date: Thu, 16 May 2019 22:26:05 +0200 Subject: [PATCH 5/7] Add infinity range --- .../monticar/cnntrain/annotations/Range.java | 44 ++++++++++++++++--- 1 file changed, 37 insertions(+), 7 deletions(-) diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/annotations/Range.java b/src/main/java/de/monticore/lang/monticar/cnntrain/annotations/Range.java index a067a44..bbbb3fc 100644 --- a/src/main/java/de/monticore/lang/monticar/cnntrain/annotations/Range.java +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/annotations/Range.java @@ -20,20 +20,50 @@ */ package de.monticore.lang.monticar.cnntrain.annotations; +import java.util.Optional; + public class Range { - private final double lowerLimit; - private final double upperLimit; + private final boolean lowerLimitIsInfinity; + private final boolean upperLimitIsInfinity; + private final Double lowerLimit; + private final Double upperLimit; - public Range(final double lowerLimit, final double upperLimit) { + private Range(boolean lowerLimitIsInfinity, boolean upperLimitIsInfinity, Double lowerLimit, Double upperLimit) { + this.lowerLimitIsInfinity = lowerLimitIsInfinity; + this.upperLimitIsInfinity = upperLimitIsInfinity; this.lowerLimit = lowerLimit; this.upperLimit = upperLimit; } - public double getLowerLimit() { - return lowerLimit; + public Optional getLowerLimit() { + return Optional.ofNullable(lowerLimit); + } + + public Optional getUpperLimit() { + return Optional.ofNullable(upperLimit); + } + + public boolean isLowerLimitInfinity() { + return this.lowerLimitIsInfinity; + } + + public boolean isUpperLimitInfinity() { + return this.upperLimitIsInfinity; + } + + public static Range withLimits(double lowerLimit, double upperLimit) { + return new Range(false, false, lowerLimit, upperLimit); + } + + public static Range withInfinityLimits() { + return new Range(true, true, null, null); + } + + public static Range withUpperInfinityLimit(double lowerLimit) { + return new Range(false, true, lowerLimit, null); } - public double getUpperLimit() { - return upperLimit; + public static Range withLowerInfinityLimit(double upperLimit) { + return new Range(true, false, null, upperLimit); } } -- GitLab From 52984ecc03997778259845f56f7de46a0533d126 Mon Sep 17 00:00:00 2001 From: Nicola Gatto Date: Fri, 17 May 2019 16:01:37 +0200 Subject: [PATCH 6/7] Return critic component name as string --- .../_symboltable/CNNTrainSymbolTableCreator.java | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/CNNTrainSymbolTableCreator.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/CNNTrainSymbolTableCreator.java index fbfe065..2c45115 100644 --- a/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/CNNTrainSymbolTableCreator.java +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/CNNTrainSymbolTableCreator.java @@ -284,6 +284,12 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP { return value; } + private ValueSymbol getValueSymbolForComponentNameAsString(ASTComponentNameValue astComponentNameValue) { + ValueSymbol value = new ValueSymbol(); + value.setValue(String.join(".", astComponentNameValue.getNameList())); + return value; + } + private String getStringFromStringValue(ASTStringValue value) { return value.getStringLiteral().getValue(); } @@ -416,11 +422,13 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP { @Override public void visit(ASTCriticNetworkEntry node) { EntrySymbol entry = new EntrySymbol(node.getName()); - entry.setValue(getValueSymbolForComponentName(node.getValue())); + entry.setValue(getValueSymbolForComponentNameAsString(node.getValue())); addToScopeAndLinkWithNode(entry, node); configuration.getEntryMap().put(node.getName(), entry); } + + @Override public void visit(ASTReplayMemoryEntry node) { processMultiParamConfigVisit(node, node.getValue().getName()); -- GitLab From f05a26a9abef1da9a8a8278bde4f1278d2c77b96 Mon Sep 17 00:00:00 2001 From: Nicola Gatto Date: Fri, 17 May 2019 17:12:06 +0200 Subject: [PATCH 7/7] Add type information for inputs of trained architecture --- .../lang/monticar/cnntrain/annotations/TrainedArchitecture.java | 1 + 1 file changed, 1 insertion(+) diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/annotations/TrainedArchitecture.java b/src/main/java/de/monticore/lang/monticar/cnntrain/annotations/TrainedArchitecture.java index 45231d5..0308a50 100644 --- a/src/main/java/de/monticore/lang/monticar/cnntrain/annotations/TrainedArchitecture.java +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/annotations/TrainedArchitecture.java @@ -28,5 +28,6 @@ public interface TrainedArchitecture { public List getOutputs(); public Map> getDimensions(); public Map getRanges(); + public Map getTypes(); } \ No newline at end of file -- GitLab