diff --git a/pom.xml b/pom.xml index 1e3a1d10a07c4fb0436d29176dc5ba2df3180e74..037c85e99a570538570c8ea77f95eb4dcd1dafa6 100644 --- a/pom.xml +++ b/pom.xml @@ -30,7 +30,7 @@ de.monticore.lang.monticar cnn-train - 0.3.0-SNAPSHOT + 0.3.1-SNAPSHOT diff --git a/src/main/grammars/de/monticore/lang/monticar/CNNTrain.mc4 b/src/main/grammars/de/monticore/lang/monticar/CNNTrain.mc4 index 4e81ce1e756dd9f778aaae37b370ca0e316e797e..7409708969bc5bdbaac2f1500888a652be3abfe3 100644 --- a/src/main/grammars/de/monticore/lang/monticar/CNNTrain.mc4 +++ b/src/main/grammars/de/monticore/lang/monticar/CNNTrain.mc4 @@ -98,6 +98,7 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number interface MultiParamValue extends ConfigValue; LearningMethodEntry implements ConfigEntry = name:"learning_method" ":" value:LearningMethodValue; + RLAlgorithmEntry implements ConfigEntry = name:"rl_algorithm" ":" value:RLAlgorithmValue; NumEpisodesEntry implements ConfigEntry = name:"num_episodes" ":" value:IntegerValue; DiscountFactorEntry implements ConfigEntry = name:"discount_factor" ":" value:NumberValue; NumMaxStepsEntry implements ConfigEntry = name:"num_max_steps" ":" value:IntegerValue; @@ -109,11 +110,14 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number AgentNameEntry implements ConfigEntry = name:"agent_name" ":" value:StringValue; UseDoubleDQNEntry implements ConfigEntry = name:"use_double_dqn" ":" value:BooleanValue; RewardFunctionEntry implements ConfigEntry = name:"reward_function" ":" value:ComponentNameValue; + CriticNetworkEntry implements ConfigEntry = name:"critic" ":" value:ComponentNameValue; ComponentNameValue implements ConfigValue = Name ("."Name)*; LearningMethodValue implements ConfigValue = (supervisedLearning:"supervised" | reinforcement:"reinforcement"); + RLAlgorithmValue implements ConfigValue = (dqn:"dqn-algorithm" | ddpg:"ddpg-algorithm"); + interface MultiParamConfigEntry extends ConfigEntry; // Replay Memory diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ASTConfigurationUtils.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ASTConfigurationUtils.java index 8cf3de8d846b779dfeef71af43f0b5e031c159e1..a0d4ed9b2726010e72787e83adfa9060334f831b 100644 --- a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ASTConfigurationUtils.java +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ASTConfigurationUtils.java @@ -1,3 +1,23 @@ +/** + * + * ****************************************************************************** + * MontiCAR Modeling Family, www.se-rwth.de + * Copyright (c) 2017, Software Engineering Group at RWTH Aachen, + * All rights reserved. + * + * This project is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 3.0 of the License, or (at your option) any later version. + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this project. If not, see . + * ******************************************************************************* + */ package de.monticore.lang.monticar.cnntrain._cocos; import de.monticore.lang.monticar.cnntrain._ast.ASTConfiguration; diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CNNTrainCocos.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CNNTrainCocos.java index ac2a8fd21160d2ea753baf6460a09d7b5d18df9b..3d0498079d3186fb7bcff99b17791580a4c3321b 100644 --- a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CNNTrainCocos.java +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CNNTrainCocos.java @@ -33,7 +33,8 @@ public class CNNTrainCocos { .addCoCo(new CheckFixTargetNetworkRequiresInterval()) .addCoCo(new CheckReinforcementRequiresEnvironment()) .addCoCo(new CheckLearningParameterCombination()) - .addCoCo(new CheckRosEnvironmentRequiresRewardFunction()); + .addCoCo(new CheckRosEnvironmentRequiresRewardFunction()) + .addCoCo(new CheckDdpgRequiresCriticNetwork()); } public static void checkAll(CNNTrainCompilationUnitSymbol compilationUnit){ diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckDdpgRequiresCriticNetwork.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckDdpgRequiresCriticNetwork.java new file mode 100644 index 0000000000000000000000000000000000000000..6ca5412ff0ff5c20441cb8284d4c67d848be26e1 --- /dev/null +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckDdpgRequiresCriticNetwork.java @@ -0,0 +1,50 @@ +/** + * + * ****************************************************************************** + * MontiCAR Modeling Family, www.se-rwth.de + * Copyright (c) 2017, Software Engineering Group at RWTH Aachen, + * All rights reserved. + * + * This project is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 3.0 of the License, or (at your option) any later version. + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this project. If not, see . + * ******************************************************************************* + */ +package de.monticore.lang.monticar.cnntrain._cocos; + +import de.monticore.lang.monticar.cnntrain._ast.ASTConfiguration; +import de.monticore.lang.monticar.cnntrain._ast.ASTCriticNetworkEntry; +import de.monticore.lang.monticar.cnntrain._ast.ASTRLAlgorithmEntry; +import de.monticore.lang.monticar.cnntrain.helper.ErrorCodes; +import de.se_rwth.commons.logging.Log; + +public class CheckDdpgRequiresCriticNetwork implements CNNTrainASTConfigurationCoCo { + + @Override + public void check(ASTConfiguration node) { + boolean isDdpg = node.getEntriesList().stream() + .anyMatch(e -> e instanceof ASTRLAlgorithmEntry + && ((ASTRLAlgorithmEntry)e).getValue().isPresentDdpg()); + boolean hasCriticEntry = node.getEntriesList().stream() + .anyMatch(e -> ((e instanceof ASTCriticNetworkEntry) + && !((ASTCriticNetworkEntry)e).getValue().getNameList().isEmpty())); + + if (isDdpg && !hasCriticEntry) { + ASTRLAlgorithmEntry algorithmEntry = node.getEntriesList().stream() + .filter(e -> e instanceof ASTRLAlgorithmEntry) + .map(e -> (ASTRLAlgorithmEntry)e) + .findFirst() + .orElseThrow(() -> new IllegalStateException("ASTRLAlgorithmEntry entry must be available")); + Log.error("0" + ErrorCodes.REQUIRED_PARAMETER_MISSING + " DDPG learning algorithm requires critc" + + " network entry", algorithmEntry.get_SourcePositionStart()); + } + } +} \ No newline at end of file diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckLearningParameterCombination.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckLearningParameterCombination.java index 0dd6d8a04507f53d2be550c6bc2fc3e17be9bf2d..90ea12a0c3900dcba2e40d4389f8ec30cc2cb097 100644 --- a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckLearningParameterCombination.java +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckLearningParameterCombination.java @@ -61,6 +61,8 @@ public class CheckLearningParameterCombination implements CNNTrainASTEntryCoCo { ); private final static List ALLOWED_REINFORCEMENT_LEARNING = Lists.newArrayList( ASTTrainContextEntry.class, + ASTRLAlgorithmEntry.class, + ASTCriticNetworkEntry.class, ASTOptimizerEntry.class, ASTRewardFunctionEntry.class, ASTMinimumLearningRateEntry.class, diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/CNNTrainSymbolTableCreator.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/CNNTrainSymbolTableCreator.java index a72c5fc91b244eadb729b3a1123af90f243e8e23..2c45115a84fd4a03f90dcca3a8ec372eb03310dd 100644 --- a/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/CNNTrainSymbolTableCreator.java +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/CNNTrainSymbolTableCreator.java @@ -277,6 +277,19 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP { return value; } + private ValueSymbol getValueSymbolForComponentName(ASTComponentNameValue astComponentNameValue) { + ValueSymbol value = new ValueSymbol(); + List valueAsList = astComponentNameValue.getNameList(); + value.setValue(valueAsList); + return value; + } + + private ValueSymbol getValueSymbolForComponentNameAsString(ASTComponentNameValue astComponentNameValue) { + ValueSymbol value = new ValueSymbol(); + value.setValue(String.join(".", astComponentNameValue.getNameList())); + return value; + } + private String getStringFromStringValue(ASTStringValue value) { return value.getStringLiteral().getValue(); } @@ -310,6 +323,22 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP { configuration.getEntryMap().put(node.getName(), entry); } + @Override + public void visit(ASTRLAlgorithmEntry node) { + EntrySymbol entry = new EntrySymbol(node.getName()); + ValueSymbol value = new ValueSymbol(); + + if (node.getValue().isPresentDdpg()) { + value.setValue(RLAlgorithm.DDPG); + } else { + value.setValue(RLAlgorithm.DQN); + } + + entry.setValue(value); + addToScopeAndLinkWithNode(entry, node); + configuration.getEntryMap().put(node.getName(), entry); + } + @Override public void visit(ASTNumEpisodesEntry node) { EntrySymbol entry = new EntrySymbol(node.getName()); @@ -390,6 +419,16 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP { configuration.getEntryMap().put(node.getName(), entry); } + @Override + public void visit(ASTCriticNetworkEntry node) { + EntrySymbol entry = new EntrySymbol(node.getName()); + entry.setValue(getValueSymbolForComponentNameAsString(node.getValue())); + addToScopeAndLinkWithNode(entry, node); + configuration.getEntryMap().put(node.getName(), entry); + } + + + @Override public void visit(ASTReplayMemoryEntry node) { processMultiParamConfigVisit(node, node.getValue().getName()); diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/ConfigurationSymbol.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/ConfigurationSymbol.java index 76ccf68aedc05d160ec757bf610e5a38e98daf47..f137c47a28b4cc1f4b1332a733047c006818f816 100644 --- a/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/ConfigurationSymbol.java +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/ConfigurationSymbol.java @@ -21,6 +21,7 @@ package de.monticore.lang.monticar.cnntrain._symboltable; import com.google.common.collect.Lists; +import de.monticore.lang.monticar.cnntrain.annotations.TrainedArchitecture; import de.monticore.symboltable.CommonScopeSpanningSymbol; import java.util.*; @@ -30,12 +31,14 @@ public class ConfigurationSymbol extends CommonScopeSpanningSymbol { private Map entryMap = new HashMap<>(); private OptimizerSymbol optimizer; private RewardFunctionSymbol rlRewardFunctionSymbol; + private TrainedArchitecture trainedArchitecture; public static final ConfigurationSymbolKind KIND = new ConfigurationSymbolKind(); public ConfigurationSymbol() { super("", KIND); rlRewardFunctionSymbol = null; + trainedArchitecture = null; } public OptimizerSymbol getOptimizer() { @@ -54,6 +57,14 @@ public class ConfigurationSymbol extends CommonScopeSpanningSymbol { return Optional.ofNullable(this.rlRewardFunctionSymbol); } + public Optional getTrainedArchitecture() { + return Optional.ofNullable(trainedArchitecture); + } + + public void setTrainedArchitecture(TrainedArchitecture trainedArchitecture) { + this.trainedArchitecture = trainedArchitecture; + } + public Map getEntryMap() { return entryMap; } @@ -66,4 +77,4 @@ public class ConfigurationSymbol extends CommonScopeSpanningSymbol { return this.entryMap.containsKey("learning_method") ? (LearningMethod)this.entryMap.get("learning_method").getValue().getValue() : LearningMethod.SUPERVISED; } -} +} \ No newline at end of file diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/RLAlgorithm.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/RLAlgorithm.java new file mode 100644 index 0000000000000000000000000000000000000000..f0e38ffffc5f7cebde523b53daaeb0243bbc179b --- /dev/null +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/RLAlgorithm.java @@ -0,0 +1,36 @@ +/** + * + * ****************************************************************************** + * MontiCAR Modeling Family, www.se-rwth.de + * Copyright (c) 2017, Software Engineering Group at RWTH Aachen, + * All rights reserved. + * + * This project is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 3.0 of the License, or (at your option) any later version. + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this project. If not, see . + * ******************************************************************************* + */ +package de.monticore.lang.monticar.cnntrain._symboltable; + +public enum RLAlgorithm { + DQN { + @Override + public String toString() { + return "dqn"; + } + }, + DDPG { + @Override + public String toString() { + return "ddpg"; + } + } +} diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/annotations/Range.java b/src/main/java/de/monticore/lang/monticar/cnntrain/annotations/Range.java new file mode 100644 index 0000000000000000000000000000000000000000..bbbb3fc7041722519c8f2bef8d99b3ceb2dc0767 --- /dev/null +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/annotations/Range.java @@ -0,0 +1,69 @@ +/** + * + * ****************************************************************************** + * MontiCAR Modeling Family, www.se-rwth.de + * Copyright (c) 2017, Software Engineering Group at RWTH Aachen, + * All rights reserved. + * + * This project is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 3.0 of the License, or (at your option) any later version. + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this project. If not, see . + * ******************************************************************************* + */ +package de.monticore.lang.monticar.cnntrain.annotations; + +import java.util.Optional; + +public class Range { + private final boolean lowerLimitIsInfinity; + private final boolean upperLimitIsInfinity; + private final Double lowerLimit; + private final Double upperLimit; + + private Range(boolean lowerLimitIsInfinity, boolean upperLimitIsInfinity, Double lowerLimit, Double upperLimit) { + this.lowerLimitIsInfinity = lowerLimitIsInfinity; + this.upperLimitIsInfinity = upperLimitIsInfinity; + this.lowerLimit = lowerLimit; + this.upperLimit = upperLimit; + } + + public Optional getLowerLimit() { + return Optional.ofNullable(lowerLimit); + } + + public Optional getUpperLimit() { + return Optional.ofNullable(upperLimit); + } + + public boolean isLowerLimitInfinity() { + return this.lowerLimitIsInfinity; + } + + public boolean isUpperLimitInfinity() { + return this.upperLimitIsInfinity; + } + + public static Range withLimits(double lowerLimit, double upperLimit) { + return new Range(false, false, lowerLimit, upperLimit); + } + + public static Range withInfinityLimits() { + return new Range(true, true, null, null); + } + + public static Range withUpperInfinityLimit(double lowerLimit) { + return new Range(false, true, lowerLimit, null); + } + + public static Range withLowerInfinityLimit(double upperLimit) { + return new Range(true, false, null, upperLimit); + } +} diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/annotations/TrainedArchitecture.java b/src/main/java/de/monticore/lang/monticar/cnntrain/annotations/TrainedArchitecture.java new file mode 100644 index 0000000000000000000000000000000000000000..0308a504cc97238158d931cab6b85489a05c926c --- /dev/null +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/annotations/TrainedArchitecture.java @@ -0,0 +1,33 @@ +/** + * + * ****************************************************************************** + * MontiCAR Modeling Family, www.se-rwth.de + * Copyright (c) 2017, Software Engineering Group at RWTH Aachen, + * All rights reserved. + * + * This project is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 3.0 of the License, or (at your option) any later version. + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this project. If not, see . + * ******************************************************************************* + */ +package de.monticore.lang.monticar.cnntrain.annotations; + +import java.util.List; +import java.util.Map; + +public interface TrainedArchitecture { + public List getInputs(); + public List getOutputs(); + public Map> getDimensions(); + public Map getRanges(); + public Map getTypes(); + +} \ No newline at end of file diff --git a/src/test/java/de/monticore/lang/monticar/cnntrain/cocos/AllCoCoTest.java b/src/test/java/de/monticore/lang/monticar/cnntrain/cocos/AllCoCoTest.java index ad6ee8d49a9deac0487c98cc76bcfd0bcfa40358..57cf32fe6072627657d1015a60a046bdf9405a6f 100644 --- a/src/test/java/de/monticore/lang/monticar/cnntrain/cocos/AllCoCoTest.java +++ b/src/test/java/de/monticore/lang/monticar/cnntrain/cocos/AllCoCoTest.java @@ -41,6 +41,7 @@ public class AllCoCoTest extends AbstractCoCoTest{ checkValid("valid_tests","FullConfig2"); checkValid("valid_tests", "ReinforcementConfig"); checkValid("valid_tests", "ReinforcementConfig2"); + checkValid("valid_tests", "DdpgConfig"); } @Test diff --git a/src/test/resources/valid_tests/DdpgConfig.cnnt b/src/test/resources/valid_tests/DdpgConfig.cnnt new file mode 100644 index 0000000000000000000000000000000000000000..45f3b28f40bdd1a2f4efff5e2eface013d2e0df0 --- /dev/null +++ b/src/test/resources/valid_tests/DdpgConfig.cnnt @@ -0,0 +1,6 @@ +configuration DdpgConfig { + learning_method : reinforcement + rl_algorithm : ddpg-algorithm + critic : path.to.component + environment : gym { name:"CartPole-v1" } +} \ No newline at end of file