Commit 33bf5036 authored by Nicola Gatto's avatar Nicola Gatto Committed by Evgeny Kusmenko

Add ddpg support

parent 907f1be8
......@@ -114,24 +114,31 @@ configuration ReinforcementConfig {
### Available Parameters for Reinforcement Learning
| Parameter | Value | Default | Required | Description |
|------------|--------|---------|----------|-------------|
|learning_method| reinforcement,supervised | supervised | No | Determines that this CNNTrain configuration is a reinforcement or supervised learning configuration |
| agent_name | String | "agent" | No | Names the agent (e.g. for logging output) |
|environment | gym, ros_interface | Yes | / | If *ros_interface* is selected, then the agent and the environment communicates via [ROS](http://www.ros.org/). The gym environment comes with a set of environments which are listed [here](https://gym.openai.com/) |
| context | cpu, gpu | cpu | No | Determines whether the GPU is used during training or the CPU |
| num_episodes | Integer | 50 | No | Number of episodes the agent is trained. An episode is a full passing of a game from an initial state to a terminal state.|
| num_max_steps | Integer | 99999 | No | Number of steps within an episodes before the environment is forced to reset the state (e.g. to avoid a state in which the agent is stuck) |
|discount_factor | Float | 0.9 | No | Discount factor |
| target_score | Float | None | No | If set, the agent stops the training when the average score of the last 100 episodes is greater than the target score. |
| training_interval | Integer | 1 | No | Number of steps between two trainings |
| loss | euclidean, l1, softmax_cross_entropy, sigmoid_cross_entropy, huber_loss | euclidean | No | Selects the loss function
| use_fix_target_network | bool | false | No | If set, an extra network with fixed parameters is used to estimate the Q values |
| target_network_update_interval | Integer | / | Yes, if fixed target network is true | If *use_fix_target_network* is set, it determines the number of steps after the target network is updated (Minh et. al. "Human Level Control through Deep Reinforcement Learning")|
| Parameter | Value | Default | Required | Algorithm | Description |
|------------|--------|---------|----------|-----------|-------------|
|learning_method| reinforcement,supervised | supervised | No | All | Determines that this CNNTrain configuration is a reinforcement or supervised learning configuration |
| rl_algorithm | ddpg-algorithm, dqn-algorithm | dqn-algorithm | No | All | Determines the RL algorithm that is used to train the agent
| agent_name | String | "agent" | No | All | Names the agent (e.g. for logging output) |
|environment | gym, ros_interface | Yes | / | All | If *ros_interface* is selected, then the agent and the environment communicates via [ROS](http://www.ros.org/). The gym environment comes with a set of environments which are listed [here](https://gym.openai.com/) |
| context | cpu, gpu | cpu | No | All | Determines whether the GPU is used during training or the CPU |
| num_episodes | Integer | 50 | No | All | Number of episodes the agent is trained. An episode is a full passing of a game from an initial state to a terminal state.|
| num_max_steps | Integer | 99999 | No | All | Number of steps within an episodes before the environment is forced to reset the state (e.g. to avoid a state in which the agent is stuck) |
|discount_factor | Float | 0.9 | No | All | Discount factor |
| target_score | Float | None | No | All | If set, the agent stops the training when the average score of the last 100 episodes is greater than the target score. |
| training_interval | Integer | 1 | No | All | Number of steps between two trainings |
| loss | euclidean, l1, softmax_cross_entropy, sigmoid_cross_entropy, huber_loss | euclidean | No | DQN | Selects the loss function
| use_fix_target_network | bool | false | No | DQN | If set, an extra network with fixed parameters is used to estimate the Q values |
| target_network_update_interval | Integer | / | DQN | Yes, if fixed target network is true | If *use_fix_target_network* is set, it determines the number of steps after the target network is updated (Minh et. al. "Human Level Control through Deep Reinforcement Learning")|
| use_double_dqn | bool | false | No | If set, two value functions are used to determine the action values (Hasselt et. al. "Deep Reinforcement Learning with Double Q Learning") |
| replay_memory | buffer, online, combined | buffer | No | Determines the behaviour of the replay memory |
| action_selection | epsgreedy | epsgreedy | No | Determines the action selection policy during the training |
| reward_function | Full name of an EMAM component | / | Yes, if *ros_interface* is selected as the environment | The EMAM component that is used to calculate the reward. It must have two inputs, one for the current state and one boolean input that determines if the current state is terminal. It must also have exactly one output which represents the reward. |
| replay_memory | buffer, online, combined | buffer | No | All | Determines the behaviour of the replay memory |
| strategy | epsgreedy, ornstein_uhlenbeck | epsgreedy (discrete), ornstein_uhlenbeck (continuous) | No | All | Determines the action selection policy during the training |
| reward_function | Full name of an EMAM component | / | Yes, if *ros_interface* is selected as the environment and no reward topic is given | All | The EMAM component that is used to calculate the reward. It must have two inputs, one for the current state and one boolean input that determines if the current state is terminal. It must also have exactly one output which represents the reward. |
critic | Full name of architecture definition | / | Yes, if DDPG is selected | DDPG | The architecture definition which specifies the architecture of the critic network |
soft_target_update_rate | Float | 0.001 | No | DDPG | Determines the update rate of the critic and actor target network |
actor_optimizer | See supervised learning | adam with LR .0001 | No | DDPG | Determines the optimizer parameters of the actor network |
critic_optimizer | See supervised learning | adam with LR .001 | No | DDPG | Determines the optimizer parameters of the critic network |
| start_training_at | Integer | 0 | No | All | Determines at which episode the training starts |
| evaluation_samples | Integer | 100 | No | All | Determines how many epsiodes are run when evaluating the network |
#### Environment
......@@ -169,19 +176,39 @@ No buffer is used. Only the current SARS tuple is used for taining.
Combination of *online* and *buffer*. Both the current SARS tuple as well as a sample from the buffer are used for each training step. Parameters are the same as *buffer*.
### Action Selection
### Strategy
Determines the behaviour when selecting an action based on the values. (Currently, only epsilon greedy is available.)
Determines the behaviour when selecting an action based on the values.
#### Option: epsgreedy
Selects an action based on Epsilon-Greedy-Policy. This means, based on epsilon, either a random action is choosen or an action with the highest value. Additional parameters:
This strategy is only available for discrete problems. It selects an action based on Epsilon-Greedy-Policy. This means, based on epsilon, either a random action is choosen or an action with the highest Q-value. Additional parameters:
- **epsilon**: Probability of choosing an action randomly
- **epsilon_decay_method**: Method which determines how epsilon decreases after each step. Can be *linear* for linear decrease or *no* for no decrease.
- **epsilon_decay_start**: Number of Episodes after the decay of epsilon starts
- **epsilon_decay**: The actual decay of epsilon after each step.
- **min_epsilon**: After *min_epsilon* is reached, epsilon is not decreased further.
#### Option: ornstein_uhlenbeck
This strategy is only available for continuous problems. The action is selected based on the actor network. Based on the current epsilon, noise is added based on the [Ornstein-Uhlenbeck](https://en.wikipedia.org/wiki/Ornstein%E2%80%93Uhlenbeck_process) process. Additional parameters:
All epsilon parameters from epsgreedy strategy can be used. Additionally, **mu**, **theta**, and **sigma** needs to be specified. For each action output you can specify the corresponding value with a tuple-style notation: `(x,y,z)`
Example: Given an actor network with action output of shape (3,), we can write
```EMADL
strategy: ornstein_uhlenbeck{
...
mu: (0.0, 0.1, 0.3)
theta: (0.5, 0.0, 0.8)
sigma: (0.3, 0.6, -0.9)
}
```
to specify the parameters for each place.
## Generation
To execute generation in your project, use the following code to generate a separate Config file:
......
......@@ -30,7 +30,7 @@
<groupId>de.monticore.lang.monticar</groupId>
<artifactId>cnn-train</artifactId>
<version>0.3.1-SNAPSHOT</version>
<version>0.3.2-SNAPSHOT</version>
<!-- == PROJECT DEPENDENCIES ============================================= -->
......@@ -40,7 +40,7 @@
<monticore.version>5.0.1</monticore.version>
<se-commons.version>1.7.8</se-commons.version>
<mc.grammars.assembly.version>0.0.6</mc.grammars.assembly.version>
<Common-MontiCar.version>0.0.17-20180824.094114-1</Common-MontiCar.version>
<Common-MontiCar.version>0.0.17-SNAPSHOT</Common-MontiCar.version>
<!-- .. Libraries .................................................. -->
<guava.version>18.0</guava.version>
......@@ -350,4 +350,4 @@
<url>https://nexus.se.rwth-aachen.de/content/repositories/embeddedmontiarc-snapshots/</url>
</snapshotRepository>
</distributionManagement>
</project>
\ No newline at end of file
</project>
package de.monticore.lang.monticar;
grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.NumberUnit{
symbol scope CNNTrainCompilationUnit = "configuration"
name:Name&
Configuration;
......@@ -16,11 +15,20 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number
interface VariableReference;
ast VariableReference = method String getName(){};
// General Values
DataVariable implements VariableReference = Name&;
IntegerValue implements ConfigValue = NumberWithUnit;
NumberValue implements ConfigValue = NumberWithUnit;
StringValue implements ConfigValue = StringLiteral;
BooleanValue implements ConfigValue = (TRUE:"true" | FALSE:"false");
ComponentNameValue implements ConfigValue = Name ("."Name)*;
DoubleVectorValue implements ConfigValue = "(" number:NumberWithUnit ("," number:NumberWithUnit)* ")";
NumEpochEntry implements ConfigEntry = name:"num_epoch" ":" value:IntegerValue;
BatchSizeEntry implements ConfigEntry = name:"batch_size" ":" value:IntegerValue;
LoadCheckpointEntry implements ConfigEntry = name:"load_checkpoint" ":" value:BooleanValue;
NormalizeEntry implements ConfigEntry = name:"normalize" ":" value:BooleanValue;
OptimizerEntry implements ConfigEntry = name:"optimizer" ":" value:OptimizerValue;
OptimizerEntry implements ConfigEntry = (name:"optimizer" | name:"actor_optimizer") ":" value:OptimizerValue;
TrainContextEntry implements ConfigEntry = name:"context" ":" value:TrainContextValue;
EvalMetricEntry implements ConfigEntry = name:"eval_metric" ":" value:EvalMetricValue;
LossEntry implements ConfigEntry = name:"loss" ":" value:LossValue;
......@@ -46,30 +54,23 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number
| sigmoid:"sigmoid");
TrainContextValue implements ConfigValue = (cpu:"cpu" | gpu:"gpu");
DataVariable implements VariableReference = Name&;
IntegerValue implements ConfigValue = NumberWithUnit;
NumberValue implements ConfigValue = NumberWithUnit;
StringValue implements ConfigValue = StringLiteral;
BooleanValue implements ConfigValue = (TRUE:"true" | FALSE:"false");
interface OptimizerParamEntry extends Entry;
interface OptimizerValue extends ConfigValue;
interface SGDEntry extends Entry;
interface SGDEntry extends OptimizerParamEntry;
SGDOptimizer implements OptimizerValue = name:"sgd" ("{" params:SGDEntry* "}")?;
interface AdamEntry extends Entry;
interface AdamEntry extends OptimizerParamEntry;
AdamOptimizer implements OptimizerValue = name:"adam" ("{" params:AdamEntry* "}")?;
interface RmsPropEntry extends Entry;
interface RmsPropEntry extends OptimizerParamEntry;
RmsPropOptimizer implements OptimizerValue = name:"rmsprop" ("{" params:RmsPropEntry* "}")?;
interface AdaGradEntry extends Entry;
interface AdaGradEntry extends OptimizerParamEntry;
AdaGradOptimizer implements OptimizerValue = name:"adagrad" ("{" params:AdaGradEntry* "}")?;
NesterovOptimizer implements OptimizerValue = name:"nag" ("{" params:SGDEntry* "}")?;
interface AdaDeltaEntry extends Entry;
interface AdaDeltaEntry extends OptimizerParamEntry;
AdaDeltaOptimizer implements OptimizerValue = name:"adadelta" ("{" params:AdaDeltaEntry* "}")?;
interface GeneralOptimizerEntry extends SGDEntry,AdamEntry,RmsPropEntry,AdaGradEntry,AdaDeltaEntry;
......@@ -104,15 +105,11 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number
NumMaxStepsEntry implements ConfigEntry = name:"num_max_steps" ":" value:IntegerValue;
TargetScoreEntry implements ConfigEntry = name:"target_score" ":" value:NumberValue;
TrainingIntervalEntry implements ConfigEntry = name:"training_interval" ":" value:IntegerValue;
UseFixTargetNetworkEntry implements ConfigEntry = name:"use_fix_target_network" ":" value:BooleanValue;
TargetNetworkUpdateIntervalEntry implements ConfigEntry = name:"target_network_update_interval" ":" value:IntegerValue;
SnapshotIntervalEntry implements ConfigEntry = name:"snapshot_interval" ":" value:IntegerValue;
AgentNameEntry implements ConfigEntry = name:"agent_name" ":" value:StringValue;
UseDoubleDQNEntry implements ConfigEntry = name:"use_double_dqn" ":" value:BooleanValue;
RewardFunctionEntry implements ConfigEntry = name:"reward_function" ":" value:ComponentNameValue;
CriticNetworkEntry implements ConfigEntry = name:"critic" ":" value:ComponentNameValue;
ComponentNameValue implements ConfigValue = Name ("."Name)*;
StartTrainingAtEntry implements ConfigEntry = name:"start_training_at" ":" value:IntegerValue;
EvaluationSamplesEntry implements ConfigEntry = name:"evaluation_samples" ":" value:IntegerValue;
LearningMethodValue implements ConfigValue = (supervisedLearning:"supervised" | reinforcement:"reinforcement");
......@@ -137,18 +134,28 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number
MemorySizeEntry implements GeneralReplayMemoryEntry = name:"memory_size" ":" value:IntegerValue;
SampleSizeEntry implements GeneralReplayMemoryEntry = name:"sample_size" ":" value:IntegerValue;
// Action Selection
ActionSelectionEntry implements MultiParamConfigEntry = name:"action_selection" ":" value:ActionSelectionValue;
interface ActionSelectionValue extends MultiParamValue;
// Strategy
StrategyEntry implements MultiParamConfigEntry = name:"strategy" ":" value:StrategyValue;
interface StrategyValue extends MultiParamValue;
interface StrategyEpsGreedyEntry extends Entry;
StrategyEpsGreedyValue implements StrategyValue = name:"epsgreedy" ("{" params:StrategyEpsGreedyEntry* "}")?;
interface StrategyOrnsteinUhlenbeckEntry extends Entry;
StrategyOrnsteinUhlenbeckValue implements StrategyValue = name:"ornstein_uhlenbeck" ("{" params:StrategyOrnsteinUhlenbeckEntry* "}")?;
interface ActionSelectionEpsGreedyEntry extends Entry;
ActionSelectionEpsGreedyValue implements ActionSelectionValue = name:"epsgreedy" ("{" params:ActionSelectionEpsGreedyEntry* "}")?;
StrategyOUMu implements StrategyOrnsteinUhlenbeckEntry = name: "mu" ":" value:DoubleVectorValue;
StrategyOUTheta implements StrategyOrnsteinUhlenbeckEntry = name: "theta" ":" value:DoubleVectorValue;
StrategyOUSigma implements StrategyOrnsteinUhlenbeckEntry = name: "sigma" ":" value:DoubleVectorValue;
GreedyEpsilonEntry implements ActionSelectionEpsGreedyEntry = name:"epsilon" ":" value:NumberValue;
MinEpsilonEntry implements ActionSelectionEpsGreedyEntry = name:"min_epsilon" ":" value:NumberValue;
EpsilonDecayMethodEntry implements ActionSelectionEpsGreedyEntry = name:"epsilon_decay_method" ":" value:EpsilonDecayMethodValue;
interface GeneralStrategyEntry extends StrategyEpsGreedyEntry, StrategyOrnsteinUhlenbeckEntry;
GreedyEpsilonEntry implements GeneralStrategyEntry = name:"epsilon" ":" value:NumberValue;
MinEpsilonEntry implements GeneralStrategyEntry = name:"min_epsilon" ":" value:NumberValue;
EpsilonDecayStartEntry implements GeneralStrategyEntry = name:"epsilon_decay_start" ":" value:IntegerValue;
EpsilonDecayMethodEntry implements GeneralStrategyEntry = name:"epsilon_decay_method" ":" value:EpsilonDecayMethodValue;
EpsilonDecayMethodValue implements ConfigValue = (linear:"linear" | no:"no");
EpsilonDecayEntry implements ActionSelectionEpsGreedyEntry = name:"epsilon_decay" ":" value:NumberValue;
EpsilonDecayEntry implements GeneralStrategyEntry = name:"epsilon_decay" ":" value:NumberValue;
// Environment
EnvironmentEntry implements MultiParamConfigEntry = name:"environment" ":" value:EnvironmentValue;
......@@ -163,7 +170,17 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number
RosEnvironmentStateTopicEntry implements RosEnvironmentEntry = name:"state_topic" ":" value:StringValue;
RosEnvironmentActionTopicEntry implements RosEnvironmentEntry = name:"action_topic" ":" value:StringValue;
RosEnvironmentResetTopicEntry implements RosEnvironmentEntry = name:"reset_topic" ":" value:StringValue;
RosEnvironmentGreetingTopicEntry implements RosEnvironmentEntry = name:"greeting_topic" ":" value:StringValue;
RosEnvironmentMetaTopicEntry implements RosEnvironmentEntry = name:"meta_topic" ":" value:StringValue;
RosEnvironmentTerminalStateTopicEntry implements RosEnvironmentEntry = name:"terminal_state_topic" ":" value:StringValue;
RosEnvironmentRewardTopicEntry implements RosEnvironmentEntry = name:"reward_topic" ":" value:StringValue;
// DQN exclusive parameters
UseFixTargetNetworkEntry implements ConfigEntry = name:"use_fix_target_network" ":" value:BooleanValue;
TargetNetworkUpdateIntervalEntry implements ConfigEntry = name:"target_network_update_interval" ":" value:IntegerValue;
UseDoubleDQNEntry implements ConfigEntry = name:"use_double_dqn" ":" value:BooleanValue;
// DDPG exclusive parameters
CriticNetworkEntry implements ConfigEntry = name:"critic" ":" value:ComponentNameValue;
SoftTargetUpdateRateEntry implements ConfigEntry = name:"soft_target_update_rate" ":" value:NumberValue;
CriticOptimizerEntry implements ConfigEntry = name:"critic_optimizer" ":" value:OptimizerValue;
}
\ No newline at end of file
......@@ -20,9 +20,9 @@
*/
package de.monticore.lang.monticar.cnntrain._cocos;
import de.monticore.lang.monticar.cnntrain._ast.ASTConfiguration;
import de.monticore.lang.monticar.cnntrain._ast.ASTEnvironmentEntry;
import de.monticore.lang.monticar.cnntrain._ast.ASTLearningMethodEntry;
import de.monticore.lang.monticar.cnntrain._ast.*;
import java.util.Optional;
class ASTConfigurationUtils {
static boolean isReinforcementLearning(final ASTConfiguration configuration) {
......@@ -34,4 +34,54 @@ class ASTConfigurationUtils {
static boolean hasEnvironment(final ASTConfiguration configuration) {
return configuration.getEntriesList().stream().anyMatch(e -> e instanceof ASTEnvironmentEntry);
}
static boolean isDdpgAlgorithm(final ASTConfiguration configuration) {
return isReinforcementLearning(configuration)
&& configuration.getEntriesList().stream().anyMatch(
e -> (e instanceof ASTRLAlgorithmEntry) && ((ASTRLAlgorithmEntry)e).getValue().isPresentDdpg());
}
static boolean isDqnAlgorithm(final ASTConfiguration configuration) {
return isReinforcementLearning(configuration) && !isDdpgAlgorithm(configuration);
}
static boolean hasEntry(final ASTConfiguration configuration, final Class<? extends ASTConfigEntry> entryClazz) {
return configuration.getEntriesList().stream().anyMatch(entryClazz::isInstance);
}
static boolean hasStrategy(final ASTConfiguration configuration) {
return configuration.getEntriesList().stream().anyMatch(e -> e instanceof ASTStrategyEntry);
}
static Optional<String> getStrategyMethod(final ASTConfiguration configuration) {
return configuration.getEntriesList().stream()
.filter(e -> e instanceof ASTStrategyEntry)
.map(e -> (ASTStrategyEntry)e)
.findFirst()
.map(astStrategyEntry -> astStrategyEntry.getValue().getName());
}
static boolean hasRewardFunction(final ASTConfiguration node) {
return node.getEntriesList().stream().anyMatch(e -> e instanceof ASTRewardFunctionEntry);
}
static boolean hasRosEnvironment(final ASTConfiguration node) {
return ASTConfigurationUtils.hasEnvironment(node)
&& node.getEntriesList().stream()
.anyMatch(e -> (e instanceof ASTEnvironmentEntry)
&& ((ASTEnvironmentEntry)e).getValue().getName().equals("ros_interface"));
}
static boolean hasRewardTopic(final ASTConfiguration node) {
if (ASTConfigurationUtils.isReinforcementLearning(node) && ASTConfigurationUtils.hasEnvironment(node)) {
return node.getEntriesList().stream()
.filter(ASTEnvironmentEntry.class::isInstance)
.map(e -> (ASTEnvironmentEntry)e)
.reduce((element, other) -> { throw new IllegalStateException("More than one entry");})
.map(astEnvironmentEntry -> astEnvironmentEntry.getValue().getParamsList().stream()
.anyMatch(e -> e instanceof ASTRosEnvironmentRewardTopicEntry)).orElse(false);
}
return false;
}
}
......@@ -34,7 +34,11 @@ public class CNNTrainCocos {
.addCoCo(new CheckReinforcementRequiresEnvironment())
.addCoCo(new CheckLearningParameterCombination())
.addCoCo(new CheckRosEnvironmentRequiresRewardFunction())
.addCoCo(new CheckDdpgRequiresCriticNetwork());
.addCoCo(new CheckDdpgRequiresCriticNetwork())
.addCoCo(new CheckRlAlgorithmParameter())
.addCoCo(new CheckDiscreteRLAlgorithmUsesDiscreteStrategy())
.addCoCo(new CheckContinuousRLAlgorithmUsesContinuousStrategy())
.addCoCo(new CheckRosEnvironmentHasOnlyOneRewardSpecification());
}
public static void checkAll(CNNTrainCompilationUnitSymbol compilationUnit){
......
/**
*
* ******************************************************************************
* MontiCAR Modeling Family, www.se-rwth.de
* Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
* All rights reserved.
*
* This project is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 3.0 of the License, or (at your option) any later version.
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package de.monticore.lang.monticar.cnntrain._cocos;
import com.google.common.collect.ImmutableSet;
import de.monticore.lang.monticar.cnntrain._ast.ASTConfiguration;
import de.monticore.lang.monticar.cnntrain.helper.ErrorCodes;
import de.se_rwth.commons.logging.Log;
import java.util.Set;
public class CheckContinuousRLAlgorithmUsesContinuousStrategy implements CNNTrainASTConfigurationCoCo{
private static final Set<String> CONTINUOUS_STRATEGIES = ImmutableSet.<String>builder()
.add("ornstein_uhlenbeck")
.build();
@Override
public void check(ASTConfiguration node) {
if (ASTConfigurationUtils.isDdpgAlgorithm(node)
&& ASTConfigurationUtils.hasStrategy(node)
&& ASTConfigurationUtils.getStrategyMethod(node).isPresent()) {
final String usedStrategy = ASTConfigurationUtils.getStrategyMethod(node).get();
if (!CONTINUOUS_STRATEGIES.contains(usedStrategy)) {
Log.error("0" + ErrorCodes.STRATEGY_NOT_APPLICABLE + " Strategy " + usedStrategy + " used but" +
" continuous algorithm used.", node.get_SourcePositionStart());
}
}
}
}
\ No newline at end of file
/**
*
* ******************************************************************************
* MontiCAR Modeling Family, www.se-rwth.de
* Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
* All rights reserved.
*
* This project is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 3.0 of the License, or (at your option) any later version.
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package de.monticore.lang.monticar.cnntrain._cocos;
import com.google.common.collect.ImmutableSet;
import de.monticore.lang.monticar.cnntrain._ast.ASTConfiguration;
import de.monticore.lang.monticar.cnntrain.helper.ErrorCodes;
import de.se_rwth.commons.logging.Log;
import java.util.Set;
public class CheckDiscreteRLAlgorithmUsesDiscreteStrategy implements CNNTrainASTConfigurationCoCo{
private static final Set<String> DISCRETE_STRATEGIES = ImmutableSet.<String>builder()
.add("epsgreedy")
.build();
@Override
public void check(ASTConfiguration node) {
if (ASTConfigurationUtils.isDqnAlgorithm(node)
&& ASTConfigurationUtils.hasStrategy(node)
&& ASTConfigurationUtils.getStrategyMethod(node).isPresent()) {
final String usedStrategy = ASTConfigurationUtils.getStrategyMethod(node).get();
if (!DISCRETE_STRATEGIES.contains(usedStrategy)) {
Log.error("0" + ErrorCodes.STRATEGY_NOT_APPLICABLE + " Strategy " + usedStrategy + " used but" +
" discrete algorithm used.", node.get_SourcePositionStart());
}
}
}
}
......@@ -20,8 +20,9 @@
*/
package de.monticore.lang.monticar.cnntrain._cocos;
import de.monticore.lang.monticar.cnntrain._ast.ASTEntry;
import de.monticore.lang.monticar.cnntrain._ast.ASTGreedyEpsilonEntry;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Sets;
import de.monticore.lang.monticar.cnntrain._ast.*;
import de.monticore.lang.monticar.cnntrain.helper.ErrorCodes;
import de.se_rwth.commons.logging.Log;
......@@ -29,23 +30,35 @@ import java.util.HashSet;
import java.util.Set;
public class CheckEntryRepetition implements CNNTrainASTEntryCoCo {
private final static Set<Class<? extends ASTEntry>> REPEATABLE_ENTRIES = ImmutableSet
.<Class<? extends ASTEntry>>builder()
.add(ASTOptimizerParamEntry.class)
.build();
private Set<String> entryNameSet = new HashSet<>();
@Override
public void check(ASTEntry node) {
String parameterPrefix = "";
if (node instanceof ASTGreedyEpsilonEntry) {
parameterPrefix = "greedy_";
}
if (entryNameSet.contains(parameterPrefix + node.getName())){
Log.error("0" + ErrorCodes.ENTRY_REPETITION_CODE +" The parameter '" + node.getName() + "' has multiple values. " +
"Multiple assignments of the same parameter are not allowed",
node.get_SourcePositionStart());
}
else {
entryNameSet.add(parameterPrefix + node.getName());
if (!isRepeatable(node)) {
String parameterPrefix = "";
if (node instanceof ASTGreedyEpsilonEntry) {
parameterPrefix = "greedy_";
}
if (entryNameSet.contains(parameterPrefix + node.getName())) {
Log.error("0" + ErrorCodes.ENTRY_REPETITION_CODE + " The parameter '" + node.getName() + "' has multiple values. " +
"Multiple assignments of the same parameter are not allowed",
node.get_SourcePositionStart());
} else {
entryNameSet.add(parameterPrefix + node.getName());
}
}
}
private boolean isRepeatable(final ASTEntry node) {
return REPEATABLE_ENTRIES.stream().anyMatch(i -> i.isInstance(node));
}
}
......@@ -34,78 +34,7 @@ import java.util.Set;
*
*/
public class CheckLearningParameterCombination implements CNNTrainASTEntryCoCo {
private final static List<Class> ALLOWED_SUPERVISED_LEARNING = Lists.newArrayList(
ASTTrainContextEntry.class,
ASTBatchSizeEntry.class,
ASTOptimizerEntry.class,
ASTLearningRateEntry.class,
ASTLoadCheckpointEntry.class,
ASTEvalMetricEntry.class,
ASTLossEntry.class,
ASTNormalizeEntry.class,
ASTMinimumLearningRateEntry.class,
ASTLRDecayEntry.class,
ASTWeightDecayEntry.class,
ASTLRPolicyEntry.class,
ASTStepSizeEntry.class,
ASTRescaleGradEntry.class,
ASTClipGradEntry.class,
ASTGamma1Entry.class,
ASTGamma2Entry.class,
ASTEpsilonEntry.class,
ASTCenteredEntry.class,
ASTClipWeightsEntry.class,
ASTBeta1Entry.class,
ASTBeta2Entry.class,
ASTNumEpochEntry.class
);
private final static List<Class> ALLOWED_REINFORCEMENT_LEARNING = Lists.newArrayList(
ASTTrainContextEntry.class,
ASTRLAlgorithmEntry.class,
ASTCriticNetworkEntry.class,
ASTOptimizerEntry.class,
ASTRewardFunctionEntry.class,
ASTMinimumLearningRateEntry.class,
ASTLRDecayEntry.class,
ASTWeightDecayEntry.class,
ASTLRPolicyEntry.class,
ASTGamma1Entry.class,
ASTGamma2Entry.class,
ASTEpsilonEntry.class,
ASTClipGradEntry.class,
ASTRescaleGradEntry.class,
ASTStepSizeEntry.class,
ASTCenteredEntry.class,
ASTClipWeightsEntry.class,
ASTLearningRateEntry.class,
ASTDiscountFactorEntry.class,
ASTNumMaxStepsEntry.class,
ASTTargetScoreEntry.class,
ASTTrainingIntervalEntry.class,
ASTUseFixTargetNetworkEntry.class,
ASTTargetNetworkUpdateIntervalEntry.class,
ASTSnapshotIntervalEntry.class,
ASTAgentNameEntry.class,
ASTGymEnvironmentNameEntry.class,
ASTEnvironmentEntry.class,
ASTUseDoubleDQNEntry.class,
ASTLossEntry.class,
ASTReplayMemoryEntry.class,
ASTMemorySizeEntry.class,
ASTSampleSizeEntry.class,
ASTActionSelectionEntry.class,
ASTGreedyEpsilonEntry.class,