diff --git a/pom.xml b/pom.xml
index 1e3a1d10a07c4fb0436d29176dc5ba2df3180e74..037c85e99a570538570c8ea77f95eb4dcd1dafa6 100644
--- a/pom.xml
+++ b/pom.xml
@@ -30,7 +30,7 @@
de.monticore.lang.monticar
cnn-train
- 0.3.0-SNAPSHOT
+ 0.3.1-SNAPSHOT
diff --git a/src/main/grammars/de/monticore/lang/monticar/CNNTrain.mc4 b/src/main/grammars/de/monticore/lang/monticar/CNNTrain.mc4
index 4e81ce1e756dd9f778aaae37b370ca0e316e797e..7409708969bc5bdbaac2f1500888a652be3abfe3 100644
--- a/src/main/grammars/de/monticore/lang/monticar/CNNTrain.mc4
+++ b/src/main/grammars/de/monticore/lang/monticar/CNNTrain.mc4
@@ -98,6 +98,7 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number
interface MultiParamValue extends ConfigValue;
LearningMethodEntry implements ConfigEntry = name:"learning_method" ":" value:LearningMethodValue;
+ RLAlgorithmEntry implements ConfigEntry = name:"rl_algorithm" ":" value:RLAlgorithmValue;
NumEpisodesEntry implements ConfigEntry = name:"num_episodes" ":" value:IntegerValue;
DiscountFactorEntry implements ConfigEntry = name:"discount_factor" ":" value:NumberValue;
NumMaxStepsEntry implements ConfigEntry = name:"num_max_steps" ":" value:IntegerValue;
@@ -109,11 +110,14 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number
AgentNameEntry implements ConfigEntry = name:"agent_name" ":" value:StringValue;
UseDoubleDQNEntry implements ConfigEntry = name:"use_double_dqn" ":" value:BooleanValue;
RewardFunctionEntry implements ConfigEntry = name:"reward_function" ":" value:ComponentNameValue;
+ CriticNetworkEntry implements ConfigEntry = name:"critic" ":" value:ComponentNameValue;
ComponentNameValue implements ConfigValue = Name ("."Name)*;
LearningMethodValue implements ConfigValue = (supervisedLearning:"supervised" | reinforcement:"reinforcement");
+ RLAlgorithmValue implements ConfigValue = (dqn:"dqn-algorithm" | ddpg:"ddpg-algorithm");
+
interface MultiParamConfigEntry extends ConfigEntry;
// Replay Memory
diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ASTConfigurationUtils.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ASTConfigurationUtils.java
index 8cf3de8d846b779dfeef71af43f0b5e031c159e1..a0d4ed9b2726010e72787e83adfa9060334f831b 100644
--- a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ASTConfigurationUtils.java
+++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ASTConfigurationUtils.java
@@ -1,3 +1,23 @@
+/**
+ *
+ * ******************************************************************************
+ * MontiCAR Modeling Family, www.se-rwth.de
+ * Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
+ * All rights reserved.
+ *
+ * This project is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License as published by the Free Software Foundation; either
+ * version 3.0 of the License, or (at your option) any later version.
+ * This library is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with this project. If not, see .
+ * *******************************************************************************
+ */
package de.monticore.lang.monticar.cnntrain._cocos;
import de.monticore.lang.monticar.cnntrain._ast.ASTConfiguration;
diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CNNTrainCocos.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CNNTrainCocos.java
index ac2a8fd21160d2ea753baf6460a09d7b5d18df9b..3d0498079d3186fb7bcff99b17791580a4c3321b 100644
--- a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CNNTrainCocos.java
+++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CNNTrainCocos.java
@@ -33,7 +33,8 @@ public class CNNTrainCocos {
.addCoCo(new CheckFixTargetNetworkRequiresInterval())
.addCoCo(new CheckReinforcementRequiresEnvironment())
.addCoCo(new CheckLearningParameterCombination())
- .addCoCo(new CheckRosEnvironmentRequiresRewardFunction());
+ .addCoCo(new CheckRosEnvironmentRequiresRewardFunction())
+ .addCoCo(new CheckDdpgRequiresCriticNetwork());
}
public static void checkAll(CNNTrainCompilationUnitSymbol compilationUnit){
diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckDdpgRequiresCriticNetwork.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckDdpgRequiresCriticNetwork.java
new file mode 100644
index 0000000000000000000000000000000000000000..6ca5412ff0ff5c20441cb8284d4c67d848be26e1
--- /dev/null
+++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckDdpgRequiresCriticNetwork.java
@@ -0,0 +1,50 @@
+/**
+ *
+ * ******************************************************************************
+ * MontiCAR Modeling Family, www.se-rwth.de
+ * Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
+ * All rights reserved.
+ *
+ * This project is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License as published by the Free Software Foundation; either
+ * version 3.0 of the License, or (at your option) any later version.
+ * This library is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with this project. If not, see .
+ * *******************************************************************************
+ */
+package de.monticore.lang.monticar.cnntrain._cocos;
+
+import de.monticore.lang.monticar.cnntrain._ast.ASTConfiguration;
+import de.monticore.lang.monticar.cnntrain._ast.ASTCriticNetworkEntry;
+import de.monticore.lang.monticar.cnntrain._ast.ASTRLAlgorithmEntry;
+import de.monticore.lang.monticar.cnntrain.helper.ErrorCodes;
+import de.se_rwth.commons.logging.Log;
+
+public class CheckDdpgRequiresCriticNetwork implements CNNTrainASTConfigurationCoCo {
+
+ @Override
+ public void check(ASTConfiguration node) {
+ boolean isDdpg = node.getEntriesList().stream()
+ .anyMatch(e -> e instanceof ASTRLAlgorithmEntry
+ && ((ASTRLAlgorithmEntry)e).getValue().isPresentDdpg());
+ boolean hasCriticEntry = node.getEntriesList().stream()
+ .anyMatch(e -> ((e instanceof ASTCriticNetworkEntry)
+ && !((ASTCriticNetworkEntry)e).getValue().getNameList().isEmpty()));
+
+ if (isDdpg && !hasCriticEntry) {
+ ASTRLAlgorithmEntry algorithmEntry = node.getEntriesList().stream()
+ .filter(e -> e instanceof ASTRLAlgorithmEntry)
+ .map(e -> (ASTRLAlgorithmEntry)e)
+ .findFirst()
+ .orElseThrow(() -> new IllegalStateException("ASTRLAlgorithmEntry entry must be available"));
+ Log.error("0" + ErrorCodes.REQUIRED_PARAMETER_MISSING + " DDPG learning algorithm requires critc" +
+ " network entry", algorithmEntry.get_SourcePositionStart());
+ }
+ }
+}
\ No newline at end of file
diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckLearningParameterCombination.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckLearningParameterCombination.java
index 0dd6d8a04507f53d2be550c6bc2fc3e17be9bf2d..90ea12a0c3900dcba2e40d4389f8ec30cc2cb097 100644
--- a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckLearningParameterCombination.java
+++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckLearningParameterCombination.java
@@ -61,6 +61,8 @@ public class CheckLearningParameterCombination implements CNNTrainASTEntryCoCo {
);
private final static List ALLOWED_REINFORCEMENT_LEARNING = Lists.newArrayList(
ASTTrainContextEntry.class,
+ ASTRLAlgorithmEntry.class,
+ ASTCriticNetworkEntry.class,
ASTOptimizerEntry.class,
ASTRewardFunctionEntry.class,
ASTMinimumLearningRateEntry.class,
diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/CNNTrainSymbolTableCreator.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/CNNTrainSymbolTableCreator.java
index a72c5fc91b244eadb729b3a1123af90f243e8e23..2c45115a84fd4a03f90dcca3a8ec372eb03310dd 100644
--- a/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/CNNTrainSymbolTableCreator.java
+++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/CNNTrainSymbolTableCreator.java
@@ -277,6 +277,19 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP {
return value;
}
+ private ValueSymbol getValueSymbolForComponentName(ASTComponentNameValue astComponentNameValue) {
+ ValueSymbol value = new ValueSymbol();
+ List valueAsList = astComponentNameValue.getNameList();
+ value.setValue(valueAsList);
+ return value;
+ }
+
+ private ValueSymbol getValueSymbolForComponentNameAsString(ASTComponentNameValue astComponentNameValue) {
+ ValueSymbol value = new ValueSymbol();
+ value.setValue(String.join(".", astComponentNameValue.getNameList()));
+ return value;
+ }
+
private String getStringFromStringValue(ASTStringValue value) {
return value.getStringLiteral().getValue();
}
@@ -310,6 +323,22 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP {
configuration.getEntryMap().put(node.getName(), entry);
}
+ @Override
+ public void visit(ASTRLAlgorithmEntry node) {
+ EntrySymbol entry = new EntrySymbol(node.getName());
+ ValueSymbol value = new ValueSymbol();
+
+ if (node.getValue().isPresentDdpg()) {
+ value.setValue(RLAlgorithm.DDPG);
+ } else {
+ value.setValue(RLAlgorithm.DQN);
+ }
+
+ entry.setValue(value);
+ addToScopeAndLinkWithNode(entry, node);
+ configuration.getEntryMap().put(node.getName(), entry);
+ }
+
@Override
public void visit(ASTNumEpisodesEntry node) {
EntrySymbol entry = new EntrySymbol(node.getName());
@@ -390,6 +419,16 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP {
configuration.getEntryMap().put(node.getName(), entry);
}
+ @Override
+ public void visit(ASTCriticNetworkEntry node) {
+ EntrySymbol entry = new EntrySymbol(node.getName());
+ entry.setValue(getValueSymbolForComponentNameAsString(node.getValue()));
+ addToScopeAndLinkWithNode(entry, node);
+ configuration.getEntryMap().put(node.getName(), entry);
+ }
+
+
+
@Override
public void visit(ASTReplayMemoryEntry node) {
processMultiParamConfigVisit(node, node.getValue().getName());
diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/ConfigurationSymbol.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/ConfigurationSymbol.java
index 76ccf68aedc05d160ec757bf610e5a38e98daf47..f137c47a28b4cc1f4b1332a733047c006818f816 100644
--- a/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/ConfigurationSymbol.java
+++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/ConfigurationSymbol.java
@@ -21,6 +21,7 @@
package de.monticore.lang.monticar.cnntrain._symboltable;
import com.google.common.collect.Lists;
+import de.monticore.lang.monticar.cnntrain.annotations.TrainedArchitecture;
import de.monticore.symboltable.CommonScopeSpanningSymbol;
import java.util.*;
@@ -30,12 +31,14 @@ public class ConfigurationSymbol extends CommonScopeSpanningSymbol {
private Map entryMap = new HashMap<>();
private OptimizerSymbol optimizer;
private RewardFunctionSymbol rlRewardFunctionSymbol;
+ private TrainedArchitecture trainedArchitecture;
public static final ConfigurationSymbolKind KIND = new ConfigurationSymbolKind();
public ConfigurationSymbol() {
super("", KIND);
rlRewardFunctionSymbol = null;
+ trainedArchitecture = null;
}
public OptimizerSymbol getOptimizer() {
@@ -54,6 +57,14 @@ public class ConfigurationSymbol extends CommonScopeSpanningSymbol {
return Optional.ofNullable(this.rlRewardFunctionSymbol);
}
+ public Optional getTrainedArchitecture() {
+ return Optional.ofNullable(trainedArchitecture);
+ }
+
+ public void setTrainedArchitecture(TrainedArchitecture trainedArchitecture) {
+ this.trainedArchitecture = trainedArchitecture;
+ }
+
public Map getEntryMap() {
return entryMap;
}
@@ -66,4 +77,4 @@ public class ConfigurationSymbol extends CommonScopeSpanningSymbol {
return this.entryMap.containsKey("learning_method")
? (LearningMethod)this.entryMap.get("learning_method").getValue().getValue() : LearningMethod.SUPERVISED;
}
-}
+}
\ No newline at end of file
diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/RLAlgorithm.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/RLAlgorithm.java
new file mode 100644
index 0000000000000000000000000000000000000000..f0e38ffffc5f7cebde523b53daaeb0243bbc179b
--- /dev/null
+++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/RLAlgorithm.java
@@ -0,0 +1,36 @@
+/**
+ *
+ * ******************************************************************************
+ * MontiCAR Modeling Family, www.se-rwth.de
+ * Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
+ * All rights reserved.
+ *
+ * This project is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License as published by the Free Software Foundation; either
+ * version 3.0 of the License, or (at your option) any later version.
+ * This library is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with this project. If not, see .
+ * *******************************************************************************
+ */
+package de.monticore.lang.monticar.cnntrain._symboltable;
+
+public enum RLAlgorithm {
+ DQN {
+ @Override
+ public String toString() {
+ return "dqn";
+ }
+ },
+ DDPG {
+ @Override
+ public String toString() {
+ return "ddpg";
+ }
+ }
+}
diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/annotations/Range.java b/src/main/java/de/monticore/lang/monticar/cnntrain/annotations/Range.java
new file mode 100644
index 0000000000000000000000000000000000000000..bbbb3fc7041722519c8f2bef8d99b3ceb2dc0767
--- /dev/null
+++ b/src/main/java/de/monticore/lang/monticar/cnntrain/annotations/Range.java
@@ -0,0 +1,69 @@
+/**
+ *
+ * ******************************************************************************
+ * MontiCAR Modeling Family, www.se-rwth.de
+ * Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
+ * All rights reserved.
+ *
+ * This project is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License as published by the Free Software Foundation; either
+ * version 3.0 of the License, or (at your option) any later version.
+ * This library is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with this project. If not, see .
+ * *******************************************************************************
+ */
+package de.monticore.lang.monticar.cnntrain.annotations;
+
+import java.util.Optional;
+
+public class Range {
+ private final boolean lowerLimitIsInfinity;
+ private final boolean upperLimitIsInfinity;
+ private final Double lowerLimit;
+ private final Double upperLimit;
+
+ private Range(boolean lowerLimitIsInfinity, boolean upperLimitIsInfinity, Double lowerLimit, Double upperLimit) {
+ this.lowerLimitIsInfinity = lowerLimitIsInfinity;
+ this.upperLimitIsInfinity = upperLimitIsInfinity;
+ this.lowerLimit = lowerLimit;
+ this.upperLimit = upperLimit;
+ }
+
+ public Optional getLowerLimit() {
+ return Optional.ofNullable(lowerLimit);
+ }
+
+ public Optional getUpperLimit() {
+ return Optional.ofNullable(upperLimit);
+ }
+
+ public boolean isLowerLimitInfinity() {
+ return this.lowerLimitIsInfinity;
+ }
+
+ public boolean isUpperLimitInfinity() {
+ return this.upperLimitIsInfinity;
+ }
+
+ public static Range withLimits(double lowerLimit, double upperLimit) {
+ return new Range(false, false, lowerLimit, upperLimit);
+ }
+
+ public static Range withInfinityLimits() {
+ return new Range(true, true, null, null);
+ }
+
+ public static Range withUpperInfinityLimit(double lowerLimit) {
+ return new Range(false, true, lowerLimit, null);
+ }
+
+ public static Range withLowerInfinityLimit(double upperLimit) {
+ return new Range(true, false, null, upperLimit);
+ }
+}
diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/annotations/TrainedArchitecture.java b/src/main/java/de/monticore/lang/monticar/cnntrain/annotations/TrainedArchitecture.java
new file mode 100644
index 0000000000000000000000000000000000000000..0308a504cc97238158d931cab6b85489a05c926c
--- /dev/null
+++ b/src/main/java/de/monticore/lang/monticar/cnntrain/annotations/TrainedArchitecture.java
@@ -0,0 +1,33 @@
+/**
+ *
+ * ******************************************************************************
+ * MontiCAR Modeling Family, www.se-rwth.de
+ * Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
+ * All rights reserved.
+ *
+ * This project is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License as published by the Free Software Foundation; either
+ * version 3.0 of the License, or (at your option) any later version.
+ * This library is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with this project. If not, see .
+ * *******************************************************************************
+ */
+package de.monticore.lang.monticar.cnntrain.annotations;
+
+import java.util.List;
+import java.util.Map;
+
+public interface TrainedArchitecture {
+ public List getInputs();
+ public List getOutputs();
+ public Map> getDimensions();
+ public Map getRanges();
+ public Map getTypes();
+
+}
\ No newline at end of file
diff --git a/src/test/java/de/monticore/lang/monticar/cnntrain/cocos/AllCoCoTest.java b/src/test/java/de/monticore/lang/monticar/cnntrain/cocos/AllCoCoTest.java
index ad6ee8d49a9deac0487c98cc76bcfd0bcfa40358..57cf32fe6072627657d1015a60a046bdf9405a6f 100644
--- a/src/test/java/de/monticore/lang/monticar/cnntrain/cocos/AllCoCoTest.java
+++ b/src/test/java/de/monticore/lang/monticar/cnntrain/cocos/AllCoCoTest.java
@@ -41,6 +41,7 @@ public class AllCoCoTest extends AbstractCoCoTest{
checkValid("valid_tests","FullConfig2");
checkValid("valid_tests", "ReinforcementConfig");
checkValid("valid_tests", "ReinforcementConfig2");
+ checkValid("valid_tests", "DdpgConfig");
}
@Test
diff --git a/src/test/resources/valid_tests/DdpgConfig.cnnt b/src/test/resources/valid_tests/DdpgConfig.cnnt
new file mode 100644
index 0000000000000000000000000000000000000000..45f3b28f40bdd1a2f4efff5e2eface013d2e0df0
--- /dev/null
+++ b/src/test/resources/valid_tests/DdpgConfig.cnnt
@@ -0,0 +1,6 @@
+configuration DdpgConfig {
+ learning_method : reinforcement
+ rl_algorithm : ddpg-algorithm
+ critic : path.to.component
+ environment : gym { name:"CartPole-v1" }
+}
\ No newline at end of file