diff --git a/README.md b/README.md
index 91102da85b18f29e25f54f9e9bcdff42d75222ac..c8c4b38120a81d25c0b2b290f67f4856500a2560 100644
--- a/README.md
+++ b/README.md
@@ -117,7 +117,7 @@ configuration ReinforcementConfig {
| Parameter | Value | Default | Required | Algorithm | Description |
|------------|--------|---------|----------|-----------|-------------|
|learning_method| reinforcement,supervised | supervised | No | All | Determines that this CNNTrain configuration is a reinforcement or supervised learning configuration |
-| rl_algorithm | ddpg-algorithm, dqn-algorithm | dqn-algorithm | No | All | Determines the RL algorithm that is used to train the agent
+| rl_algorithm | ddpg-algorithm, dqn-algorithm, td3-algorithm | dqn-algorithm | No | All | Determines the RL algorithm that is used to train the agent
| agent_name | String | "agent" | No | All | Names the agent (e.g. for logging output) |
|environment | gym, ros_interface | Yes | / | All | If *ros_interface* is selected, then the agent and the environment communicates via [ROS](http://www.ros.org/). The gym environment comes with a set of environments which are listed [here](https://gym.openai.com/) |
| context | cpu, gpu | cpu | No | All | Determines whether the GPU is used during training or the CPU |
@@ -133,12 +133,15 @@ configuration ReinforcementConfig {
| replay_memory | buffer, online, combined | buffer | No | All | Determines the behaviour of the replay memory |
| strategy | epsgreedy, ornstein_uhlenbeck | epsgreedy (discrete), ornstein_uhlenbeck (continuous) | No | All | Determines the action selection policy during the training |
| reward_function | Full name of an EMAM component | / | Yes, if *ros_interface* is selected as the environment and no reward topic is given | All | The EMAM component that is used to calculate the reward. It must have two inputs, one for the current state and one boolean input that determines if the current state is terminal. It must also have exactly one output which represents the reward. |
-critic | Full name of architecture definition | / | Yes, if DDPG is selected | DDPG | The architecture definition which specifies the architecture of the critic network |
-soft_target_update_rate | Float | 0.001 | No | DDPG | Determines the update rate of the critic and actor target network |
-actor_optimizer | See supervised learning | adam with LR .0001 | No | DDPG | Determines the optimizer parameters of the actor network |
-critic_optimizer | See supervised learning | adam with LR .001 | No | DDPG | Determines the optimizer parameters of the critic network |
+critic | Full name of architecture definition | / | Yes, if DDPG or TD3 is selected | DDPG, TD3 | The architecture definition which specifies the architecture of the critic network |
+soft_target_update_rate | Float | 0.001 | No | DDPG, TD3 | Determines the update rate of the critic and actor target network |
+actor_optimizer | See supervised learning | adam with LR .0001 | No | DDPG, TD3 | Determines the optimizer parameters of the actor network |
+critic_optimizer | See supervised learning | adam with LR .001 | No | DDPG, TD3 | Determines the optimizer parameters of the critic network |
| start_training_at | Integer | 0 | No | All | Determines at which episode the training starts |
| evaluation_samples | Integer | 100 | No | All | Determines how many epsiodes are run when evaluating the network |
+| policy_noise | Float | 0.1 | No | TD3 | Determines the standard deviation of the noise that is added to the actions predicted by the target actor network when calculating the targets.
+| noise_clip | Float | 0.5 | No | TD3 | Sets the upper and lower limit of the policy noise
+policy_delay | Integer | 2 | No | TD3 | Every policy_delay of steps, the actor network and targets are updated.
#### Environment
@@ -189,6 +192,7 @@ This strategy is only available for discrete problems. It selects an action base
- **epsilon_decay_start**: Number of Episodes after the decay of epsilon starts
- **epsilon_decay**: The actual decay of epsilon after each step.
- **min_epsilon**: After *min_epsilon* is reached, epsilon is not decreased further.
+- **epsilon_decay_per_step**:Expects either true or false. If true, the decay will be performed for each step the agent executes instead of performing the decay after each episode. The default value is false
#### Option: ornstein_uhlenbeck
@@ -209,6 +213,9 @@ Example: Given an actor network with action output of shape (3,), we can write
to specify the parameters for each place.
+### Option: gaussian
+This strategy is also only available for continuous problems. If this strat- egy is selected, uncorrelated Gaussian noise with zero mean is added to the current policy action selection. This strategy provides the same parameters as the epsgreedy option and the parameter **noise_variance** that determines the variance of the noise.
+
## Generation
To execute generation in your project, use the following code to generate a separate Config file:
diff --git a/pom.xml b/pom.xml
index 2e1be2079899bb2d8e76808ac122ed6b9ad7aed6..3797a1f056943a4d716b9d538da43b5396f5dead 100644
--- a/pom.xml
+++ b/pom.xml
@@ -30,7 +30,7 @@
de.monticore.lang.monticar
cnn-train
- 0.3.4-SNAPSHOT
+ 0.3.6-SNAPSHOT
diff --git a/src/main/grammars/de/monticore/lang/monticar/CNNTrain.mc4 b/src/main/grammars/de/monticore/lang/monticar/CNNTrain.mc4
index 3503f53b9e82c9d4f7007469d889c272d36a5be9..efc7507186030b80318f10650fdf2de00ea25626 100644
--- a/src/main/grammars/de/monticore/lang/monticar/CNNTrain.mc4
+++ b/src/main/grammars/de/monticore/lang/monticar/CNNTrain.mc4
@@ -145,7 +145,7 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number
LearningMethodValue implements ConfigValue = (supervisedLearning:"supervised" | reinforcement:"reinforcement");
- RLAlgorithmValue implements ConfigValue = (dqn:"dqn-algorithm" | ddpg:"ddpg-algorithm");
+ RLAlgorithmValue implements ConfigValue = (dqn:"dqn-algorithm" | ddpg:"ddpg-algorithm" | tdThree:"td3-algorithm");
interface MultiParamConfigEntry extends ConfigEntry;
@@ -176,17 +176,23 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number
interface StrategyOrnsteinUhlenbeckEntry extends Entry;
StrategyOrnsteinUhlenbeckValue implements StrategyValue = name:"ornstein_uhlenbeck" ("{" params:StrategyOrnsteinUhlenbeckEntry* "}")?;
+ interface StrategyGaussianEntry extends Entry;
+ StrategyGaussianValue implements StrategyValue = name:"gaussian" ("{" params:StrategyGaussianEntry* "}")?;
+
+ StrategyGaussianNoiseVarianceEntry implements StrategyGaussianEntry = name: "noise_variance" ":" value:NumberValue;
+
StrategyOUMu implements StrategyOrnsteinUhlenbeckEntry = name: "mu" ":" value:DoubleVectorValue;
StrategyOUTheta implements StrategyOrnsteinUhlenbeckEntry = name: "theta" ":" value:DoubleVectorValue;
StrategyOUSigma implements StrategyOrnsteinUhlenbeckEntry = name: "sigma" ":" value:DoubleVectorValue;
- interface GeneralStrategyEntry extends StrategyEpsGreedyEntry, StrategyOrnsteinUhlenbeckEntry;
+ interface GeneralStrategyEntry extends StrategyEpsGreedyEntry, StrategyOrnsteinUhlenbeckEntry, StrategyGaussianEntry;
GreedyEpsilonEntry implements GeneralStrategyEntry = name:"epsilon" ":" value:NumberValue;
MinEpsilonEntry implements GeneralStrategyEntry = name:"min_epsilon" ":" value:NumberValue;
EpsilonDecayStartEntry implements GeneralStrategyEntry = name:"epsilon_decay_start" ":" value:IntegerValue;
EpsilonDecayMethodEntry implements GeneralStrategyEntry = name:"epsilon_decay_method" ":" value:EpsilonDecayMethodValue;
EpsilonDecayMethodValue implements ConfigValue = (linear:"linear" | no:"no");
+ EpsilonDecayPerStepEntry implements GeneralStrategyEntry = name:"epsilon_decay_per_step" ":" value:BooleanValue;
EpsilonDecayEntry implements GeneralStrategyEntry = name:"epsilon_decay" ":" value:NumberValue;
// Environment
@@ -211,8 +217,13 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number
UseDoubleDQNEntry implements ConfigEntry = name:"use_double_dqn" ":" value:BooleanValue;
- // DDPG exclusive parameters
+ // DDPG and TD3 exclusive parameters
CriticNetworkEntry implements ConfigEntry = name:"critic" ":" value:ComponentNameValue;
SoftTargetUpdateRateEntry implements ConfigEntry = name:"soft_target_update_rate" ":" value:NumberValue;
CriticOptimizerEntry implements ConfigEntry = name:"critic_optimizer" ":" value:OptimizerValue;
+
+ // TD3 exclusive parameters
+ PolicyNoiseEntry implements ConfigEntry = name:"policy_noise" ":" value:NumberValue;
+ NoiseClipEntry implements ConfigEntry = name:"noise_clip" ":" value:NumberValue;
+ PolicyDelayEntry implements ConfigEntry = name:"policy_delay" ":" value:IntegerValue;
}
\ No newline at end of file
diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ASTConfigurationUtils.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ASTConfigurationUtils.java
index a955b73d59ba3836f7ff6e08e53ecbbb0f67773b..a31e150ca8dbef14d2b29d458c759baa41f368c0 100644
--- a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ASTConfigurationUtils.java
+++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ASTConfigurationUtils.java
@@ -21,6 +21,7 @@
package de.monticore.lang.monticar.cnntrain._cocos;
import de.monticore.lang.monticar.cnntrain._ast.*;
+import static de.monticore.lang.monticar.cnntrain.helper.ConfigEntryNameConstants.*;
import java.util.Optional;
@@ -41,8 +42,16 @@ class ASTConfigurationUtils {
e -> (e instanceof ASTRLAlgorithmEntry) && ((ASTRLAlgorithmEntry)e).getValue().isPresentDdpg());
}
+ static boolean isTd3Algorithm(final ASTConfiguration configuration) {
+ return isReinforcementLearning(configuration)
+ && configuration.getEntriesList().stream().anyMatch(
+ e -> (e instanceof ASTRLAlgorithmEntry) && ((ASTRLAlgorithmEntry)e).getValue().isPresentTdThree());
+ }
+
static boolean isDqnAlgorithm(final ASTConfiguration configuration) {
- return isReinforcementLearning(configuration) && !isDdpgAlgorithm(configuration);
+ return isReinforcementLearning(configuration)
+ && !isDdpgAlgorithm(configuration)
+ && !isTd3Algorithm(configuration);
}
static boolean hasEntry(final ASTConfiguration configuration, final Class extends ASTConfigEntry> entryClazz) {
@@ -69,7 +78,7 @@ class ASTConfigurationUtils {
return ASTConfigurationUtils.hasEnvironment(node)
&& node.getEntriesList().stream()
.anyMatch(e -> (e instanceof ASTEnvironmentEntry)
- && ((ASTEnvironmentEntry)e).getValue().getName().equals("ros_interface"));
+ && ((ASTEnvironmentEntry)e).getValue().getName().equals(ENVIRONMENT_ROS));
}
static boolean hasRewardTopic(final ASTConfiguration node) {
@@ -84,4 +93,18 @@ class ASTConfigurationUtils {
}
return false;
}
+
+ static boolean isActorCriticAlgorithm(final ASTConfiguration node) {
+ return isDdpgAlgorithm(node) || isTd3Algorithm(node);
+ }
+
+ static boolean hasCriticEntry(final ASTConfiguration node) {
+ return node.getEntriesList().stream()
+ .anyMatch(e -> ((e instanceof ASTCriticNetworkEntry)
+ && !((ASTCriticNetworkEntry)e).getValue().getNameList().isEmpty()));
+ }
+
+ public static boolean isContinuousAlgorithm(final ASTConfiguration node) {
+ return isDdpgAlgorithm(node) || isTd3Algorithm(node);
+ }
}
diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CNNTrainCocos.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CNNTrainCocos.java
index b46504214ad1d24c2724196bf82b1e48d1fa5168..c994809a6986ba4b29a2621d8efe5556d3411f0d 100644
--- a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CNNTrainCocos.java
+++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CNNTrainCocos.java
@@ -22,6 +22,7 @@ package de.monticore.lang.monticar.cnntrain._cocos;
import de.monticore.lang.monticar.cnntrain._ast.ASTCNNTrainNode;
import de.monticore.lang.monticar.cnntrain._symboltable.CNNTrainCompilationUnitSymbol;
+import de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol;
import de.se_rwth.commons.logging.Log;
public class CNNTrainCocos {
@@ -34,7 +35,7 @@ public class CNNTrainCocos {
.addCoCo(new CheckReinforcementRequiresEnvironment())
.addCoCo(new CheckLearningParameterCombination())
.addCoCo(new CheckRosEnvironmentRequiresRewardFunction())
- .addCoCo(new CheckDdpgRequiresCriticNetwork())
+ .addCoCo(new CheckActorCriticRequiresCriticNetwork())
.addCoCo(new CheckRlAlgorithmParameter())
.addCoCo(new CheckDiscreteRLAlgorithmUsesDiscreteStrategy())
.addCoCo(new CheckContinuousRLAlgorithmUsesContinuousStrategy())
@@ -46,4 +47,19 @@ public class CNNTrainCocos {
int findings = Log.getFindings().size();
createChecker().checkAll(node);
}
+
+ public static void checkTrainedArchitectureCoCos(final ConfigurationSymbol configurationSymbol) {
+ CNNTrainConfigurationSymbolChecker checker = new CNNTrainConfigurationSymbolChecker()
+ .addCoCo(new CheckTrainedRlNetworkHasExactlyOneInput())
+ .addCoCo(new CheckTrainedRlNetworkHasExactlyOneOutput())
+ .addCoCo(new CheckOUParameterDimensionEqualsActionDimension());
+ checker.checkAll(configurationSymbol);
+ }
+
+ public static void checkCriticCocos(final ConfigurationSymbol configurationSymbol) {
+ CNNTrainConfigurationSymbolChecker checker = new CNNTrainConfigurationSymbolChecker()
+ .addCoCo(new CheckCriticNetworkHasExactlyAOneDimensionalOutput())
+ .addCoCo(new CheckCriticNetworkInputs());
+ checker.checkAll(configurationSymbol);
+ }
}
\ No newline at end of file
diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CNNTrainConfigurationSymbolChecker.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CNNTrainConfigurationSymbolChecker.java
new file mode 100644
index 0000000000000000000000000000000000000000..c59d5a023ec0c35bfe44d2bf79bac78993191b04
--- /dev/null
+++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CNNTrainConfigurationSymbolChecker.java
@@ -0,0 +1,44 @@
+/**
+ *
+ * ******************************************************************************
+ * MontiCAR Modeling Family, www.se-rwth.de
+ * Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
+ * All rights reserved.
+ *
+ * This project is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License as published by the Free Software Foundation; either
+ * version 3.0 of the License, or (at your option) any later version.
+ * This library is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with this project. If not, see .
+ * *******************************************************************************
+ */
+package de.monticore.lang.monticar.cnntrain._cocos;
+
+import de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol;
+
+import java.util.ArrayList;
+import java.util.List;
+
+/**
+ *
+ */
+public class CNNTrainConfigurationSymbolChecker {
+ private List cocos = new ArrayList<>();
+
+ public CNNTrainConfigurationSymbolChecker addCoCo(CNNTrainConfigurationSymbolCoCo coco) {
+ cocos.add(coco);
+ return this;
+ }
+
+ public void checkAll(ConfigurationSymbol configurationSymbol) {
+ for (CNNTrainConfigurationSymbolCoCo coco : cocos) {
+ coco.check(configurationSymbol);
+ }
+ }
+}
\ No newline at end of file
diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/annotations/TrainedArchitecture.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CNNTrainConfigurationSymbolCoCo.java
similarity index 73%
rename from src/main/java/de/monticore/lang/monticar/cnntrain/annotations/TrainedArchitecture.java
rename to src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CNNTrainConfigurationSymbolCoCo.java
index 0308a504cc97238158d931cab6b85489a05c926c..320f2475eb553defa9998e112ea6047cd60dbbcd 100644
--- a/src/main/java/de/monticore/lang/monticar/cnntrain/annotations/TrainedArchitecture.java
+++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CNNTrainConfigurationSymbolCoCo.java
@@ -18,16 +18,13 @@
* License along with this project. If not, see .
* *******************************************************************************
*/
-package de.monticore.lang.monticar.cnntrain.annotations;
+package de.monticore.lang.monticar.cnntrain._cocos;
-import java.util.List;
-import java.util.Map;
-
-public interface TrainedArchitecture {
- public List getInputs();
- public List getOutputs();
- public Map> getDimensions();
- public Map getRanges();
- public Map getTypes();
+import de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol;
+/**
+ *
+ */
+public interface CNNTrainConfigurationSymbolCoCo {
+ void check(ConfigurationSymbol configurationSymbol);
}
\ No newline at end of file
diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckDdpgRequiresCriticNetwork.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckActorCriticRequiresCriticNetwork.java
similarity index 76%
rename from src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckDdpgRequiresCriticNetwork.java
rename to src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckActorCriticRequiresCriticNetwork.java
index 6ca5412ff0ff5c20441cb8284d4c67d848be26e1..821c6fb0f42712b862044981e5fef829369bb222 100644
--- a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckDdpgRequiresCriticNetwork.java
+++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckActorCriticRequiresCriticNetwork.java
@@ -26,24 +26,23 @@ import de.monticore.lang.monticar.cnntrain._ast.ASTRLAlgorithmEntry;
import de.monticore.lang.monticar.cnntrain.helper.ErrorCodes;
import de.se_rwth.commons.logging.Log;
-public class CheckDdpgRequiresCriticNetwork implements CNNTrainASTConfigurationCoCo {
+import static de.monticore.lang.monticar.cnntrain._cocos.ASTConfigurationUtils.hasCriticEntry;
+import static de.monticore.lang.monticar.cnntrain._cocos.ASTConfigurationUtils.isActorCriticAlgorithm;
+
+public class CheckActorCriticRequiresCriticNetwork implements CNNTrainASTConfigurationCoCo {
@Override
public void check(ASTConfiguration node) {
- boolean isDdpg = node.getEntriesList().stream()
- .anyMatch(e -> e instanceof ASTRLAlgorithmEntry
- && ((ASTRLAlgorithmEntry)e).getValue().isPresentDdpg());
- boolean hasCriticEntry = node.getEntriesList().stream()
- .anyMatch(e -> ((e instanceof ASTCriticNetworkEntry)
- && !((ASTCriticNetworkEntry)e).getValue().getNameList().isEmpty()));
+ boolean isActorCritic = isActorCriticAlgorithm(node);
+ boolean hasCriticEntry = hasCriticEntry(node);
- if (isDdpg && !hasCriticEntry) {
+ if (isActorCritic && !hasCriticEntry) {
ASTRLAlgorithmEntry algorithmEntry = node.getEntriesList().stream()
.filter(e -> e instanceof ASTRLAlgorithmEntry)
.map(e -> (ASTRLAlgorithmEntry)e)
.findFirst()
.orElseThrow(() -> new IllegalStateException("ASTRLAlgorithmEntry entry must be available"));
- Log.error("0" + ErrorCodes.REQUIRED_PARAMETER_MISSING + " DDPG learning algorithm requires critc" +
+ Log.error("0" + ErrorCodes.REQUIRED_PARAMETER_MISSING + " DDPG learning algorithm requires critic" +
" network entry", algorithmEntry.get_SourcePositionStart());
}
}
diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckContinuousRLAlgorithmUsesContinuousStrategy.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckContinuousRLAlgorithmUsesContinuousStrategy.java
index 5d0a68c42884d94ff26880a79a5700d7ef4f4cd9..03c097998d1fcca74aa8751662e117bf647aebd7 100644
--- a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckContinuousRLAlgorithmUsesContinuousStrategy.java
+++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckContinuousRLAlgorithmUsesContinuousStrategy.java
@@ -22,19 +22,23 @@ package de.monticore.lang.monticar.cnntrain._cocos;
import com.google.common.collect.ImmutableSet;
import de.monticore.lang.monticar.cnntrain._ast.ASTConfiguration;
+import de.monticore.lang.monticar.cnntrain.helper.ConfigEntryNameConstants;
import de.monticore.lang.monticar.cnntrain.helper.ErrorCodes;
import de.se_rwth.commons.logging.Log;
import java.util.Set;
+import static de.monticore.lang.monticar.cnntrain.helper.ConfigEntryNameConstants.*;
+
public class CheckContinuousRLAlgorithmUsesContinuousStrategy implements CNNTrainASTConfigurationCoCo{
private static final Set CONTINUOUS_STRATEGIES = ImmutableSet.builder()
- .add("ornstein_uhlenbeck")
+ .add(STRATEGY_OU)
+ .add(STRATEGY_GAUSSIAN)
.build();
@Override
public void check(ASTConfiguration node) {
- if (ASTConfigurationUtils.isDdpgAlgorithm(node)
+ if (ASTConfigurationUtils.isContinuousAlgorithm(node)
&& ASTConfigurationUtils.hasStrategy(node)
&& ASTConfigurationUtils.getStrategyMethod(node).isPresent()) {
final String usedStrategy = ASTConfigurationUtils.getStrategyMethod(node).get();
diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckCriticNetworkHasExactlyAOneDimensionalOutput.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckCriticNetworkHasExactlyAOneDimensionalOutput.java
new file mode 100644
index 0000000000000000000000000000000000000000..5cc3fbb1d905f6eda23891a1cd53821b2f6186ad
--- /dev/null
+++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckCriticNetworkHasExactlyAOneDimensionalOutput.java
@@ -0,0 +1,53 @@
+/**
+ *
+ * ******************************************************************************
+ * MontiCAR Modeling Family, www.se-rwth.de
+ * Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
+ * All rights reserved.
+ *
+ * This project is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License as published by the Free Software Foundation; either
+ * version 3.0 of the License, or (at your option) any later version.
+ * This library is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with this project. If not, see .
+ * *******************************************************************************
+ */
+package de.monticore.lang.monticar.cnntrain._cocos;
+
+import de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol;
+import de.monticore.lang.monticar.cnntrain._symboltable.NNArchitectureSymbol;
+import de.monticore.lang.monticar.cnntrain.helper.ErrorCodes;
+import de.se_rwth.commons.logging.Log;
+
+import java.util.List;
+
+/**
+ *
+ */
+public class CheckCriticNetworkHasExactlyAOneDimensionalOutput implements CNNTrainConfigurationSymbolCoCo {
+
+ @Override
+ public void check(ConfigurationSymbol configurationSymbol) {
+ if (configurationSymbol.getCriticNetwork().isPresent()) {
+ NNArchitectureSymbol criticNetwork = configurationSymbol.getCriticNetwork().get();
+
+ if (criticNetwork.getOutputs().size() > 1) {
+ Log.error("0" + ErrorCodes.CRITIC_NETWORK_ERROR
+ + " The critic network has more than one outputs", criticNetwork.getSourcePosition());
+ }
+ final String outputName = criticNetwork.getOutputs().get(0);
+ List dimensions = criticNetwork.getDimensions().get(outputName);
+
+ if (dimensions.size() != 1 || dimensions.get(0) != 1) {
+ Log.error("0" + ErrorCodes.CRITIC_NETWORK_ERROR + " The output " + outputName
+ + " of critic network is not a one-dimensional vector", configurationSymbol.getSourcePosition());
+ }
+ }
+ }
+}
diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckCriticNetworkInputs.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckCriticNetworkInputs.java
new file mode 100644
index 0000000000000000000000000000000000000000..e2d745d189ad8a32e0c1c64c4e6b9718044099a8
--- /dev/null
+++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckCriticNetworkInputs.java
@@ -0,0 +1,118 @@
+/**
+ *
+ * ******************************************************************************
+ * MontiCAR Modeling Family, www.se-rwth.de
+ * Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
+ * All rights reserved.
+ *
+ * This project is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License as published by the Free Software Foundation; either
+ * version 3.0 of the License, or (at your option) any later version.
+ * This library is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with this project. If not, see .
+ * *******************************************************************************
+ */
+package de.monticore.lang.monticar.cnntrain._cocos;
+
+import de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol;
+import de.monticore.lang.monticar.cnntrain._symboltable.NNArchitectureSymbol;
+import de.monticore.lang.monticar.cnntrain.annotations.Range;
+import de.monticore.lang.monticar.cnntrain.helper.ErrorCodes;
+import de.se_rwth.commons.logging.Log;
+
+import java.util.List;
+import java.util.stream.Collectors;
+
+/**
+ *
+ */
+public class CheckCriticNetworkInputs implements CNNTrainConfigurationSymbolCoCo {
+
+ @Override
+ public void check(ConfigurationSymbol configurationSymbol) {
+ if (configurationSymbol.getCriticNetwork().isPresent()) {
+ if (!configurationSymbol.getTrainedArchitecture().isPresent()) {
+ Log.error("0" + ErrorCodes.MISSING_TRAINED_ARCHITECTURE +
+ "No architecture found that is trained by this configuration.", configurationSymbol.getSourcePosition());
+ }
+ NNArchitectureSymbol trainedArchitecture = configurationSymbol.getTrainedArchitecture().get();
+ NNArchitectureSymbol criticNetwork = configurationSymbol.getCriticNetwork().get();
+
+ if (trainedArchitecture.getInputs().size() != 1 || trainedArchitecture.getOutputs().size() != 1) {
+ Log.error("Malformed trained architecture");
+ }
+
+ final String stateInput = trainedArchitecture.getInputs().get(0);
+ final String actionOutput = trainedArchitecture.getOutputs().get(0);
+ final List stateDimensions = trainedArchitecture.getDimensions().get(stateInput);
+ final List actionDimensions = trainedArchitecture.getDimensions().get(actionOutput);
+ final Range stateRange = trainedArchitecture.getRanges().get(stateInput);
+ final Range actionRange = trainedArchitecture.getRanges().get(actionOutput);
+ final String stateType = trainedArchitecture.getTypes().get(stateInput);
+ final String actionType = trainedArchitecture.getTypes().get(actionOutput);
+
+ String criticInput1 = criticNetwork.getInputs().get(0);
+ String criticInput2 = criticNetwork.getInputs().get(1);
+
+ if (criticNetwork.getInputs().size() != 2) {
+ Log.error("0" + ErrorCodes.CRITIC_NETWORK_ERROR
+ + "Number of critic network inputs is wrong. Critic network has two inputs," +
+ "first needs to be a state input and second needs to be the action input.");
+ }
+
+ if (!criticNetwork.getDimensions().get(criticInput1).equals(stateDimensions)) {
+ Log.error("0" + ErrorCodes.CRITIC_NETWORK_ERROR
+ + " Declared critic network is not a critic: Dimensions of first input of critic architecture must be" +
+ " equal to state's dimensions "
+ + stateDimensions.stream().map(Object::toString).collect(Collectors.joining(",", "{", "}"))
+ + ".", configurationSymbol.getSourcePosition());
+ }
+
+ if (!criticNetwork.getDimensions().get(criticInput2).equals(actionDimensions)) {
+ Log.error("0" + ErrorCodes.CRITIC_NETWORK_ERROR
+ + " Declared critic network is not a critic: Dimensions of second input of critic architecture must be" +
+ " equal to action's dimensions "
+ + actionDimensions.stream().map(Object::toString).collect(Collectors.joining(",", "{", "}"))
+ + ".", configurationSymbol.getSourcePosition());
+ }
+
+ if (!criticNetwork.getRanges().get(criticInput1).equals(stateRange)) {
+ Log.error("0" + ErrorCodes.CRITIC_NETWORK_ERROR
+ + " Declared critic network is not a critic: Ranges of first input of critic architecture must be" +
+ " equal to state's ranges "
+ + stateRange.toString()
+ + ".", configurationSymbol.getSourcePosition());
+ }
+
+ if (!criticNetwork.getRanges().get(criticInput2).equals(actionRange)) {
+ Log.error("0" + ErrorCodes.CRITIC_NETWORK_ERROR
+ + " Declared critic network is not a critic: Ranges of second input of critic architecture must be" +
+ " equal to action's ranges "
+ + actionRange.toString()
+ + ".", configurationSymbol.getSourcePosition());
+ }
+
+ if (!criticNetwork.getTypes().get(criticInput1).equals(stateType)) {
+ Log.error("0" + ErrorCodes.CRITIC_NETWORK_ERROR
+ + " Declared critic network is not a critic: Type of first input of critic architecture must be" +
+ " equal to state's types "
+ + stateType
+ + ".", configurationSymbol.getSourcePosition());
+ }
+
+ if (!criticNetwork.getTypes().get(criticInput2).equals(actionType)) {
+ Log.error("0" + ErrorCodes.CRITIC_NETWORK_ERROR
+ + " Declared critic network is not a critic: Type of second input of critic architecture must be" +
+ " equal to action's types "
+ + stateType
+ + ".", configurationSymbol.getSourcePosition());
+ }
+ }
+ }
+}
diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckDiscreteRLAlgorithmUsesDiscreteStrategy.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckDiscreteRLAlgorithmUsesDiscreteStrategy.java
index 64263bb35d62401abaf9c1d8e0d8179b28636e0d..edf32c7aa3602d006dcda39c29b33b64fa1e21dc 100644
--- a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckDiscreteRLAlgorithmUsesDiscreteStrategy.java
+++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckDiscreteRLAlgorithmUsesDiscreteStrategy.java
@@ -22,6 +22,7 @@ package de.monticore.lang.monticar.cnntrain._cocos;
import com.google.common.collect.ImmutableSet;
import de.monticore.lang.monticar.cnntrain._ast.ASTConfiguration;
+import de.monticore.lang.monticar.cnntrain.helper.ConfigEntryNameConstants;
import de.monticore.lang.monticar.cnntrain.helper.ErrorCodes;
import de.se_rwth.commons.logging.Log;
@@ -29,7 +30,7 @@ import java.util.Set;
public class CheckDiscreteRLAlgorithmUsesDiscreteStrategy implements CNNTrainASTConfigurationCoCo{
private static final Set DISCRETE_STRATEGIES = ImmutableSet.builder()
- .add("epsgreedy")
+ .add(ConfigEntryNameConstants.STRATEGY_EPSGREEDY)
.build();
@Override
diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckFixTargetNetworkRequiresInterval.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckFixTargetNetworkRequiresInterval.java
index 221c38ed7f15a296fbfebf4b9bdf8062ad6e74d8..435cf74496590866bdd4c9c0f5592507c168fc1f 100644
--- a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckFixTargetNetworkRequiresInterval.java
+++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckFixTargetNetworkRequiresInterval.java
@@ -23,6 +23,7 @@ package de.monticore.lang.monticar.cnntrain._cocos;
import de.monticore.lang.monticar.cnntrain._ast.*;
import de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol;
import de.monticore.lang.monticar.cnntrain._symboltable.EntrySymbol;
+import de.monticore.lang.monticar.cnntrain.helper.ConfigEntryNameConstants;
import de.monticore.lang.monticar.cnntrain.helper.ErrorCodes;
import de.se_rwth.commons.logging.Log;
@@ -33,8 +34,6 @@ import java.util.Map;
*
*/
public class CheckFixTargetNetworkRequiresInterval implements CNNTrainASTConfigurationCoCo {
- private static final String PARAMETER_USE_FIX_TARGET_NETWORK = "use_fix_target_network";
- private static final String PARAMETER_TARGET_NETWORK_UPDATE_INTERVAL = "target_network_update_interval";
@Override
public void check(ASTConfiguration node) {
@@ -50,8 +49,8 @@ public class CheckFixTargetNetworkRequiresInterval implements CNNTrainASTConfigu
.map(e -> (ASTUseFixTargetNetworkEntry)e)
.findFirst()
.orElseThrow(() -> new IllegalStateException("ASTUseFixTargetNetwork entry must be available"));
- Log.error("0" + ErrorCodes.REQUIRED_PARAMETER_MISSING + " Parameter " + Boolean.toString(useFixTargetNetwork)
- + " requires parameter " + PARAMETER_TARGET_NETWORK_UPDATE_INTERVAL,
+ Log.error("0" + ErrorCodes.REQUIRED_PARAMETER_MISSING + " Parameter " + ConfigEntryNameConstants.USE_FIX_TARGET_NETWORK
+ + " requires parameter " + ConfigEntryNameConstants.TARGET_NETWORK_UPDATE_INTERVAL,
useFixTargetNetworkEntry.get_SourcePositionStart());
} else if (!useFixTargetNetwork && hasTargetNetworkUpdateInterval) {
ASTTargetNetworkUpdateIntervalEntry targetNetworkUpdateIntervalEntry = node.getEntriesList().stream()
@@ -62,7 +61,7 @@ public class CheckFixTargetNetworkRequiresInterval implements CNNTrainASTConfigu
() -> new IllegalStateException("ASTTargetNetworkUpdateInterval entry must be available"));
Log.error("0" + ErrorCodes.REQUIRED_PARAMETER_MISSING + " Parameter "
+ targetNetworkUpdateIntervalEntry.getName() + " requires that parameter "
- + PARAMETER_USE_FIX_TARGET_NETWORK + " to be true.",
+ + ConfigEntryNameConstants.USE_FIX_TARGET_NETWORK + " to be true.",
targetNetworkUpdateIntervalEntry.get_SourcePositionStart());
}
}
diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckOUParameterDimensionEqualsActionDimension.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckOUParameterDimensionEqualsActionDimension.java
new file mode 100644
index 0000000000000000000000000000000000000000..c4661656c7ae999280c99d1bc524bd414d4e4c34
--- /dev/null
+++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckOUParameterDimensionEqualsActionDimension.java
@@ -0,0 +1,79 @@
+/**
+ *
+ * ******************************************************************************
+ * MontiCAR Modeling Family, www.se-rwth.de
+ * Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
+ * All rights reserved.
+ *
+ * This project is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License as published by the Free Software Foundation; either
+ * version 3.0 of the License, or (at your option) any later version.
+ * This library is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with this project. If not, see .
+ * *******************************************************************************
+ */
+package de.monticore.lang.monticar.cnntrain._cocos;
+
+import de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol;
+import de.monticore.lang.monticar.cnntrain._symboltable.MultiParamValueSymbol;
+import de.monticore.lang.monticar.cnntrain._symboltable.NNArchitectureSymbol;
+import de.monticore.lang.monticar.cnntrain.helper.ConfigEntryNameConstants;
+import de.se_rwth.commons.logging.Log;
+
+import java.util.Collection;
+import java.util.List;
+
+import static de.monticore.lang.monticar.cnntrain.helper.ConfigEntryNameConstants.*;
+
+/**
+ *
+ */
+public class CheckOUParameterDimensionEqualsActionDimension implements CNNTrainConfigurationSymbolCoCo {
+ @Override
+ public void check(final ConfigurationSymbol configurationSymbol) {
+ if (configurationSymbol.getTrainedArchitecture().isPresent()
+ && configurationSymbol.isReinforcementLearningMethod()
+ && configurationSymbol.getEntry(STRATEGY).getValue().getValue().equals(STRATEGY_OU)) {
+ final MultiParamValueSymbol strategyParameters
+ = (MultiParamValueSymbol)configurationSymbol.getEntry(STRATEGY).getValue();
+ final NNArchitectureSymbol architectureSymbol = configurationSymbol.getTrainedArchitecture().get();
+ final String outputNameOfTrainedArchitecture = architectureSymbol.getOutputs().get(0);
+ final int actionDimension = architectureSymbol.getDimensions().get(outputNameOfTrainedArchitecture).size();
+
+ if (strategyParameters.hasParameter(STRATEGY_OU_MU)) {
+ logIfDimensionIsUnequal(configurationSymbol, strategyParameters, outputNameOfTrainedArchitecture,
+ actionDimension, STRATEGY_OU_MU);
+ }
+
+ if (strategyParameters.hasParameter(STRATEGY_OU_SIGMA)) {
+ logIfDimensionIsUnequal(configurationSymbol, strategyParameters, outputNameOfTrainedArchitecture,
+ actionDimension, STRATEGY_OU_SIGMA);
+ }
+
+ if (strategyParameters.hasParameter(STRATEGY_OU_THETA)) {
+ logIfDimensionIsUnequal(configurationSymbol, strategyParameters, outputNameOfTrainedArchitecture,
+ actionDimension, STRATEGY_OU_THETA);
+ }
+ }
+ }
+
+ private void logIfDimensionIsUnequal(ConfigurationSymbol configurationSymbol,
+ MultiParamValueSymbol strategyParameters,
+ String outputNameOfTrainedArchitecture,
+ int actionDimension,
+ String ouParameterName) {
+ final int ouParameterDimension = ((Collection>) strategyParameters.getParameter(ouParameterName)).size();
+ if (ouParameterDimension != actionDimension) {
+ Log.error("Vector parameter " + ouParameterName + " of parameter " + STRATEGY_OU + " must have"
+ + " the same dimensions as the action dimension of output "
+ + outputNameOfTrainedArchitecture + " which is " + actionDimension,
+ configurationSymbol.getSourcePosition());
+ }
+ }
+}
diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckReinforcementRequiresEnvironment.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckReinforcementRequiresEnvironment.java
index 19672e71afd1034c3819aaeb0dcfc213a8ccd007..29c5f0631b34c3c7eb34ff78029a5d066f740eb9 100644
--- a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckReinforcementRequiresEnvironment.java
+++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckReinforcementRequiresEnvironment.java
@@ -25,6 +25,7 @@ import de.monticore.lang.monticar.cnntrain._ast.ASTEnvironmentEntry;
import de.monticore.lang.monticar.cnntrain._ast.ASTLearningMethodEntry;
import de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol;
import de.monticore.lang.monticar.cnntrain._symboltable.LearningMethod;
+import de.monticore.lang.monticar.cnntrain.helper.ConfigEntryNameConstants;
import de.monticore.lang.monticar.cnntrain.helper.ErrorCodes;
import de.se_rwth.commons.logging.Log;
@@ -32,8 +33,6 @@ import de.se_rwth.commons.logging.Log;
*
*/
public class CheckReinforcementRequiresEnvironment implements CNNTrainASTConfigurationCoCo {
- private static final String PARAMETER_ENVIRONMENT = "environment";
-
@Override
public void check(ASTConfiguration node) {
boolean isReinforcementLearning = ASTConfigurationUtils.isReinforcementLearning(node);
@@ -41,7 +40,7 @@ public class CheckReinforcementRequiresEnvironment implements CNNTrainASTConfigu
if (isReinforcementLearning && !hasEnvironment) {
Log.error("0" + ErrorCodes.REQUIRED_PARAMETER_MISSING + " The required parameter "
- + PARAMETER_ENVIRONMENT + " is missing");
+ + ConfigEntryNameConstants.ENVIRONMENT + " is missing");
}
}
}
\ No newline at end of file
diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckRlAlgorithmParameter.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckRlAlgorithmParameter.java
index 85d0bd1f64d18616808a7103db1f4b5bc91a439b..4b70fe9388b79dd61f0f9c654c96feb1504355b5 100644
--- a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckRlAlgorithmParameter.java
+++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckRlAlgorithmParameter.java
@@ -20,6 +20,7 @@
*/
package de.monticore.lang.monticar.cnntrain._cocos;
+import de.monticore.lang.monticar.cnntrain._ast.ASTConfiguration;
import de.monticore.lang.monticar.cnntrain._ast.ASTEntry;
import de.monticore.lang.monticar.cnntrain._ast.ASTRLAlgorithmEntry;
import de.monticore.lang.monticar.cnntrain._symboltable.RLAlgorithm;
@@ -29,61 +30,58 @@ import de.se_rwth.commons.logging.Log;
public class CheckRlAlgorithmParameter implements CNNTrainASTEntryCoCo {
private final ParameterAlgorithmMapping parameterAlgorithmMapping;
- boolean algorithmKnown;
+ private boolean isDqn = true;
+ private boolean isDdpg = true;
+ private boolean isTd3 = true;
+
RLAlgorithm algorithm;
public CheckRlAlgorithmParameter() {
parameterAlgorithmMapping = new ParameterAlgorithmMapping();
- algorithmKnown = false;
- algorithm = null;
}
@Override
public void check(ASTEntry node) {
- final boolean isDdpgParameter = parameterAlgorithmMapping.isDdpgParameter(node.getClass());
- final boolean isDqnParameter = parameterAlgorithmMapping.isDqnParameter(node.getClass());
-
+ if (!parameterAlgorithmMapping.isReinforcementLearningParameter(node.getClass())) {
+ return;
+ }
if (node instanceof ASTRLAlgorithmEntry) {
ASTRLAlgorithmEntry algorithmEntry = (ASTRLAlgorithmEntry)node;
if (algorithmEntry.getValue().isPresentDdpg()) {
- setAlgorithmToDdpg(node);
+ logWrongParameterIfCheckFails(isDdpg, node);
+ isTd3 = false;
+ isDqn = false;
+ } else if(algorithmEntry.getValue().isPresentTdThree()) {
+ logWrongParameterIfCheckFails(isTd3, node);
+ isDdpg = false;
+ isDqn = false;
} else {
- setAlgorithmToDqn(node);
+ logWrongParameterIfCheckFails(isDqn, node);
+ isDdpg = false;
+ isTd3 = false;
}
} else {
- if (isDdpgParameter && !isDqnParameter) {
- setAlgorithmToDdpg(node);
- } else if (!isDdpgParameter && isDqnParameter) {
- setAlgorithmToDqn(node);
+ final boolean isDdpgParameter = parameterAlgorithmMapping.isDdpgParameter(node.getClass());
+ final boolean isDqnParameter = parameterAlgorithmMapping.isDqnParameter(node.getClass());
+ final boolean isTd3Parameter = parameterAlgorithmMapping.isTd3Parameter(node.getClass());
+ if (!isDdpgParameter) {
+ isDdpg = false;
+ }
+ if (!isTd3Parameter) {
+ isTd3 = false;
+ }
+ if (!isDqnParameter) {
+ isDqn = false;
}
}
+ logWrongParameterIfCheckFails(isDqn || isTd3 || isDdpg, node);
}
- private void logErrorIfAlgorithmIsDqn(final ASTEntry node) {
- if (algorithmKnown && algorithm.equals(RLAlgorithm.DQN)) {
- Log.error("0" + ErrorCodes.UNSUPPORTED_PARAMETER
- + " DDPG Parameter " + node.getName() + " used but algorithm is " + algorithm + ".",
- node.get_SourcePositionStart());
- }
- }
-
- private void setAlgorithmToDdpg(final ASTEntry node) {
- logErrorIfAlgorithmIsDqn(node);
- algorithmKnown = true;
- algorithm = RLAlgorithm.DDPG;
- }
-
- private void setAlgorithmToDqn(final ASTEntry node) {
- logErrorIfAlgorithmIsDdpg(node);
- algorithmKnown = true;
- algorithm = RLAlgorithm.DQN;
- }
-
- private void logErrorIfAlgorithmIsDdpg(final ASTEntry node) {
- if (algorithmKnown && algorithm.equals(RLAlgorithm.DDPG)) {
+ private void logWrongParameterIfCheckFails(final boolean condition, final ASTEntry node) {
+ if (!condition) {
Log.error("0" + ErrorCodes.UNSUPPORTED_PARAMETER
- + " DQN Parameter " + node.getName() + " used but algorithm is " + algorithm + ".",
+ + "Parameter " + node.getName() + " used but parameter is not for chosen algorithm.",
node.get_SourcePositionStart());
}
}
diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckTrainedRlNetworkHasExactlyOneInput.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckTrainedRlNetworkHasExactlyOneInput.java
new file mode 100644
index 0000000000000000000000000000000000000000..a0e57261aeb48e932c08a12cc1357a264d327fda
--- /dev/null
+++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckTrainedRlNetworkHasExactlyOneInput.java
@@ -0,0 +1,48 @@
+/**
+ *
+ * ******************************************************************************
+ * MontiCAR Modeling Family, www.se-rwth.de
+ * Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
+ * All rights reserved.
+ *
+ * This project is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License as published by the Free Software Foundation; either
+ * version 3.0 of the License, or (at your option) any later version.
+ * This library is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with this project. If not, see .
+ * *******************************************************************************
+ */
+package de.monticore.lang.monticar.cnntrain._cocos;
+
+import de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol;
+import de.monticore.lang.monticar.cnntrain._symboltable.RLAlgorithm;
+import de.monticore.lang.monticar.cnntrain.helper.ConfigEntryNameConstants;
+import de.monticore.lang.monticar.cnntrain.helper.ErrorCodes;
+import de.se_rwth.commons.logging.Log;
+
+/**
+ *
+ */
+public class CheckTrainedRlNetworkHasExactlyOneInput implements CNNTrainConfigurationSymbolCoCo {
+ @Override
+ public void check(ConfigurationSymbol configurationSymbol) {
+ if (configurationSymbol.isReinforcementLearningMethod()
+ && configurationSymbol.getTrainedArchitecture().isPresent()) {
+ final int numberOfInputs = configurationSymbol.getTrainedArchitecture().get().getInputs().size();
+ if (numberOfInputs != 1) {
+ final String networkName
+ = configurationSymbol.getEntry(ConfigEntryNameConstants.RL_ALGORITHM).getValue().getValue()
+ .equals(RLAlgorithm.DQN) ? "Q-Network" : "Actor-Network";
+ Log.error("x0" + ErrorCodes.TRAINED_ARCHITECTURE_ERROR
+ + networkName + " " +configurationSymbol.getTrainedArchitecture().get().getName()
+ +" has " + numberOfInputs + " inputs but 1 is only allowed.", configurationSymbol.getSourcePosition());
+ }
+ }
+ }
+}
\ No newline at end of file
diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckTrainedRlNetworkHasExactlyOneOutput.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckTrainedRlNetworkHasExactlyOneOutput.java
new file mode 100644
index 0000000000000000000000000000000000000000..a0d7487abf254ad982398c5d53665f468f7c396b
--- /dev/null
+++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckTrainedRlNetworkHasExactlyOneOutput.java
@@ -0,0 +1,48 @@
+/**
+ *
+ * ******************************************************************************
+ * MontiCAR Modeling Family, www.se-rwth.de
+ * Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
+ * All rights reserved.
+ *
+ * This project is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License as published by the Free Software Foundation; either
+ * version 3.0 of the License, or (at your option) any later version.
+ * This library is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with this project. If not, see .
+ * *******************************************************************************
+ */
+package de.monticore.lang.monticar.cnntrain._cocos;
+
+import de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol;
+import de.monticore.lang.monticar.cnntrain._symboltable.RLAlgorithm;
+import de.monticore.lang.monticar.cnntrain.helper.ConfigEntryNameConstants;
+import de.monticore.lang.monticar.cnntrain.helper.ErrorCodes;
+import de.se_rwth.commons.logging.Log;
+
+/**
+ *
+ */
+public class CheckTrainedRlNetworkHasExactlyOneOutput implements CNNTrainConfigurationSymbolCoCo {
+ @Override
+ public void check(final ConfigurationSymbol configurationSymbol) {
+ if (configurationSymbol.isReinforcementLearningMethod()
+ && configurationSymbol.getTrainedArchitecture().isPresent()) {
+ final int numberOfOutputs = configurationSymbol.getTrainedArchitecture().get().getOutputs().size();
+ if (numberOfOutputs != 1) {
+ final String networkName
+ = configurationSymbol.getEntry(ConfigEntryNameConstants.RL_ALGORITHM).getValue().getValue()
+ .equals(RLAlgorithm.DQN) ? "Q-Network" : "Actor-Network";
+ Log.error("x0" + ErrorCodes.TRAINED_ARCHITECTURE_ERROR
+ + networkName + " " +configurationSymbol.getTrainedArchitecture().get().getName()
+ +" has " + numberOfOutputs + " outputs but 1 is only allowed.", configurationSymbol.getSourcePosition());
+ }
+ }
+ }
+}
\ No newline at end of file
diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ParameterAlgorithmMapping.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ParameterAlgorithmMapping.java
index 6443cf45709f0efedf85c397aff9f6d3b3a4a6cc..e05aadbbdaed137b0ee1bf5d7b22f1c3d456203e 100644
--- a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ParameterAlgorithmMapping.java
+++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ParameterAlgorithmMapping.java
@@ -79,6 +79,7 @@ class ParameterAlgorithmMapping {
ASTMinEpsilonEntry.class,
ASTEpsilonDecayEntry.class,
ASTEpsilonDecayMethodEntry.class,
+ ASTEpsilonDecayPerStepEntry.class,
ASTNumEpisodesEntry.class,
ASTRosEnvironmentActionTopicEntry.class,
ASTRosEnvironmentStateTopicEntry.class,
@@ -104,7 +105,21 @@ class ParameterAlgorithmMapping {
ASTCriticOptimizerEntry.class,
ASTStrategyOUMu.class,
ASTStrategyOUTheta.class,
- ASTStrategyOUSigma.class
+ ASTStrategyOUSigma.class,
+ ASTStrategyGaussianNoiseVarianceEntry.class
+ );
+
+ private static final List EXCLUSIVE_TD3_PARAMETERS = Lists.newArrayList(
+ ASTCriticNetworkEntry.class,
+ ASTSoftTargetUpdateRateEntry.class,
+ ASTCriticOptimizerEntry.class,
+ ASTStrategyOUMu.class,
+ ASTStrategyOUTheta.class,
+ ASTStrategyOUSigma.class,
+ ASTPolicyNoiseEntry.class,
+ ASTNoiseClipEntry.class,
+ ASTPolicyDelayEntry.class,
+ ASTStrategyGaussianNoiseVarianceEntry.class
);
ParameterAlgorithmMapping() {
@@ -115,7 +130,8 @@ class ParameterAlgorithmMapping {
return GENERAL_PARAMETERS.contains(entryClazz)
|| GENERAL_REINFORCEMENT_PARAMETERS.contains(entryClazz)
|| EXCLUSIVE_DQN_PARAMETERS.contains(entryClazz)
- || EXCLUSIVE_DDPG_PARAMETERS.contains(entryClazz);
+ || EXCLUSIVE_DDPG_PARAMETERS.contains(entryClazz)
+ || EXCLUSIVE_TD3_PARAMETERS.contains(entryClazz);
}
boolean isSupervisedLearningParameter(Class extends ASTEntry> entryClazz) {
@@ -135,12 +151,19 @@ class ParameterAlgorithmMapping {
|| EXCLUSIVE_DDPG_PARAMETERS.contains(entryClazz);
}
+ boolean isTd3Parameter(Class extends ASTEntry> entryClazz) {
+ return GENERAL_PARAMETERS.contains(entryClazz)
+ || GENERAL_REINFORCEMENT_PARAMETERS.contains(entryClazz)
+ || EXCLUSIVE_TD3_PARAMETERS.contains(entryClazz);
+ }
+
List getAllReinforcementParameters() {
return ImmutableList. builder()
.addAll(GENERAL_PARAMETERS)
.addAll(GENERAL_REINFORCEMENT_PARAMETERS)
.addAll(EXCLUSIVE_DQN_PARAMETERS)
.addAll(EXCLUSIVE_DDPG_PARAMETERS)
+ .addAll(EXCLUSIVE_TD3_PARAMETERS)
.build();
}
diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/CNNTrainSymbolTableCreator.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/CNNTrainSymbolTableCreator.java
index f38c711dd9e4e815e7461131d9769ba1e82afc95..25bc4b9a5a8ecc0a11a37fbfaba9b2bd218b2e34 100644
--- a/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/CNNTrainSymbolTableCreator.java
+++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/CNNTrainSymbolTableCreator.java
@@ -351,6 +351,8 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP {
if (node.getValue().isPresentDdpg()) {
value.setValue(RLAlgorithm.DDPG);
+ } else if(node.getValue().isPresentTdThree()) {
+ value.setValue(RLAlgorithm.TD3);
} else {
value.setValue(RLAlgorithm.DQN);
}
@@ -514,6 +516,38 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP {
configuration.getEntryMap().put(node.getName(), entry);
}
+ @Override
+ public void visit(ASTPolicyNoiseEntry node) {
+ EntrySymbol entry = new EntrySymbol(node.getName());
+ entry.setValue(getValueSymbolForDouble(node.getValue()));
+ addToScopeAndLinkWithNode(entry, node);
+ configuration.getEntryMap().put(node.getName(), entry);
+ }
+
+ @Override
+ public void visit(ASTNoiseClipEntry node) {
+ EntrySymbol entry = new EntrySymbol(node.getName());
+ entry.setValue(getValueSymbolForDouble(node.getValue()));
+ addToScopeAndLinkWithNode(entry, node);
+ configuration.getEntryMap().put(node.getName(), entry);
+ }
+
+ @Override
+ public void visit(ASTStrategyGaussianNoiseVarianceEntry node) {
+ EntrySymbol entry = new EntrySymbol(node.getName());
+ entry.setValue(getValueSymbolForDouble(node.getValue()));
+ addToScopeAndLinkWithNode(entry, node);
+ configuration.getEntryMap().put(node.getName(), entry);
+ }
+
+ @Override
+ public void visit(ASTPolicyDelayEntry node) {
+ EntrySymbol entry = new EntrySymbol(node.getName());
+ entry.setValue(getValueSymbolForInteger(node.getValue()));
+ addToScopeAndLinkWithNode(entry, node);
+ configuration.getEntryMap().put(node.getName(), entry);
+ }
+
private void processMultiParamConfigVisit(ASTMultiParamConfigEntry node, Object value) {
EntrySymbol entry = new EntrySymbol(node.getName());
MultiParamValueSymbol valueSymbol = new MultiParamValueSymbol();
diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/ConfigurationSymbol.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/ConfigurationSymbol.java
index 1a7e8829bb62df21c92bcaee383601bf2bda5089..6e41d6f0143affda053ec1af6641398a92c52d67 100644
--- a/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/ConfigurationSymbol.java
+++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/ConfigurationSymbol.java
@@ -20,13 +20,12 @@
*/
package de.monticore.lang.monticar.cnntrain._symboltable;
-import com.google.common.collect.Lists;
-import de.monticore.lang.monticar.cnntrain.annotations.TrainedArchitecture;
import de.monticore.symboltable.CommonScopeSpanningSymbol;
-import javax.swing.text.html.Option;
import java.util.*;
+import static de.monticore.lang.monticar.cnntrain.helper.ConfigEntryNameConstants.*;
+
public class ConfigurationSymbol extends CommonScopeSpanningSymbol {
private Map entryMap = new HashMap<>();
@@ -34,7 +33,8 @@ public class ConfigurationSymbol extends CommonScopeSpanningSymbol {
private OptimizerSymbol criticOptimizer;
private LossSymbol loss;
private RewardFunctionSymbol rlRewardFunctionSymbol;
- private TrainedArchitecture trainedArchitecture;
+ private NNArchitectureSymbol trainedArchitecture;
+ private NNArchitectureSymbol criticNetwork;
public static final ConfigurationSymbolKind KIND = new ConfigurationSymbolKind();
@@ -76,14 +76,22 @@ public class ConfigurationSymbol extends CommonScopeSpanningSymbol {
return Optional.ofNullable(this.rlRewardFunctionSymbol);
}
- public Optional getTrainedArchitecture() {
+ public Optional getTrainedArchitecture() {
return Optional.ofNullable(trainedArchitecture);
}
- public void setTrainedArchitecture(TrainedArchitecture trainedArchitecture) {
+ public void setTrainedArchitecture(NNArchitectureSymbol trainedArchitecture) {
this.trainedArchitecture = trainedArchitecture;
}
+ public Optional getCriticNetwork() {
+ return Optional.ofNullable(criticNetwork);
+ }
+
+ public void setCriticNetwork(NNArchitectureSymbol criticNetwork) {
+ this.criticNetwork = criticNetwork;
+ }
+
public Map getEntryMap() {
return entryMap;
}
@@ -93,7 +101,25 @@ public class ConfigurationSymbol extends CommonScopeSpanningSymbol {
}
public LearningMethod getLearningMethod() {
- return this.entryMap.containsKey("learning_method")
- ? (LearningMethod)this.entryMap.get("learning_method").getValue().getValue() : LearningMethod.SUPERVISED;
+ return this.entryMap.containsKey(LEARNING_METHOD)
+ ? (LearningMethod)this.entryMap.get(LEARNING_METHOD).getValue().getValue() : LearningMethod.SUPERVISED;
+ }
+
+ public boolean isReinforcementLearningMethod() {
+ return getLearningMethod().equals(LearningMethod.REINFORCEMENT);
+ }
+
+ public boolean hasCritic() {
+ return getEntryMap().containsKey(CRITIC);
+ }
+
+ public Optional getCriticName() {
+ if (!hasCritic()) {
+ return Optional.empty();
+ }
+
+ final Object criticNameValue = getEntry(CRITIC).getValue().getValue();
+ assert criticNameValue instanceof String;
+ return Optional.of((String)criticNameValue);
}
}
\ No newline at end of file
diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/MultiParamValueSymbol.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/MultiParamValueSymbol.java
index bff39c9627da7125979ef5e88a04cb1da898d0a1..5f152d820ae108e0971be6369c21e6ee75eb7901 100644
--- a/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/MultiParamValueSymbol.java
+++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/MultiParamValueSymbol.java
@@ -44,6 +44,10 @@ public class MultiParamValueSymbol extends ValueSymbol {
return parameters.get(parameterName);
}
+ public boolean hasParameter(final String parameterName) {
+ return parameters.containsKey(parameterName);
+ }
+
public void addParameter(final String parameterName, final Object value) {
parameters.put(parameterName, value);
}
diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/NNArchitectureSymbol.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/NNArchitectureSymbol.java
new file mode 100644
index 0000000000000000000000000000000000000000..c606b90618169e331647a66c7a43bde8985a9bb8
--- /dev/null
+++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/NNArchitectureSymbol.java
@@ -0,0 +1,41 @@
+/**
+ *
+ * ******************************************************************************
+ * MontiCAR Modeling Family, www.se-rwth.de
+ * Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
+ * All rights reserved.
+ *
+ * This project is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License as published by the Free Software Foundation; either
+ * version 3.0 of the License, or (at your option) any later version.
+ * This library is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with this project. If not, see .
+ * *******************************************************************************
+ */
+package de.monticore.lang.monticar.cnntrain._symboltable;
+
+import de.monticore.lang.monticar.cnntrain.annotations.Range;
+import de.monticore.symboltable.SymbolKind;
+
+import java.util.List;
+import java.util.Map;
+
+public abstract class NNArchitectureSymbol extends de.monticore.symboltable.CommonSymbol {
+ public static final NNArchitectureSymbolKind KIND = NNArchitectureSymbolKind.INSTANCE;
+
+ public NNArchitectureSymbol(String name) {
+ super(name, KIND);
+ }
+
+ abstract public List getInputs();
+ abstract public List getOutputs();
+ abstract public Map> getDimensions();
+ abstract public Map getRanges();
+ abstract public Map getTypes();
+}
\ No newline at end of file
diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/NNArchitectureSymbolKind.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/NNArchitectureSymbolKind.java
new file mode 100644
index 0000000000000000000000000000000000000000..f06e869bce8a7c78aad1bec3f697c65864f61099
--- /dev/null
+++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/NNArchitectureSymbolKind.java
@@ -0,0 +1,38 @@
+/**
+ *
+ * ******************************************************************************
+ * MontiCAR Modeling Family, www.se-rwth.de
+ * Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
+ * All rights reserved.
+ *
+ * This project is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License as published by the Free Software Foundation; either
+ * version 3.0 of the License, or (at your option) any later version.
+ * This library is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with this project. If not, see .
+ * *******************************************************************************
+ */
+package de.monticore.lang.monticar.cnntrain._symboltable;
+
+import de.monticore.symboltable.SymbolKind;
+
+public class NNArchitectureSymbolKind implements SymbolKind {
+ public static final NNArchitectureSymbolKind INSTANCE = new NNArchitectureSymbolKind();
+ private static final String NAME = "de.monticore.lang.monticar.cnntrain._symboltable.NNArchitectureSymbolKind";
+
+ @Override
+ public String getName() {
+ return NAME;
+ }
+
+ @Override
+ public boolean isKindOf(SymbolKind kind) {
+ return NAME.equals(kind.getName()) || SymbolKind.super.isKindOf(kind);
+ }
+}
\ No newline at end of file
diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/RLAlgorithm.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/RLAlgorithm.java
index f0e38ffffc5f7cebde523b53daaeb0243bbc179b..014649bed5c9b9f6fb18387d1c334706c292e197 100644
--- a/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/RLAlgorithm.java
+++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/RLAlgorithm.java
@@ -32,5 +32,11 @@ public enum RLAlgorithm {
public String toString() {
return "ddpg";
}
+ },
+ TD3 {
+ @Override
+ public String toString() {
+ return "td3";
+ }
}
-}
+}
\ No newline at end of file
diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/annotations/Range.java b/src/main/java/de/monticore/lang/monticar/cnntrain/annotations/Range.java
index bbbb3fc7041722519c8f2bef8d99b3ceb2dc0767..06a0582b63ba93eb8b8e30fe4244316dd67102ed 100644
--- a/src/main/java/de/monticore/lang/monticar/cnntrain/annotations/Range.java
+++ b/src/main/java/de/monticore/lang/monticar/cnntrain/annotations/Range.java
@@ -20,6 +20,7 @@
*/
package de.monticore.lang.monticar.cnntrain.annotations;
+import java.util.Objects;
import java.util.Optional;
public class Range {
@@ -66,4 +67,28 @@ public class Range {
public static Range withLowerInfinityLimit(double upperLimit) {
return new Range(true, false, null, upperLimit);
}
+
+ @Override
+ public String toString() {
+ final String lowerLimit = isLowerLimitInfinity() || !getLowerLimit().isPresent() ? "-oo" : getLowerLimit().get().toString();
+ final String upperLimit = isUpperLimitInfinity() || !getUpperLimit().isPresent() ? "oo" : getUpperLimit().get().toString();
+
+ return "[" + lowerLimit + ", " + upperLimit + "]";
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) return true;
+ if (!(o instanceof Range)) return false;
+ Range range = (Range) o;
+ return lowerLimitIsInfinity == range.lowerLimitIsInfinity &&
+ upperLimitIsInfinity == range.upperLimitIsInfinity &&
+ Objects.equals(lowerLimit, range.lowerLimit) &&
+ Objects.equals(upperLimit, range.upperLimit);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(lowerLimitIsInfinity, upperLimitIsInfinity, lowerLimit, upperLimit);
+ }
}
diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/helper/ConfigEntryNameConstants.java b/src/main/java/de/monticore/lang/monticar/cnntrain/helper/ConfigEntryNameConstants.java
new file mode 100644
index 0000000000000000000000000000000000000000..58c2274904796ac0f3defc107e3f645d701a107b
--- /dev/null
+++ b/src/main/java/de/monticore/lang/monticar/cnntrain/helper/ConfigEntryNameConstants.java
@@ -0,0 +1,63 @@
+/**
+ *
+ * ******************************************************************************
+ * MontiCAR Modeling Family, www.se-rwth.de
+ * Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
+ * All rights reserved.
+ *
+ * This project is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License as published by the Free Software Foundation; either
+ * version 3.0 of the License, or (at your option) any later version.
+ * This library is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with this project. If not, see .
+ * *******************************************************************************
+ */
+package de.monticore.lang.monticar.cnntrain.helper;
+
+/**
+ *
+ */
+public class ConfigEntryNameConstants {
+ public static final String LEARNING_METHOD = "learning_method";
+ public static final String NUM_EPISODES = "num_episodes";
+ public static final String DISCOUNT_FACTOR = "discount_factor";
+ public static final String NUM_MAX_STEPS = "num_max_steps";
+ public static final String TARGET_SCORE = "target_score";
+ public static final String TRAINING_INTERVAL = "training_interval";
+ public static final String USE_FIX_TARGET_NETWORK = "use_fix_target_network";
+ public static final String TARGET_NETWORK_UPDATE_INTERVAL = "target_network_update_interval";
+ public static final String SNAPSHOT_INTERVAL = "snapshot_interval";
+ public static final String AGENT_NAME = "agent_name";
+ public static final String USE_DOUBLE_DQN = "use_double_dqn";
+ public static final String LOSS = "loss";
+ public static final String RL_ALGORITHM = "rl_algorithm";
+ public static final String REPLAY_MEMORY = "replay_memory";
+ public static final String ENVIRONMENT = "environment";
+ public static final String START_TRAINING_AT = "start_training_at";
+ public static final String SOFT_TARGET_UPDATE_RATE = "soft_target_update_rate";
+ public static final String EVALUATION_SAMPLES = "evaluation_samples";
+ public static final String POLICY_NOISE = "policy_noise";
+ public static final String NOISE_CLIP = "noise_clip";
+ public static final String POLICY_DELAY = "policy_delay";
+
+ public static final String ENVIRONMENT_REWARD_TOPIC = "reward_topic";
+ public static final String ENVIRONMENT_ROS = "ros_interface";
+ public static final String ENVIRONMENT_GYM = "gym";
+
+ public static final String STRATEGY = "strategy";
+ public static final String STRATEGY_OU = "ornstein_uhlenbeck";
+ public static final String STRATEGY_OU_MU = "mu";
+ public static final String STRATEGY_OU_THETA = "theta";
+ public static final String STRATEGY_OU_SIGMA = "sigma";
+ public static final String STRATEGY_GAUSSIAN = "gaussian";
+ public static final String STRATEGY_EPSGREEDY = "epsgreedy";
+ public static final String STRATEGY_EPSDECAY = "epsdecay";
+
+ public static final String CRITIC = "critic";
+}
diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/helper/ErrorCodes.java b/src/main/java/de/monticore/lang/monticar/cnntrain/helper/ErrorCodes.java
index 6aa9e5074f0324090812cb222078c1c1e110704e..fa7c21bb72b3d4183eed524a771f9cdd33dda60d 100644
--- a/src/main/java/de/monticore/lang/monticar/cnntrain/helper/ErrorCodes.java
+++ b/src/main/java/de/monticore/lang/monticar/cnntrain/helper/ErrorCodes.java
@@ -31,4 +31,7 @@ public class ErrorCodes {
public static final String REQUIRED_PARAMETER_MISSING = "xC8856";
public static final String STRATEGY_NOT_APPLICABLE = "xC8857";
public static final String CONTRADICTING_PARAMETERS = "xC8858";
+ public static final String CRITIC_NETWORK_ERROR = "xC7100";
+ public static final String MISSING_TRAINED_ARCHITECTURE = "xC7101";
+ public static final String TRAINED_ARCHITECTURE_ERROR = "xC7102";
}
\ No newline at end of file
diff --git a/src/test/java/de/monticore/lang/monticar/cnntrain/cocos/AllCoCoTest.java b/src/test/java/de/monticore/lang/monticar/cnntrain/cocos/AllCoCoTest.java
index cd7d8f63a3d25e560bc3847c47c6930741365212..059932d9bad59d84c4577a649723c21313a49732 100644
--- a/src/test/java/de/monticore/lang/monticar/cnntrain/cocos/AllCoCoTest.java
+++ b/src/test/java/de/monticore/lang/monticar/cnntrain/cocos/AllCoCoTest.java
@@ -34,64 +34,164 @@ public class AllCoCoTest extends AbstractCoCoTest{
}
@Test
- public void testValidCoCos() throws IOException {
+ public void testValidSimpleConfig1() {
checkValid("valid_tests","SimpleConfig1");
+ }
+
+ @Test
+ public void testValidSimpleConfig2() {
checkValid("valid_tests","SimpleConfig2");
+ }
+
+ @Test
+ public void testValidFullConfig() {
checkValid("valid_tests","FullConfig");
+ }
+
+ @Test
+ public void testValidFullConfig2() {
checkValid("valid_tests","FullConfig2");
- checkValid("valid_tests", "ReinforcementConfig");
- checkValid("valid_tests", "ReinforcementConfig2");
- checkValid("valid_tests", "DdpgConfig");
+ }
+
+ @Test
+ public void testValidReinforcementConfig() {
+ checkValid("valid_tests","ReinforcementConfig");
+ }
+
+ @Test
+ public void testValidReinforcementConfig2() {
+ checkValid("valid_tests","ReinforcementConfig2");
+ }
+
+ @Test
+ public void testValidDdpgConfig() {
+ checkValid("valid_tests","DdpgConfig");
+ }
+
+ @Test
+ public void testValidTD3Config() {
+ checkValid("valid_tests","TD3Config");
+ }
+
+ @Test
+ public void testValidReinforcementWithRosReward() throws IOException {
checkValid("valid_tests", "ReinforcementWithRosReward");
}
@Test
- public void testInvalidCoCos() throws IOException {
+ public void testInvalidEntryRepetition() {
checkInvalid(new CNNTrainCoCoChecker().addCoCo(new CheckEntryRepetition()),
"invalid_cocos_tests", "EntryRepetition",
new ExpectedErrorInfo(1, ErrorCodes.ENTRY_REPETITION_CODE));
+ }
+
+ @Test
+ public void testInvalidIntegerTest() {
checkInvalid(new CNNTrainCoCoChecker().addCoCo(new CheckInteger()),
"invalid_cocos_tests", "IntegerTest",
new ExpectedErrorInfo(1, ErrorCodes.NOT_INTEGER_CODE));
+ }
+
+ @Test
+ public void testInvalidFixTargetNetworkRequiresInterval1() {
checkInvalid(new CNNTrainCoCoChecker().addCoCo(new CheckFixTargetNetworkRequiresInterval()),
"invalid_cocos_tests", "FixTargetNetworkRequiresInterval1",
new ExpectedErrorInfo(1, ErrorCodes.REQUIRED_PARAMETER_MISSING));
+ }
+
+ @Test
+ public void testInvalidFixTargetNetworkRequiresInterval2() {
checkInvalid(new CNNTrainCoCoChecker().addCoCo(new CheckFixTargetNetworkRequiresInterval()),
"invalid_cocos_tests", "FixTargetNetworkRequiresInterval2",
new ExpectedErrorInfo(1, ErrorCodes.REQUIRED_PARAMETER_MISSING));
+ }
+
+ @Test
+ public void testInvalidCheckLearningParameterCombination1() {
checkInvalid(new CNNTrainCoCoChecker().addCoCo(new CheckLearningParameterCombination()),
"invalid_cocos_tests", "CheckLearningParameterCombination1",
new ExpectedErrorInfo(1, ErrorCodes.UNSUPPORTED_PARAMETER));
+ }
+
+ @Test
+ public void testInvalidCheckLearningParameterCombination2() {
checkInvalid(new CNNTrainCoCoChecker().addCoCo(new CheckLearningParameterCombination()),
"invalid_cocos_tests", "CheckLearningParameterCombination2",
new ExpectedErrorInfo(3, ErrorCodes.UNSUPPORTED_PARAMETER));
+ }
+
+ @Test
+ public void testInvalidCheckLearningParameterCombination3() {
checkInvalid(new CNNTrainCoCoChecker().addCoCo(new CheckLearningParameterCombination()),
"invalid_cocos_tests", "CheckLearningParameterCombination3",
new ExpectedErrorInfo(2, ErrorCodes.UNSUPPORTED_PARAMETER));
+ }
+
+ @Test
+ public void testInvalidCheckLearningParameterCombination4() {
checkInvalid(new CNNTrainCoCoChecker().addCoCo(new CheckLearningParameterCombination()),
"invalid_cocos_tests", "CheckLearningParameterCombination4",
new ExpectedErrorInfo(5, ErrorCodes.UNSUPPORTED_PARAMETER));
+ }
+
+ @Test
+ public void testInvalidCheckReinforcementRequiresEnvironment() {
checkInvalid(new CNNTrainCoCoChecker().addCoCo(new CheckReinforcementRequiresEnvironment()),
"invalid_cocos_tests", "CheckReinforcementRequiresEnvironment",
new ExpectedErrorInfo(1, ErrorCodes.REQUIRED_PARAMETER_MISSING));
+ }
+
+ @Test
+ public void testInvalidCheckRosEnvironmentRequiresRewardFunction() {
checkInvalid(new CNNTrainCoCoChecker().addCoCo(new CheckRosEnvironmentRequiresRewardFunction()),
"invalid_cocos_tests", "CheckRosEnvironmentRequiresRewardFunction",
new ExpectedErrorInfo(1, ErrorCodes.REQUIRED_PARAMETER_MISSING));
+ }
+
+ @Test
+ public void testInvalidCheckRLAlgorithmParameter1() {
checkInvalid(new CNNTrainCoCoChecker().addCoCo(new CheckRlAlgorithmParameter()),
"invalid_cocos_tests", "CheckRLAlgorithmParameter1",
new ExpectedErrorInfo(1, ErrorCodes.UNSUPPORTED_PARAMETER));
+ }
+
+ @Test
+ public void testInvalidCheckRLAlgorithmParameter2() {
checkInvalid(new CNNTrainCoCoChecker().addCoCo(new CheckRlAlgorithmParameter()),
"invalid_cocos_tests", "CheckRLAlgorithmParameter2",
- new ExpectedErrorInfo(1, ErrorCodes.UNSUPPORTED_PARAMETER));
+ new ExpectedErrorInfo(2, ErrorCodes.UNSUPPORTED_PARAMETER));
+ }
+
+ @Test
+ public void testInvalidCheckRLAlgorithmParameter3() {
checkInvalid(new CNNTrainCoCoChecker().addCoCo(new CheckRlAlgorithmParameter()),
"invalid_cocos_tests", "CheckRLAlgorithmParameter3",
new ExpectedErrorInfo(1, ErrorCodes.UNSUPPORTED_PARAMETER));
+ }
+
+ @Test
+ public void testInvalidCheckRLAlgorithmParameter4() {
+ checkInvalid(new CNNTrainCoCoChecker().addCoCo(new CheckRlAlgorithmParameter()),
+ "invalid_cocos_tests", "CheckRLAlgorithmParameter4",
+ new ExpectedErrorInfo(1, ErrorCodes.UNSUPPORTED_PARAMETER));
+ }
+
+ @Test
+ public void testInvalidCheckDiscreteRLAlgorithmUsesDiscreteStrategy() {
checkInvalid(new CNNTrainCoCoChecker().addCoCo(new CheckDiscreteRLAlgorithmUsesDiscreteStrategy()),
"invalid_cocos_tests", "CheckDiscreteRLAlgorithmUsesDiscreteStrategy",
new ExpectedErrorInfo(1, ErrorCodes.STRATEGY_NOT_APPLICABLE));
+ }
+
+ @Test
+ public void testInvalidCheckContinuousRLAlgorithmUsesContinuousStrategy() {
checkInvalid(new CNNTrainCoCoChecker().addCoCo(new CheckContinuousRLAlgorithmUsesContinuousStrategy()),
"invalid_cocos_tests", "CheckContinuousRLAlgorithmUsesContinuousStrategy",
new ExpectedErrorInfo(1, ErrorCodes.STRATEGY_NOT_APPLICABLE));
+ }
+
+ @Test
+ public void testInvalidCheckRosEnvironmentHasOnlyOneRewardSpecification() {
checkInvalid(new CNNTrainCoCoChecker().addCoCo(new CheckRosEnvironmentHasOnlyOneRewardSpecification()),
"invalid_cocos_tests", "CheckRosEnvironmentHasOnlyOneRewardSpecification",
new ExpectedErrorInfo(1, ErrorCodes.CONTRADICTING_PARAMETERS));
diff --git a/src/test/resources/invalid_cocos_tests/CheckRLAlgorithmParameter4.cnnt b/src/test/resources/invalid_cocos_tests/CheckRLAlgorithmParameter4.cnnt
new file mode 100644
index 0000000000000000000000000000000000000000..a8b0c2797568bff9107a4c43e5189845cdcaf5c8
--- /dev/null
+++ b/src/test/resources/invalid_cocos_tests/CheckRLAlgorithmParameter4.cnnt
@@ -0,0 +1,19 @@
+configuration CheckRLAlgorithmParameter4 {
+ learning_method : reinforcement
+
+ rl_algorithm: td3-algorithm
+
+ agent_name : "reinforcement-agent"
+
+ environment : gym { name:"CartPole-v1" }
+
+ context : cpu
+
+ num_episodes : 300
+ num_max_steps : 9999
+ discount_factor : 0.998
+ target_score : 1000
+ training_interval : 10
+
+ use_double_dqn: true
+}
\ No newline at end of file
diff --git a/src/test/resources/valid_tests/TD3Config.cnnt b/src/test/resources/valid_tests/TD3Config.cnnt
new file mode 100644
index 0000000000000000000000000000000000000000..0ed09ce35025e801328fd1689eed1ce4686ad31f
--- /dev/null
+++ b/src/test/resources/valid_tests/TD3Config.cnnt
@@ -0,0 +1,32 @@
+configuration TD3Config {
+ learning_method : reinforcement
+ rl_algorithm : td3-algorithm
+ critic : path.to.component
+ environment : gym { name:"CartPole-v1" }
+ soft_target_update_rate: 0.001
+ policy_noise: 0.2
+ noise_clip: 0.5
+ policy_delay: 2
+
+ actor_optimizer : adam{
+ learning_rate : 0.0001
+ learning_rate_minimum : 0.00005
+ learning_rate_decay : 0.9
+ learning_rate_policy : step
+ }
+ critic_optimizer : rmsprop{
+ learning_rate : 0.001
+ learning_rate_minimum : 0.0001
+ learning_rate_decay : 0.5
+ learning_rate_policy : step
+ }
+ strategy : gaussian {
+ epsilon: 1.0
+ min_epsilon: 0.001
+ noise_variance: 0.5
+ epsilon_decay_per_step: true
+ epsilon_decay_method: linear
+ epsilon_decay : 0.0001
+ epsilon_decay_start: 50
+ }
+}
\ No newline at end of file