diff --git a/README.md b/README.md index f1e4b8749d515bad8b5b0d580ff9ce7ac101f6fa..f019ce9a9ef5462309bf861bfeb04bad427b320c 100644 --- a/README.md +++ b/README.md @@ -114,24 +114,31 @@ configuration ReinforcementConfig { ### Available Parameters for Reinforcement Learning -| Parameter | Value | Default | Required | Description | -|------------|--------|---------|----------|-------------| -|learning_method| reinforcement,supervised | supervised | No | Determines that this CNNTrain configuration is a reinforcement or supervised learning configuration | -| agent_name | String | "agent" | No | Names the agent (e.g. for logging output) | -|environment | gym, ros_interface | Yes | / | If *ros_interface* is selected, then the agent and the environment communicates via [ROS](http://www.ros.org/). The gym environment comes with a set of environments which are listed [here](https://gym.openai.com/) | -| context | cpu, gpu | cpu | No | Determines whether the GPU is used during training or the CPU | -| num_episodes | Integer | 50 | No | Number of episodes the agent is trained. An episode is a full passing of a game from an initial state to a terminal state.| -| num_max_steps | Integer | 99999 | No | Number of steps within an episodes before the environment is forced to reset the state (e.g. to avoid a state in which the agent is stuck) | -|discount_factor | Float | 0.9 | No | Discount factor | -| target_score | Float | None | No | If set, the agent stops the training when the average score of the last 100 episodes is greater than the target score. | -| training_interval | Integer | 1 | No | Number of steps between two trainings | -| loss | euclidean, l1, softmax_cross_entropy, sigmoid_cross_entropy, huber_loss | euclidean | No | Selects the loss function -| use_fix_target_network | bool | false | No | If set, an extra network with fixed parameters is used to estimate the Q values | -| target_network_update_interval | Integer | / | Yes, if fixed target network is true | If *use_fix_target_network* is set, it determines the number of steps after the target network is updated (Minh et. al. "Human Level Control through Deep Reinforcement Learning")| +| Parameter | Value | Default | Required | Algorithm | Description | +|------------|--------|---------|----------|-----------|-------------| +|learning_method| reinforcement,supervised | supervised | No | All | Determines that this CNNTrain configuration is a reinforcement or supervised learning configuration | +| rl_algorithm | ddpg-algorithm, dqn-algorithm | dqn-algorithm | No | All | Determines the RL algorithm that is used to train the agent +| agent_name | String | "agent" | No | All | Names the agent (e.g. for logging output) | +|environment | gym, ros_interface | Yes | / | All | If *ros_interface* is selected, then the agent and the environment communicates via [ROS](http://www.ros.org/). The gym environment comes with a set of environments which are listed [here](https://gym.openai.com/) | +| context | cpu, gpu | cpu | No | All | Determines whether the GPU is used during training or the CPU | +| num_episodes | Integer | 50 | No | All | Number of episodes the agent is trained. An episode is a full passing of a game from an initial state to a terminal state.| +| num_max_steps | Integer | 99999 | No | All | Number of steps within an episodes before the environment is forced to reset the state (e.g. to avoid a state in which the agent is stuck) | +|discount_factor | Float | 0.9 | No | All | Discount factor | +| target_score | Float | None | No | All | If set, the agent stops the training when the average score of the last 100 episodes is greater than the target score. | +| training_interval | Integer | 1 | No | All | Number of steps between two trainings | +| loss | euclidean, l1, softmax_cross_entropy, sigmoid_cross_entropy, huber_loss | euclidean | No | DQN | Selects the loss function +| use_fix_target_network | bool | false | No | DQN | If set, an extra network with fixed parameters is used to estimate the Q values | +| target_network_update_interval | Integer | / | DQN | Yes, if fixed target network is true | If *use_fix_target_network* is set, it determines the number of steps after the target network is updated (Minh et. al. "Human Level Control through Deep Reinforcement Learning")| | use_double_dqn | bool | false | No | If set, two value functions are used to determine the action values (Hasselt et. al. "Deep Reinforcement Learning with Double Q Learning") | -| replay_memory | buffer, online, combined | buffer | No | Determines the behaviour of the replay memory | -| action_selection | epsgreedy | epsgreedy | No | Determines the action selection policy during the training | -| reward_function | Full name of an EMAM component | / | Yes, if *ros_interface* is selected as the environment | The EMAM component that is used to calculate the reward. It must have two inputs, one for the current state and one boolean input that determines if the current state is terminal. It must also have exactly one output which represents the reward. | +| replay_memory | buffer, online, combined | buffer | No | All | Determines the behaviour of the replay memory | +| strategy | epsgreedy, ornstein_uhlenbeck | epsgreedy (discrete), ornstein_uhlenbeck (continuous) | No | All | Determines the action selection policy during the training | +| reward_function | Full name of an EMAM component | / | Yes, if *ros_interface* is selected as the environment and no reward topic is given | All | The EMAM component that is used to calculate the reward. It must have two inputs, one for the current state and one boolean input that determines if the current state is terminal. It must also have exactly one output which represents the reward. | +critic | Full name of architecture definition | / | Yes, if DDPG is selected | DDPG | The architecture definition which specifies the architecture of the critic network | +soft_target_update_rate | Float | 0.001 | No | DDPG | Determines the update rate of the critic and actor target network | +actor_optimizer | See supervised learning | adam with LR .0001 | No | DDPG | Determines the optimizer parameters of the actor network | +critic_optimizer | See supervised learning | adam with LR .001 | No | DDPG | Determines the optimizer parameters of the critic network | +| start_training_at | Integer | 0 | No | All | Determines at which episode the training starts | +| evaluation_samples | Integer | 100 | No | All | Determines how many epsiodes are run when evaluating the network | #### Environment @@ -169,19 +176,39 @@ No buffer is used. Only the current SARS tuple is used for taining. Combination of *online* and *buffer*. Both the current SARS tuple as well as a sample from the buffer are used for each training step. Parameters are the same as *buffer*. -### Action Selection +### Strategy -Determines the behaviour when selecting an action based on the values. (Currently, only epsilon greedy is available.) +Determines the behaviour when selecting an action based on the values. #### Option: epsgreedy -Selects an action based on Epsilon-Greedy-Policy. This means, based on epsilon, either a random action is choosen or an action with the highest value. Additional parameters: +This strategy is only available for discrete problems. It selects an action based on Epsilon-Greedy-Policy. This means, based on epsilon, either a random action is choosen or an action with the highest Q-value. Additional parameters: - **epsilon**: Probability of choosing an action randomly - **epsilon_decay_method**: Method which determines how epsilon decreases after each step. Can be *linear* for linear decrease or *no* for no decrease. +- **epsilon_decay_start**: Number of Episodes after the decay of epsilon starts - **epsilon_decay**: The actual decay of epsilon after each step. - **min_epsilon**: After *min_epsilon* is reached, epsilon is not decreased further. + +#### Option: ornstein_uhlenbeck +This strategy is only available for continuous problems. The action is selected based on the actor network. Based on the current epsilon, noise is added based on the [Ornstein-Uhlenbeck](https://en.wikipedia.org/wiki/Ornstein%E2%80%93Uhlenbeck_process) process. Additional parameters: + +All epsilon parameters from epsgreedy strategy can be used. Additionally, **mu**, **theta**, and **sigma** needs to be specified. For each action output you can specify the corresponding value with a tuple-style notation: `(x,y,z)` + +Example: Given an actor network with action output of shape (3,), we can write + +```EMADL + strategy: ornstein_uhlenbeck{ + ... + mu: (0.0, 0.1, 0.3) + theta: (0.5, 0.0, 0.8) + sigma: (0.3, 0.6, -0.9) + } +``` + +to specify the parameters for each place. + ## Generation To execute generation in your project, use the following code to generate a separate Config file: diff --git a/pom.xml b/pom.xml index 037c85e99a570538570c8ea77f95eb4dcd1dafa6..bf1e94156d6d8103eed40ce7a4ff5edb4b3f4ea9 100644 --- a/pom.xml +++ b/pom.xml @@ -30,7 +30,7 @@ de.monticore.lang.monticar cnn-train - 0.3.1-SNAPSHOT + 0.3.2-SNAPSHOT @@ -40,7 +40,7 @@ 5.0.1 1.7.8 0.0.6 - 0.0.17-20180824.094114-1 + 0.0.17-SNAPSHOT 18.0 @@ -350,4 +350,4 @@ https://nexus.se.rwth-aachen.de/content/repositories/embeddedmontiarc-snapshots/ - \ No newline at end of file + diff --git a/src/main/grammars/de/monticore/lang/monticar/CNNTrain.mc4 b/src/main/grammars/de/monticore/lang/monticar/CNNTrain.mc4 index 7409708969bc5bdbaac2f1500888a652be3abfe3..22c2e48bfc453f2b83ccb9b801abb781b7adbf3c 100644 --- a/src/main/grammars/de/monticore/lang/monticar/CNNTrain.mc4 +++ b/src/main/grammars/de/monticore/lang/monticar/CNNTrain.mc4 @@ -1,7 +1,6 @@ package de.monticore.lang.monticar; grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.NumberUnit{ - symbol scope CNNTrainCompilationUnit = "configuration" name:Name& Configuration; @@ -16,11 +15,20 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number interface VariableReference; ast VariableReference = method String getName(){}; + // General Values + DataVariable implements VariableReference = Name&; + IntegerValue implements ConfigValue = NumberWithUnit; + NumberValue implements ConfigValue = NumberWithUnit; + StringValue implements ConfigValue = StringLiteral; + BooleanValue implements ConfigValue = (TRUE:"true" | FALSE:"false"); + ComponentNameValue implements ConfigValue = Name ("."Name)*; + DoubleVectorValue implements ConfigValue = "(" number:NumberWithUnit ("," number:NumberWithUnit)* ")"; + NumEpochEntry implements ConfigEntry = name:"num_epoch" ":" value:IntegerValue; BatchSizeEntry implements ConfigEntry = name:"batch_size" ":" value:IntegerValue; LoadCheckpointEntry implements ConfigEntry = name:"load_checkpoint" ":" value:BooleanValue; NormalizeEntry implements ConfigEntry = name:"normalize" ":" value:BooleanValue; - OptimizerEntry implements ConfigEntry = name:"optimizer" ":" value:OptimizerValue; + OptimizerEntry implements ConfigEntry = (name:"optimizer" | name:"actor_optimizer") ":" value:OptimizerValue; TrainContextEntry implements ConfigEntry = name:"context" ":" value:TrainContextValue; EvalMetricEntry implements ConfigEntry = name:"eval_metric" ":" value:EvalMetricValue; LossEntry implements ConfigEntry = name:"loss" ":" value:LossValue; @@ -46,30 +54,23 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number | sigmoid:"sigmoid"); TrainContextValue implements ConfigValue = (cpu:"cpu" | gpu:"gpu"); - - DataVariable implements VariableReference = Name&; - IntegerValue implements ConfigValue = NumberWithUnit; - NumberValue implements ConfigValue = NumberWithUnit; - StringValue implements ConfigValue = StringLiteral; - BooleanValue implements ConfigValue = (TRUE:"true" | FALSE:"false"); - + interface OptimizerParamEntry extends Entry; interface OptimizerValue extends ConfigValue; - - interface SGDEntry extends Entry; + interface SGDEntry extends OptimizerParamEntry; SGDOptimizer implements OptimizerValue = name:"sgd" ("{" params:SGDEntry* "}")?; - interface AdamEntry extends Entry; + interface AdamEntry extends OptimizerParamEntry; AdamOptimizer implements OptimizerValue = name:"adam" ("{" params:AdamEntry* "}")?; - interface RmsPropEntry extends Entry; + interface RmsPropEntry extends OptimizerParamEntry; RmsPropOptimizer implements OptimizerValue = name:"rmsprop" ("{" params:RmsPropEntry* "}")?; - interface AdaGradEntry extends Entry; + interface AdaGradEntry extends OptimizerParamEntry; AdaGradOptimizer implements OptimizerValue = name:"adagrad" ("{" params:AdaGradEntry* "}")?; NesterovOptimizer implements OptimizerValue = name:"nag" ("{" params:SGDEntry* "}")?; - interface AdaDeltaEntry extends Entry; + interface AdaDeltaEntry extends OptimizerParamEntry; AdaDeltaOptimizer implements OptimizerValue = name:"adadelta" ("{" params:AdaDeltaEntry* "}")?; interface GeneralOptimizerEntry extends SGDEntry,AdamEntry,RmsPropEntry,AdaGradEntry,AdaDeltaEntry; @@ -104,15 +105,11 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number NumMaxStepsEntry implements ConfigEntry = name:"num_max_steps" ":" value:IntegerValue; TargetScoreEntry implements ConfigEntry = name:"target_score" ":" value:NumberValue; TrainingIntervalEntry implements ConfigEntry = name:"training_interval" ":" value:IntegerValue; - UseFixTargetNetworkEntry implements ConfigEntry = name:"use_fix_target_network" ":" value:BooleanValue; - TargetNetworkUpdateIntervalEntry implements ConfigEntry = name:"target_network_update_interval" ":" value:IntegerValue; SnapshotIntervalEntry implements ConfigEntry = name:"snapshot_interval" ":" value:IntegerValue; AgentNameEntry implements ConfigEntry = name:"agent_name" ":" value:StringValue; - UseDoubleDQNEntry implements ConfigEntry = name:"use_double_dqn" ":" value:BooleanValue; RewardFunctionEntry implements ConfigEntry = name:"reward_function" ":" value:ComponentNameValue; - CriticNetworkEntry implements ConfigEntry = name:"critic" ":" value:ComponentNameValue; - - ComponentNameValue implements ConfigValue = Name ("."Name)*; + StartTrainingAtEntry implements ConfigEntry = name:"start_training_at" ":" value:IntegerValue; + EvaluationSamplesEntry implements ConfigEntry = name:"evaluation_samples" ":" value:IntegerValue; LearningMethodValue implements ConfigValue = (supervisedLearning:"supervised" | reinforcement:"reinforcement"); @@ -137,18 +134,28 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number MemorySizeEntry implements GeneralReplayMemoryEntry = name:"memory_size" ":" value:IntegerValue; SampleSizeEntry implements GeneralReplayMemoryEntry = name:"sample_size" ":" value:IntegerValue; - // Action Selection - ActionSelectionEntry implements MultiParamConfigEntry = name:"action_selection" ":" value:ActionSelectionValue; - interface ActionSelectionValue extends MultiParamValue; + // Strategy + StrategyEntry implements MultiParamConfigEntry = name:"strategy" ":" value:StrategyValue; + interface StrategyValue extends MultiParamValue; + + interface StrategyEpsGreedyEntry extends Entry; + StrategyEpsGreedyValue implements StrategyValue = name:"epsgreedy" ("{" params:StrategyEpsGreedyEntry* "}")?; + + interface StrategyOrnsteinUhlenbeckEntry extends Entry; + StrategyOrnsteinUhlenbeckValue implements StrategyValue = name:"ornstein_uhlenbeck" ("{" params:StrategyOrnsteinUhlenbeckEntry* "}")?; - interface ActionSelectionEpsGreedyEntry extends Entry; - ActionSelectionEpsGreedyValue implements ActionSelectionValue = name:"epsgreedy" ("{" params:ActionSelectionEpsGreedyEntry* "}")?; + StrategyOUMu implements StrategyOrnsteinUhlenbeckEntry = name: "mu" ":" value:DoubleVectorValue; + StrategyOUTheta implements StrategyOrnsteinUhlenbeckEntry = name: "theta" ":" value:DoubleVectorValue; + StrategyOUSigma implements StrategyOrnsteinUhlenbeckEntry = name: "sigma" ":" value:DoubleVectorValue; - GreedyEpsilonEntry implements ActionSelectionEpsGreedyEntry = name:"epsilon" ":" value:NumberValue; - MinEpsilonEntry implements ActionSelectionEpsGreedyEntry = name:"min_epsilon" ":" value:NumberValue; - EpsilonDecayMethodEntry implements ActionSelectionEpsGreedyEntry = name:"epsilon_decay_method" ":" value:EpsilonDecayMethodValue; + interface GeneralStrategyEntry extends StrategyEpsGreedyEntry, StrategyOrnsteinUhlenbeckEntry; + + GreedyEpsilonEntry implements GeneralStrategyEntry = name:"epsilon" ":" value:NumberValue; + MinEpsilonEntry implements GeneralStrategyEntry = name:"min_epsilon" ":" value:NumberValue; + EpsilonDecayStartEntry implements GeneralStrategyEntry = name:"epsilon_decay_start" ":" value:IntegerValue; + EpsilonDecayMethodEntry implements GeneralStrategyEntry = name:"epsilon_decay_method" ":" value:EpsilonDecayMethodValue; EpsilonDecayMethodValue implements ConfigValue = (linear:"linear" | no:"no"); - EpsilonDecayEntry implements ActionSelectionEpsGreedyEntry = name:"epsilon_decay" ":" value:NumberValue; + EpsilonDecayEntry implements GeneralStrategyEntry = name:"epsilon_decay" ":" value:NumberValue; // Environment EnvironmentEntry implements MultiParamConfigEntry = name:"environment" ":" value:EnvironmentValue; @@ -163,7 +170,17 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number RosEnvironmentStateTopicEntry implements RosEnvironmentEntry = name:"state_topic" ":" value:StringValue; RosEnvironmentActionTopicEntry implements RosEnvironmentEntry = name:"action_topic" ":" value:StringValue; RosEnvironmentResetTopicEntry implements RosEnvironmentEntry = name:"reset_topic" ":" value:StringValue; - RosEnvironmentGreetingTopicEntry implements RosEnvironmentEntry = name:"greeting_topic" ":" value:StringValue; - RosEnvironmentMetaTopicEntry implements RosEnvironmentEntry = name:"meta_topic" ":" value:StringValue; RosEnvironmentTerminalStateTopicEntry implements RosEnvironmentEntry = name:"terminal_state_topic" ":" value:StringValue; + RosEnvironmentRewardTopicEntry implements RosEnvironmentEntry = name:"reward_topic" ":" value:StringValue; + + // DQN exclusive parameters + UseFixTargetNetworkEntry implements ConfigEntry = name:"use_fix_target_network" ":" value:BooleanValue; + TargetNetworkUpdateIntervalEntry implements ConfigEntry = name:"target_network_update_interval" ":" value:IntegerValue; + UseDoubleDQNEntry implements ConfigEntry = name:"use_double_dqn" ":" value:BooleanValue; + + + // DDPG exclusive parameters + CriticNetworkEntry implements ConfigEntry = name:"critic" ":" value:ComponentNameValue; + SoftTargetUpdateRateEntry implements ConfigEntry = name:"soft_target_update_rate" ":" value:NumberValue; + CriticOptimizerEntry implements ConfigEntry = name:"critic_optimizer" ":" value:OptimizerValue; } \ No newline at end of file diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ASTConfigurationUtils.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ASTConfigurationUtils.java index a0d4ed9b2726010e72787e83adfa9060334f831b..a955b73d59ba3836f7ff6e08e53ecbbb0f67773b 100644 --- a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ASTConfigurationUtils.java +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ASTConfigurationUtils.java @@ -20,9 +20,9 @@ */ package de.monticore.lang.monticar.cnntrain._cocos; -import de.monticore.lang.monticar.cnntrain._ast.ASTConfiguration; -import de.monticore.lang.monticar.cnntrain._ast.ASTEnvironmentEntry; -import de.monticore.lang.monticar.cnntrain._ast.ASTLearningMethodEntry; +import de.monticore.lang.monticar.cnntrain._ast.*; + +import java.util.Optional; class ASTConfigurationUtils { static boolean isReinforcementLearning(final ASTConfiguration configuration) { @@ -34,4 +34,54 @@ class ASTConfigurationUtils { static boolean hasEnvironment(final ASTConfiguration configuration) { return configuration.getEntriesList().stream().anyMatch(e -> e instanceof ASTEnvironmentEntry); } + + static boolean isDdpgAlgorithm(final ASTConfiguration configuration) { + return isReinforcementLearning(configuration) + && configuration.getEntriesList().stream().anyMatch( + e -> (e instanceof ASTRLAlgorithmEntry) && ((ASTRLAlgorithmEntry)e).getValue().isPresentDdpg()); + } + + static boolean isDqnAlgorithm(final ASTConfiguration configuration) { + return isReinforcementLearning(configuration) && !isDdpgAlgorithm(configuration); + } + + static boolean hasEntry(final ASTConfiguration configuration, final Class entryClazz) { + return configuration.getEntriesList().stream().anyMatch(entryClazz::isInstance); + } + + static boolean hasStrategy(final ASTConfiguration configuration) { + return configuration.getEntriesList().stream().anyMatch(e -> e instanceof ASTStrategyEntry); + } + + static Optional getStrategyMethod(final ASTConfiguration configuration) { + return configuration.getEntriesList().stream() + .filter(e -> e instanceof ASTStrategyEntry) + .map(e -> (ASTStrategyEntry)e) + .findFirst() + .map(astStrategyEntry -> astStrategyEntry.getValue().getName()); + } + + static boolean hasRewardFunction(final ASTConfiguration node) { + return node.getEntriesList().stream().anyMatch(e -> e instanceof ASTRewardFunctionEntry); + } + + static boolean hasRosEnvironment(final ASTConfiguration node) { + return ASTConfigurationUtils.hasEnvironment(node) + && node.getEntriesList().stream() + .anyMatch(e -> (e instanceof ASTEnvironmentEntry) + && ((ASTEnvironmentEntry)e).getValue().getName().equals("ros_interface")); + } + + static boolean hasRewardTopic(final ASTConfiguration node) { + if (ASTConfigurationUtils.isReinforcementLearning(node) && ASTConfigurationUtils.hasEnvironment(node)) { + return node.getEntriesList().stream() + .filter(ASTEnvironmentEntry.class::isInstance) + .map(e -> (ASTEnvironmentEntry)e) + .reduce((element, other) -> { throw new IllegalStateException("More than one entry");}) + .map(astEnvironmentEntry -> astEnvironmentEntry.getValue().getParamsList().stream() + .anyMatch(e -> e instanceof ASTRosEnvironmentRewardTopicEntry)).orElse(false); + + } + return false; + } } diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CNNTrainCocos.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CNNTrainCocos.java index 3d0498079d3186fb7bcff99b17791580a4c3321b..b46504214ad1d24c2724196bf82b1e48d1fa5168 100644 --- a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CNNTrainCocos.java +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CNNTrainCocos.java @@ -34,7 +34,11 @@ public class CNNTrainCocos { .addCoCo(new CheckReinforcementRequiresEnvironment()) .addCoCo(new CheckLearningParameterCombination()) .addCoCo(new CheckRosEnvironmentRequiresRewardFunction()) - .addCoCo(new CheckDdpgRequiresCriticNetwork()); + .addCoCo(new CheckDdpgRequiresCriticNetwork()) + .addCoCo(new CheckRlAlgorithmParameter()) + .addCoCo(new CheckDiscreteRLAlgorithmUsesDiscreteStrategy()) + .addCoCo(new CheckContinuousRLAlgorithmUsesContinuousStrategy()) + .addCoCo(new CheckRosEnvironmentHasOnlyOneRewardSpecification()); } public static void checkAll(CNNTrainCompilationUnitSymbol compilationUnit){ diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckContinuousRLAlgorithmUsesContinuousStrategy.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckContinuousRLAlgorithmUsesContinuousStrategy.java new file mode 100644 index 0000000000000000000000000000000000000000..5d0a68c42884d94ff26880a79a5700d7ef4f4cd9 --- /dev/null +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckContinuousRLAlgorithmUsesContinuousStrategy.java @@ -0,0 +1,47 @@ +/** + * + * ****************************************************************************** + * MontiCAR Modeling Family, www.se-rwth.de + * Copyright (c) 2017, Software Engineering Group at RWTH Aachen, + * All rights reserved. + * + * This project is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 3.0 of the License, or (at your option) any later version. + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this project. If not, see . + * ******************************************************************************* + */ +package de.monticore.lang.monticar.cnntrain._cocos; + +import com.google.common.collect.ImmutableSet; +import de.monticore.lang.monticar.cnntrain._ast.ASTConfiguration; +import de.monticore.lang.monticar.cnntrain.helper.ErrorCodes; +import de.se_rwth.commons.logging.Log; + +import java.util.Set; + +public class CheckContinuousRLAlgorithmUsesContinuousStrategy implements CNNTrainASTConfigurationCoCo{ + private static final Set CONTINUOUS_STRATEGIES = ImmutableSet.builder() + .add("ornstein_uhlenbeck") + .build(); + + @Override + public void check(ASTConfiguration node) { + if (ASTConfigurationUtils.isDdpgAlgorithm(node) + && ASTConfigurationUtils.hasStrategy(node) + && ASTConfigurationUtils.getStrategyMethod(node).isPresent()) { + final String usedStrategy = ASTConfigurationUtils.getStrategyMethod(node).get(); + if (!CONTINUOUS_STRATEGIES.contains(usedStrategy)) { + Log.error("0" + ErrorCodes.STRATEGY_NOT_APPLICABLE + " Strategy " + usedStrategy + " used but" + + " continuous algorithm used.", node.get_SourcePositionStart()); + } + } + } +} \ No newline at end of file diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckDiscreteRLAlgorithmUsesDiscreteStrategy.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckDiscreteRLAlgorithmUsesDiscreteStrategy.java new file mode 100644 index 0000000000000000000000000000000000000000..64263bb35d62401abaf9c1d8e0d8179b28636e0d --- /dev/null +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckDiscreteRLAlgorithmUsesDiscreteStrategy.java @@ -0,0 +1,47 @@ +/** + * + * ****************************************************************************** + * MontiCAR Modeling Family, www.se-rwth.de + * Copyright (c) 2017, Software Engineering Group at RWTH Aachen, + * All rights reserved. + * + * This project is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 3.0 of the License, or (at your option) any later version. + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this project. If not, see . + * ******************************************************************************* + */ +package de.monticore.lang.monticar.cnntrain._cocos; + +import com.google.common.collect.ImmutableSet; +import de.monticore.lang.monticar.cnntrain._ast.ASTConfiguration; +import de.monticore.lang.monticar.cnntrain.helper.ErrorCodes; +import de.se_rwth.commons.logging.Log; + +import java.util.Set; + +public class CheckDiscreteRLAlgorithmUsesDiscreteStrategy implements CNNTrainASTConfigurationCoCo{ + private static final Set DISCRETE_STRATEGIES = ImmutableSet.builder() + .add("epsgreedy") + .build(); + + @Override + public void check(ASTConfiguration node) { + if (ASTConfigurationUtils.isDqnAlgorithm(node) + && ASTConfigurationUtils.hasStrategy(node) + && ASTConfigurationUtils.getStrategyMethod(node).isPresent()) { + final String usedStrategy = ASTConfigurationUtils.getStrategyMethod(node).get(); + if (!DISCRETE_STRATEGIES.contains(usedStrategy)) { + Log.error("0" + ErrorCodes.STRATEGY_NOT_APPLICABLE + " Strategy " + usedStrategy + " used but" + + " discrete algorithm used.", node.get_SourcePositionStart()); + } + } + } +} diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckEntryRepetition.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckEntryRepetition.java index 07337b8dd064efa502e5a8ca2d0296582419fd66..5a3ce20d390639c36af4f2af841010881df58adf 100644 --- a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckEntryRepetition.java +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckEntryRepetition.java @@ -20,8 +20,9 @@ */ package de.monticore.lang.monticar.cnntrain._cocos; -import de.monticore.lang.monticar.cnntrain._ast.ASTEntry; -import de.monticore.lang.monticar.cnntrain._ast.ASTGreedyEpsilonEntry; +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Sets; +import de.monticore.lang.monticar.cnntrain._ast.*; import de.monticore.lang.monticar.cnntrain.helper.ErrorCodes; import de.se_rwth.commons.logging.Log; @@ -29,23 +30,35 @@ import java.util.HashSet; import java.util.Set; public class CheckEntryRepetition implements CNNTrainASTEntryCoCo { + private final static Set> REPEATABLE_ENTRIES = ImmutableSet + .>builder() + .add(ASTOptimizerParamEntry.class) + .build(); + + private Set entryNameSet = new HashSet<>(); @Override public void check(ASTEntry node) { - String parameterPrefix = ""; - if (node instanceof ASTGreedyEpsilonEntry) { - parameterPrefix = "greedy_"; - } - if (entryNameSet.contains(parameterPrefix + node.getName())){ - Log.error("0" + ErrorCodes.ENTRY_REPETITION_CODE +" The parameter '" + node.getName() + "' has multiple values. " + - "Multiple assignments of the same parameter are not allowed", - node.get_SourcePositionStart()); - } - else { - entryNameSet.add(parameterPrefix + node.getName()); + if (!isRepeatable(node)) { + String parameterPrefix = ""; + + if (node instanceof ASTGreedyEpsilonEntry) { + parameterPrefix = "greedy_"; + } + if (entryNameSet.contains(parameterPrefix + node.getName())) { + Log.error("0" + ErrorCodes.ENTRY_REPETITION_CODE + " The parameter '" + node.getName() + "' has multiple values. " + + "Multiple assignments of the same parameter are not allowed", + node.get_SourcePositionStart()); + } else { + entryNameSet.add(parameterPrefix + node.getName()); + } } } + private boolean isRepeatable(final ASTEntry node) { + return REPEATABLE_ENTRIES.stream().anyMatch(i -> i.isInstance(node)); + } + } diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckLearningParameterCombination.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckLearningParameterCombination.java index 90ea12a0c3900dcba2e40d4389f8ec30cc2cb097..63c11b46744dfe1345ed0aecd22cf293cfe69dda 100644 --- a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckLearningParameterCombination.java +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckLearningParameterCombination.java @@ -34,78 +34,7 @@ import java.util.Set; * */ public class CheckLearningParameterCombination implements CNNTrainASTEntryCoCo { - private final static List ALLOWED_SUPERVISED_LEARNING = Lists.newArrayList( - ASTTrainContextEntry.class, - ASTBatchSizeEntry.class, - ASTOptimizerEntry.class, - ASTLearningRateEntry.class, - ASTLoadCheckpointEntry.class, - ASTEvalMetricEntry.class, - ASTLossEntry.class, - ASTNormalizeEntry.class, - ASTMinimumLearningRateEntry.class, - ASTLRDecayEntry.class, - ASTWeightDecayEntry.class, - ASTLRPolicyEntry.class, - ASTStepSizeEntry.class, - ASTRescaleGradEntry.class, - ASTClipGradEntry.class, - ASTGamma1Entry.class, - ASTGamma2Entry.class, - ASTEpsilonEntry.class, - ASTCenteredEntry.class, - ASTClipWeightsEntry.class, - ASTBeta1Entry.class, - ASTBeta2Entry.class, - ASTNumEpochEntry.class - ); - private final static List ALLOWED_REINFORCEMENT_LEARNING = Lists.newArrayList( - ASTTrainContextEntry.class, - ASTRLAlgorithmEntry.class, - ASTCriticNetworkEntry.class, - ASTOptimizerEntry.class, - ASTRewardFunctionEntry.class, - ASTMinimumLearningRateEntry.class, - ASTLRDecayEntry.class, - ASTWeightDecayEntry.class, - ASTLRPolicyEntry.class, - ASTGamma1Entry.class, - ASTGamma2Entry.class, - ASTEpsilonEntry.class, - ASTClipGradEntry.class, - ASTRescaleGradEntry.class, - ASTStepSizeEntry.class, - ASTCenteredEntry.class, - ASTClipWeightsEntry.class, - ASTLearningRateEntry.class, - ASTDiscountFactorEntry.class, - ASTNumMaxStepsEntry.class, - ASTTargetScoreEntry.class, - ASTTrainingIntervalEntry.class, - ASTUseFixTargetNetworkEntry.class, - ASTTargetNetworkUpdateIntervalEntry.class, - ASTSnapshotIntervalEntry.class, - ASTAgentNameEntry.class, - ASTGymEnvironmentNameEntry.class, - ASTEnvironmentEntry.class, - ASTUseDoubleDQNEntry.class, - ASTLossEntry.class, - ASTReplayMemoryEntry.class, - ASTMemorySizeEntry.class, - ASTSampleSizeEntry.class, - ASTActionSelectionEntry.class, - ASTGreedyEpsilonEntry.class, - ASTMinEpsilonEntry.class, - ASTEpsilonDecayEntry.class, - ASTEpsilonDecayMethodEntry.class, - ASTNumEpisodesEntry.class, - ASTRosEnvironmentActionTopicEntry.class, - ASTRosEnvironmentStateTopicEntry.class, - ASTRosEnvironmentMetaTopicEntry.class, - ASTRosEnvironmentResetTopicEntry.class, - ASTRosEnvironmentTerminalStateTopicEntry.class, - ASTRosEnvironmentGreetingTopicEntry.class - ); + private final ParameterAlgorithmMapping parameterAlgorithmMapping; private Set allEntries; @@ -113,12 +42,13 @@ public class CheckLearningParameterCombination implements CNNTrainASTEntryCoCo { private LearningMethod learningMethod; public CheckLearningParameterCombination() { - this.allEntries = new HashSet<>(); - this.learningMethodKnown = false; + allEntries = new HashSet<>(); + learningMethodKnown = false; + parameterAlgorithmMapping = new ParameterAlgorithmMapping(); } private Boolean isLearningMethodKnown() { - return this.learningMethodKnown; + return learningMethodKnown; } @Override @@ -132,18 +62,20 @@ public class CheckLearningParameterCombination implements CNNTrainASTEntryCoCo { private void evaluateEntry(ASTEntry node) { allEntries.add(node); - final Boolean supervisedLearningParameter = ALLOWED_SUPERVISED_LEARNING.contains(node.getClass()); - final Boolean reinforcementLearningParameter = ALLOWED_REINFORCEMENT_LEARNING.contains(node.getClass()); + final boolean supervisedLearningParameter + = parameterAlgorithmMapping.isSupervisedLearningParameter(node.getClass()); + final boolean reinforcementLearningParameter + = parameterAlgorithmMapping.isReinforcementLearningParameter(node.getClass()); - assert (supervisedLearningParameter || reinforcementLearningParameter) : + assert (supervisedLearningParameter || reinforcementLearningParameter) : "Parameter " + node.getName() + " is not checkable, because it is unknown to Condition"; - if (supervisedLearningParameter && reinforcementLearningParameter) { - return; - } else if (supervisedLearningParameter && !reinforcementLearningParameter) { + + if (supervisedLearningParameter && !reinforcementLearningParameter) { setLearningMethodOrLogErrorIfActualLearningMethodIsNotSupervised(node); - } else if (!supervisedLearningParameter && reinforcementLearningParameter) { + } else if(!supervisedLearningParameter) { setLearningMethodOrLogErrorIfActualLearningMethodIsNotReinforcement(node); } + } private void setLearningMethodOrLogErrorIfActualLearningMethodIsNotReinforcement(ASTEntry node) { @@ -203,11 +135,12 @@ public class CheckLearningParameterCombination implements CNNTrainASTEntryCoCo { private List getAllowedParametersByLearningMethod(final LearningMethod learningMethod) { if (learningMethod.equals(LearningMethod.REINFORCEMENT)) { - return ALLOWED_REINFORCEMENT_LEARNING; + return parameterAlgorithmMapping.getAllReinforcementParameters(); } - return ALLOWED_SUPERVISED_LEARNING; + return parameterAlgorithmMapping.getAllSupervisedParameters(); } + private void setLearningMethod(final LearningMethod learningMethod) { if (learningMethod.equals(LearningMethod.REINFORCEMENT)) { setLearningMethodToReinforcement(); diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckRlAlgorithmParameter.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckRlAlgorithmParameter.java new file mode 100644 index 0000000000000000000000000000000000000000..85d0bd1f64d18616808a7103db1f4b5bc91a439b --- /dev/null +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckRlAlgorithmParameter.java @@ -0,0 +1,90 @@ +/** + * + * ****************************************************************************** + * MontiCAR Modeling Family, www.se-rwth.de + * Copyright (c) 2017, Software Engineering Group at RWTH Aachen, + * All rights reserved. + * + * This project is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 3.0 of the License, or (at your option) any later version. + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this project. If not, see . + * ******************************************************************************* + */ +package de.monticore.lang.monticar.cnntrain._cocos; + +import de.monticore.lang.monticar.cnntrain._ast.ASTEntry; +import de.monticore.lang.monticar.cnntrain._ast.ASTRLAlgorithmEntry; +import de.monticore.lang.monticar.cnntrain._symboltable.RLAlgorithm; +import de.monticore.lang.monticar.cnntrain.helper.ErrorCodes; +import de.se_rwth.commons.logging.Log; + +public class CheckRlAlgorithmParameter implements CNNTrainASTEntryCoCo { + private final ParameterAlgorithmMapping parameterAlgorithmMapping; + + boolean algorithmKnown; + RLAlgorithm algorithm; + + public CheckRlAlgorithmParameter() { + parameterAlgorithmMapping = new ParameterAlgorithmMapping(); + algorithmKnown = false; + algorithm = null; + } + + + @Override + public void check(ASTEntry node) { + final boolean isDdpgParameter = parameterAlgorithmMapping.isDdpgParameter(node.getClass()); + final boolean isDqnParameter = parameterAlgorithmMapping.isDqnParameter(node.getClass()); + + if (node instanceof ASTRLAlgorithmEntry) { + ASTRLAlgorithmEntry algorithmEntry = (ASTRLAlgorithmEntry)node; + if (algorithmEntry.getValue().isPresentDdpg()) { + setAlgorithmToDdpg(node); + } else { + setAlgorithmToDqn(node); + } + } else { + if (isDdpgParameter && !isDqnParameter) { + setAlgorithmToDdpg(node); + } else if (!isDdpgParameter && isDqnParameter) { + setAlgorithmToDqn(node); + } + } + } + + private void logErrorIfAlgorithmIsDqn(final ASTEntry node) { + if (algorithmKnown && algorithm.equals(RLAlgorithm.DQN)) { + Log.error("0" + ErrorCodes.UNSUPPORTED_PARAMETER + + " DDPG Parameter " + node.getName() + " used but algorithm is " + algorithm + ".", + node.get_SourcePositionStart()); + } + } + + private void setAlgorithmToDdpg(final ASTEntry node) { + logErrorIfAlgorithmIsDqn(node); + algorithmKnown = true; + algorithm = RLAlgorithm.DDPG; + } + + private void setAlgorithmToDqn(final ASTEntry node) { + logErrorIfAlgorithmIsDdpg(node); + algorithmKnown = true; + algorithm = RLAlgorithm.DQN; + } + + private void logErrorIfAlgorithmIsDdpg(final ASTEntry node) { + if (algorithmKnown && algorithm.equals(RLAlgorithm.DDPG)) { + Log.error("0" + ErrorCodes.UNSUPPORTED_PARAMETER + + " DQN Parameter " + node.getName() + " used but algorithm is " + algorithm + ".", + node.get_SourcePositionStart()); + } + } +} \ No newline at end of file diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckRosEnvironmentHasOnlyOneRewardSpecification.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckRosEnvironmentHasOnlyOneRewardSpecification.java new file mode 100644 index 0000000000000000000000000000000000000000..d3ef290751ce2a591279713ba342cf7c651e7edd --- /dev/null +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckRosEnvironmentHasOnlyOneRewardSpecification.java @@ -0,0 +1,42 @@ +/** + * + * ****************************************************************************** + * MontiCAR Modeling Family, www.se-rwth.de + * Copyright (c) 2017, Software Engineering Group at RWTH Aachen, + * All rights reserved. + * + * This project is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 3.0 of the License, or (at your option) any later version. + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this project. If not, see . + * ******************************************************************************* + */ +package de.monticore.lang.monticar.cnntrain._cocos; + +import de.monticore.lang.monticar.cnntrain._ast.ASTConfiguration; +import de.monticore.lang.monticar.cnntrain.helper.ErrorCodes; +import de.se_rwth.commons.logging.Log; + +import static de.monticore.lang.monticar.cnntrain._cocos.ASTConfigurationUtils.*; + +public class CheckRosEnvironmentHasOnlyOneRewardSpecification implements CNNTrainASTConfigurationCoCo { + @Override + public void check(final ASTConfiguration node) { + if (isReinforcementLearning(node) + && hasRosEnvironment(node) + && hasRewardFunction(node) + && hasRewardTopic(node)) { + Log.error("0" + ErrorCodes.CONTRADICTING_PARAMETERS + + " Multiple reward calculation method specified. Either use a reward function component with " + + "parameter reward_function or use ROS topic with parameter reward_topic. " + + "Both is not possible"); + } + } +} \ No newline at end of file diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckRosEnvironmentRequiresRewardFunction.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckRosEnvironmentRequiresRewardFunction.java index 9903507f23d91213d5a363847105c64a0c950ae3..1f6b3ec762be27c7a495c3c96103425b081196a7 100644 --- a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckRosEnvironmentRequiresRewardFunction.java +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/CheckRosEnvironmentRequiresRewardFunction.java @@ -21,34 +21,23 @@ package de.monticore.lang.monticar.cnntrain._cocos; import de.monticore.lang.monticar.cnntrain._ast.*; -import de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol; -import de.monticore.lang.monticar.cnntrain._symboltable.EntrySymbol; -import de.monticore.lang.monticar.cnntrain._symboltable.Environment; -import de.monticore.lang.monticar.cnntrain._symboltable.LearningMethod; import de.monticore.lang.monticar.cnntrain.helper.ErrorCodes; import de.se_rwth.commons.logging.Log; -import java.util.Map; +import static de.monticore.lang.monticar.cnntrain._cocos.ASTConfigurationUtils.*; public class CheckRosEnvironmentRequiresRewardFunction implements CNNTrainASTConfigurationCoCo { @Override public void check(final ASTConfiguration node) { - if (ASTConfigurationUtils.isReinforcementLearning(node) - && hasRosEnvironment(node) - && !hasRewardFunction(node)) { - Log.error("0" + ErrorCodes.REQUIRED_PARAMETER_MISSING + " The required parameter reward_function" + - " is missing"); + // Specification of reward function only required for reinforcement learning via ROS since OpenAI Gym defines + // their own reward functions + if (isReinforcementLearning(node) && hasRosEnvironment(node)) { + // Reward needs to be either be calculated with a custom component or + if (!hasRewardFunction(node) && !hasRewardTopic(node)) { + Log.error("0" + ErrorCodes.REQUIRED_PARAMETER_MISSING + + " Reward function is missing. Either add a reward function component with parameter " + + "reward_function or add a ROS topic with parameter reward_topic."); + } } } - - private boolean hasRewardFunction(final ASTConfiguration node) { - return node.getEntriesList().stream().anyMatch(e -> e instanceof ASTRewardFunctionEntry); - } - - private boolean hasRosEnvironment(final ASTConfiguration node) { - return ASTConfigurationUtils.hasEnvironment(node) - && node.getEntriesList().stream() - .anyMatch(e -> (e instanceof ASTEnvironmentEntry) - && ((ASTEnvironmentEntry)e).getValue().getName().equals("ros_interface")); - } } \ No newline at end of file diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ParameterAlgorithmMapping.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ParameterAlgorithmMapping.java new file mode 100644 index 0000000000000000000000000000000000000000..6c930ff48e123498de1facbbbaa5cfed3bd5583f --- /dev/null +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_cocos/ParameterAlgorithmMapping.java @@ -0,0 +1,148 @@ +/** + * + * ****************************************************************************** + * MontiCAR Modeling Family, www.se-rwth.de + * Copyright (c) 2017, Software Engineering Group at RWTH Aachen, + * All rights reserved. + * + * This project is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 3.0 of the License, or (at your option) any later version. + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this project. If not, see . + * ******************************************************************************* + */ +package de.monticore.lang.monticar.cnntrain._cocos; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Lists; +import de.monticore.lang.monticar.cnntrain._ast.*; + +import java.util.List; + +class ParameterAlgorithmMapping { + private static final List GENERAL_PARAMETERS = Lists.newArrayList( + ASTTrainContextEntry.class, + ASTOptimizerEntry.class, + ASTLearningRateEntry.class, + ASTMinimumLearningRateEntry.class, + ASTLRDecayEntry.class, + ASTWeightDecayEntry.class, + ASTLRPolicyEntry.class, + ASTStepSizeEntry.class, + ASTRescaleGradEntry.class, + ASTClipGradEntry.class, + ASTGamma1Entry.class, + ASTGamma2Entry.class, + ASTEpsilonEntry.class, + ASTCenteredEntry.class, + ASTClipWeightsEntry.class, + ASTBeta1Entry.class, + ASTBeta2Entry.class + ); + + private static final List EXCLUSIVE_SUPERVISED_PARAMETERS = Lists.newArrayList( + ASTBatchSizeEntry.class, + ASTLoadCheckpointEntry.class, + ASTEvalMetricEntry.class, + ASTNormalizeEntry.class, + ASTNumEpochEntry.class, + ASTLossEntry.class + ); + + private static final List GENERAL_REINFORCEMENT_PARAMETERS = Lists.newArrayList( + ASTRLAlgorithmEntry.class, + ASTRewardFunctionEntry.class, + ASTDiscountFactorEntry.class, + ASTNumMaxStepsEntry.class, + ASTTargetScoreEntry.class, + ASTTrainingIntervalEntry.class, + ASTAgentNameEntry.class, + ASTGymEnvironmentNameEntry.class, + ASTEnvironmentEntry.class, + ASTReplayMemoryEntry.class, + ASTMemorySizeEntry.class, + ASTSampleSizeEntry.class, + ASTStrategyEntry.class, + ASTGreedyEpsilonEntry.class, + ASTMinEpsilonEntry.class, + ASTEpsilonDecayEntry.class, + ASTEpsilonDecayMethodEntry.class, + ASTNumEpisodesEntry.class, + ASTRosEnvironmentActionTopicEntry.class, + ASTRosEnvironmentStateTopicEntry.class, + ASTRosEnvironmentResetTopicEntry.class, + ASTRosEnvironmentTerminalStateTopicEntry.class, + ASTRosEnvironmentRewardTopicEntry.class, + ASTStartTrainingAtEntry.class, + ASTEvaluationSamplesEntry.class, + ASTEpsilonDecayStartEntry.class, + ASTSnapshotIntervalEntry.class + ); + + private static final List EXCLUSIVE_DQN_PARAMETERS = Lists.newArrayList( + ASTUseFixTargetNetworkEntry.class, + ASTTargetNetworkUpdateIntervalEntry.class, + ASTUseDoubleDQNEntry.class, + ASTLossEntry.class + ); + + private static final List EXCLUSIVE_DDPG_PARAMETERS = Lists.newArrayList( + ASTCriticNetworkEntry.class, + ASTSoftTargetUpdateRateEntry.class, + ASTCriticOptimizerEntry.class, + ASTStrategyOUMu.class, + ASTStrategyOUTheta.class, + ASTStrategyOUSigma.class + ); + + ParameterAlgorithmMapping() { + + } + + boolean isReinforcementLearningParameter(Class entryClazz) { + return GENERAL_PARAMETERS.contains(entryClazz) + || GENERAL_REINFORCEMENT_PARAMETERS.contains(entryClazz) + || EXCLUSIVE_DQN_PARAMETERS.contains(entryClazz) + || EXCLUSIVE_DDPG_PARAMETERS.contains(entryClazz); + } + + boolean isSupervisedLearningParameter(Class entryClazz) { + return GENERAL_PARAMETERS.contains(entryClazz) + || EXCLUSIVE_SUPERVISED_PARAMETERS.contains(entryClazz); + } + + boolean isDqnParameter(Class entryClazz) { + return GENERAL_PARAMETERS.contains(entryClazz) + || GENERAL_REINFORCEMENT_PARAMETERS.contains(entryClazz) + || EXCLUSIVE_DQN_PARAMETERS.contains(entryClazz); + } + + boolean isDdpgParameter(Class entryClazz) { + return GENERAL_PARAMETERS.contains(entryClazz) + || GENERAL_REINFORCEMENT_PARAMETERS.contains(entryClazz) + || EXCLUSIVE_DDPG_PARAMETERS.contains(entryClazz); + } + + List getAllReinforcementParameters() { + return ImmutableList. builder() + .addAll(GENERAL_PARAMETERS) + .addAll(GENERAL_REINFORCEMENT_PARAMETERS) + .addAll(EXCLUSIVE_DQN_PARAMETERS) + .addAll(EXCLUSIVE_DDPG_PARAMETERS) + .build(); + } + + List getAllSupervisedParameters() { + return ImmutableList. builder() + .addAll(GENERAL_PARAMETERS) + .addAll(EXCLUSIVE_SUPERVISED_PARAMETERS) + .build(); + } +} \ No newline at end of file diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/CNNTrainSymbolTableCreator.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/CNNTrainSymbolTableCreator.java index 2c45115a84fd4a03f90dcca3a8ec372eb03310dd..50b4ab43418f98c3b1c7e9b6d532578a03ad1b68 100644 --- a/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/CNNTrainSymbolTableCreator.java +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/CNNTrainSymbolTableCreator.java @@ -29,6 +29,7 @@ import de.monticore.symboltable.ResolvingConfiguration; import de.se_rwth.commons.logging.Log; import java.util.*; +import java.util.stream.Collectors; public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP { @@ -96,11 +97,29 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP { OptimizerParamSymbol param = new OptimizerParamSymbol(); OptimizerParamValueSymbol valueSymbol = (OptimizerParamValueSymbol) nodeParam.getValue().getSymbolOpt().get(); param.setValue(valueSymbol); - configuration.getOptimizer().getOptimizerParamMap().put(nodeParam.getName(), param);; + configuration.getOptimizer().getOptimizerParamMap().put(nodeParam.getName(), param); } } + @Override + public void visit(ASTCriticOptimizerEntry node) { + OptimizerSymbol optimizerSymbol = new OptimizerSymbol(node.getValue().getName()); + configuration.setCriticOptimizer(optimizerSymbol); + addToScopeAndLinkWithNode(optimizerSymbol, node); + } + + @Override + public void endVisit(ASTCriticOptimizerEntry node) { + assert configuration.getCriticOptimizer().isPresent(): "Critic optimizer not present"; + for (ASTEntry paramNode : node.getValue().getParamsList()) { + OptimizerParamSymbol param = new OptimizerParamSymbol(); + OptimizerParamValueSymbol valueSymbol = (OptimizerParamValueSymbol)paramNode.getValue().getSymbolOpt().get(); + param.setValue(valueSymbol); + configuration.getCriticOptimizer().get().getOptimizerParamMap().put(paramNode.getName(), param); + } + } + @Override public void endVisit(ASTNumEpochEntry node) { EntrySymbol entry = new EntrySymbol(node.getName()); @@ -440,12 +459,12 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP { } @Override - public void visit(ASTActionSelectionEntry node) { + public void visit(ASTStrategyEntry node) { processMultiParamConfigVisit(node, node.getValue().getName()); } @Override - public void endVisit(ASTActionSelectionEntry node) { + public void endVisit(ASTStrategyEntry node) { processMultiParamConfigEndVisit(node); } @@ -469,6 +488,30 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP { addToScopeAndLinkWithNode(symbol, node); } + @Override + public void visit(ASTSoftTargetUpdateRateEntry node) { + EntrySymbol entry = new EntrySymbol(node.getName()); + entry.setValue(getValueSymbolForDouble(node.getValue())); + addToScopeAndLinkWithNode(entry, node); + configuration.getEntryMap().put(node.getName(), entry); + } + + @Override + public void visit(ASTStartTrainingAtEntry node) { + EntrySymbol entry = new EntrySymbol(node.getName()); + entry.setValue(getValueSymbolForInteger(node.getValue())); + addToScopeAndLinkWithNode(entry, node); + configuration.getEntryMap().put(node.getName(), entry); + } + + @Override + public void visit(ASTEvaluationSamplesEntry node) { + EntrySymbol entry = new EntrySymbol(node.getName()); + entry.setValue(getValueSymbolForInteger(node.getValue())); + addToScopeAndLinkWithNode(entry, node); + configuration.getEntryMap().put(node.getName(), entry); + } + private void processMultiParamConfigVisit(ASTMultiParamConfigEntry node, Object value) { EntrySymbol entry = new EntrySymbol(node.getName()); MultiParamValueSymbol valueSymbol = new MultiParamValueSymbol(); @@ -504,6 +547,12 @@ public class CNNTrainSymbolTableCreator extends CNNTrainSymbolTableCreatorTOP { } else { return EpsilonDecayMethod.NO; } + } else if (configValue instanceof ASTDoubleVectorValue) { + ASTDoubleVectorValue astDoubleVectorValue = (ASTDoubleVectorValue)configValue; + return astDoubleVectorValue.getNumberList().stream() + .filter(n -> n.getNumber().isPresent()) + .map(n -> n.getNumber().get()) + .collect(Collectors.toList()); } throw new UnsupportedOperationException("Unknown Value type: " + configValue.getClass()); } diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/ConfigurationSymbol.java b/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/ConfigurationSymbol.java index f137c47a28b4cc1f4b1332a733047c006818f816..a026fbd6f1b4c7de626d241d5e4bf6906e67f60f 100644 --- a/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/ConfigurationSymbol.java +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/_symboltable/ConfigurationSymbol.java @@ -24,12 +24,14 @@ import com.google.common.collect.Lists; import de.monticore.lang.monticar.cnntrain.annotations.TrainedArchitecture; import de.monticore.symboltable.CommonScopeSpanningSymbol; +import javax.swing.text.html.Option; import java.util.*; public class ConfigurationSymbol extends CommonScopeSpanningSymbol { private Map entryMap = new HashMap<>(); private OptimizerSymbol optimizer; + private OptimizerSymbol criticOptimizer; private RewardFunctionSymbol rlRewardFunctionSymbol; private TrainedArchitecture trainedArchitecture; @@ -49,6 +51,14 @@ public class ConfigurationSymbol extends CommonScopeSpanningSymbol { this.optimizer = optimizer; } + public void setCriticOptimizer(OptimizerSymbol criticOptimizer) { + this.criticOptimizer = criticOptimizer; + } + + public Optional getCriticOptimizer() { + return Optional.ofNullable(criticOptimizer); + } + protected void setRlRewardFunction(RewardFunctionSymbol rlRewardFunctionSymbol) { this.rlRewardFunctionSymbol = rlRewardFunctionSymbol; } diff --git a/src/main/java/de/monticore/lang/monticar/cnntrain/helper/ErrorCodes.java b/src/main/java/de/monticore/lang/monticar/cnntrain/helper/ErrorCodes.java index f21894ea7de54dbf61f049bea403dff14150bf90..6aa9e5074f0324090812cb222078c1c1e110704e 100644 --- a/src/main/java/de/monticore/lang/monticar/cnntrain/helper/ErrorCodes.java +++ b/src/main/java/de/monticore/lang/monticar/cnntrain/helper/ErrorCodes.java @@ -29,4 +29,6 @@ public class ErrorCodes { public static final String UNSUPPORTED_PARAMETER = "xC8854"; public static final String RANGE_ERROR = "xC8855"; public static final String REQUIRED_PARAMETER_MISSING = "xC8856"; + public static final String STRATEGY_NOT_APPLICABLE = "xC8857"; + public static final String CONTRADICTING_PARAMETERS = "xC8858"; } \ No newline at end of file diff --git a/src/test/java/de/monticore/lang/monticar/cnntrain/ParserTest.java b/src/test/java/de/monticore/lang/monticar/cnntrain/ParserTest.java index 6e10eca2554d830fe1c3a96a23633e64ee8611e3..8adfc8df3423f96ccf600cc2a7ef78bfea7805ce 100644 --- a/src/test/java/de/monticore/lang/monticar/cnntrain/ParserTest.java +++ b/src/test/java/de/monticore/lang/monticar/cnntrain/ParserTest.java @@ -46,7 +46,8 @@ public class ParserTest { "src/test/resources/invalid_parser_tests/WrongParameterName2.cnnt", "src/test/resources/invalid_parser_tests/InvalidType.cnnt", "src/test/resources/invalid_parser_tests/InvalidOptimizer.cnnt", - "src/test/resources/invalid_parser_tests/MissingColon.cnnt") + "src/test/resources/invalid_parser_tests/MissingColon.cnnt", + "src/test/resources/invalid_parser_tests/WrongStrategyParameter.cnnt") .map(s -> Paths.get(s).toString()) .collect(Collectors.toList()); diff --git a/src/test/java/de/monticore/lang/monticar/cnntrain/cocos/AllCoCoTest.java b/src/test/java/de/monticore/lang/monticar/cnntrain/cocos/AllCoCoTest.java index 57cf32fe6072627657d1015a60a046bdf9405a6f..cd7d8f63a3d25e560bc3847c47c6930741365212 100644 --- a/src/test/java/de/monticore/lang/monticar/cnntrain/cocos/AllCoCoTest.java +++ b/src/test/java/de/monticore/lang/monticar/cnntrain/cocos/AllCoCoTest.java @@ -42,6 +42,7 @@ public class AllCoCoTest extends AbstractCoCoTest{ checkValid("valid_tests", "ReinforcementConfig"); checkValid("valid_tests", "ReinforcementConfig2"); checkValid("valid_tests", "DdpgConfig"); + checkValid("valid_tests", "ReinforcementWithRosReward"); } @Test @@ -76,5 +77,23 @@ public class AllCoCoTest extends AbstractCoCoTest{ checkInvalid(new CNNTrainCoCoChecker().addCoCo(new CheckRosEnvironmentRequiresRewardFunction()), "invalid_cocos_tests", "CheckRosEnvironmentRequiresRewardFunction", new ExpectedErrorInfo(1, ErrorCodes.REQUIRED_PARAMETER_MISSING)); + checkInvalid(new CNNTrainCoCoChecker().addCoCo(new CheckRlAlgorithmParameter()), + "invalid_cocos_tests", "CheckRLAlgorithmParameter1", + new ExpectedErrorInfo(1, ErrorCodes.UNSUPPORTED_PARAMETER)); + checkInvalid(new CNNTrainCoCoChecker().addCoCo(new CheckRlAlgorithmParameter()), + "invalid_cocos_tests", "CheckRLAlgorithmParameter2", + new ExpectedErrorInfo(1, ErrorCodes.UNSUPPORTED_PARAMETER)); + checkInvalid(new CNNTrainCoCoChecker().addCoCo(new CheckRlAlgorithmParameter()), + "invalid_cocos_tests", "CheckRLAlgorithmParameter3", + new ExpectedErrorInfo(1, ErrorCodes.UNSUPPORTED_PARAMETER)); + checkInvalid(new CNNTrainCoCoChecker().addCoCo(new CheckDiscreteRLAlgorithmUsesDiscreteStrategy()), + "invalid_cocos_tests", "CheckDiscreteRLAlgorithmUsesDiscreteStrategy", + new ExpectedErrorInfo(1, ErrorCodes.STRATEGY_NOT_APPLICABLE)); + checkInvalid(new CNNTrainCoCoChecker().addCoCo(new CheckContinuousRLAlgorithmUsesContinuousStrategy()), + "invalid_cocos_tests", "CheckContinuousRLAlgorithmUsesContinuousStrategy", + new ExpectedErrorInfo(1, ErrorCodes.STRATEGY_NOT_APPLICABLE)); + checkInvalid(new CNNTrainCoCoChecker().addCoCo(new CheckRosEnvironmentHasOnlyOneRewardSpecification()), + "invalid_cocos_tests", "CheckRosEnvironmentHasOnlyOneRewardSpecification", + new ExpectedErrorInfo(1, ErrorCodes.CONTRADICTING_PARAMETERS)); } } diff --git a/src/test/resources/invalid_cocos_tests/CheckContinuousRLAlgorithmUsesContinuousStrategy.cnnt b/src/test/resources/invalid_cocos_tests/CheckContinuousRLAlgorithmUsesContinuousStrategy.cnnt new file mode 100644 index 0000000000000000000000000000000000000000..9a14a32cf48a97e2a5803b9eead970c0823e3678 --- /dev/null +++ b/src/test/resources/invalid_cocos_tests/CheckContinuousRLAlgorithmUsesContinuousStrategy.cnnt @@ -0,0 +1,12 @@ +configuration CheckContinuousRLAlgorithmUsesContinuousStrategy { + learning_method : reinforcement + rl_algorithm: ddpg-algorithm + + environment : gym { name:"CartPole-v1" } + + strategy: epsgreedy { + epsilon: 1.0 + epsilon_decay_method: linear + epsilon_decay: 0.01 + } +} \ No newline at end of file diff --git a/src/test/resources/invalid_cocos_tests/CheckDiscreteRLAlgorithmUsesDiscreteStrategy.cnnt b/src/test/resources/invalid_cocos_tests/CheckDiscreteRLAlgorithmUsesDiscreteStrategy.cnnt new file mode 100644 index 0000000000000000000000000000000000000000..26fa2b41193aad988eb0b5771a01364d76e78ca2 --- /dev/null +++ b/src/test/resources/invalid_cocos_tests/CheckDiscreteRLAlgorithmUsesDiscreteStrategy.cnnt @@ -0,0 +1,15 @@ +configuration CheckDiscreteRLAlgorithmUsesDiscreteStrategy { + learning_method : reinforcement + rl_algorithm: dqn-algorithm + + environment : gym { name:"CartPole-v1" } + + strategy: ornstein_uhlenbeck { + epsilon: 1.0 + epsilon_decay_method: linear + epsilon_decay: 0.01 + mu: (0.0, 0.1, 0.3) + theta: (0.5, 0.0, 0.8) + sigma: (0.3, 0.6, -0.9) + } +} \ No newline at end of file diff --git a/src/test/resources/invalid_cocos_tests/CheckLearningParameterCombination1.cnnt b/src/test/resources/invalid_cocos_tests/CheckLearningParameterCombination1.cnnt index 1a2ddec561ca73418204fa6afee269bd59f0560b..b428c3849b580dd92e21e19ff17aba16389124e7 100644 --- a/src/test/resources/invalid_cocos_tests/CheckLearningParameterCombination1.cnnt +++ b/src/test/resources/invalid_cocos_tests/CheckLearningParameterCombination1.cnnt @@ -25,7 +25,7 @@ configuration CheckLearningParameterCombination1 { sample_size : 64 } - action_selection : epsgreedy{ + strategy : epsgreedy{ epsilon : 1.0 min_epsilon : 0.01 epsilon_decay_method: linear diff --git a/src/test/resources/invalid_cocos_tests/CheckRLAlgorithmParameter1.cnnt b/src/test/resources/invalid_cocos_tests/CheckRLAlgorithmParameter1.cnnt new file mode 100644 index 0000000000000000000000000000000000000000..b38eb8efb8c7b6df8b95a05278ed316e8c494a74 --- /dev/null +++ b/src/test/resources/invalid_cocos_tests/CheckRLAlgorithmParameter1.cnnt @@ -0,0 +1,7 @@ +configuration CheckRLAlgorithmParameter1 { + learning_method : reinforcement + critic : path.to.component + environment : gym { name:"CartPole-v1" } + soft_target_update_rate: 0.001 + use_double_dqn: true +} \ No newline at end of file diff --git a/src/test/resources/invalid_cocos_tests/CheckRLAlgorithmParameter2.cnnt b/src/test/resources/invalid_cocos_tests/CheckRLAlgorithmParameter2.cnnt new file mode 100644 index 0000000000000000000000000000000000000000..89a14c45afbac3e93cfd6f4af81fdc48652398df --- /dev/null +++ b/src/test/resources/invalid_cocos_tests/CheckRLAlgorithmParameter2.cnnt @@ -0,0 +1,9 @@ +configuration CheckRLAlgorithmParameter2 { + learning_method : reinforcement + rl_algorithm : ddpg-algorithm + critic : path.to.component + environment : gym { name:"CartPole-v1" } + soft_target_update_rate: 0.001 + target_network_update_interval: 400 + use_fix_target_network: true +} \ No newline at end of file diff --git a/src/test/resources/invalid_cocos_tests/CheckRLAlgorithmParameter3.cnnt b/src/test/resources/invalid_cocos_tests/CheckRLAlgorithmParameter3.cnnt new file mode 100644 index 0000000000000000000000000000000000000000..711c895d4f091ad3834fb90caa40cf69d2024803 --- /dev/null +++ b/src/test/resources/invalid_cocos_tests/CheckRLAlgorithmParameter3.cnnt @@ -0,0 +1,19 @@ +configuration CheckRLAlgorithmParameter3 { + learning_method : reinforcement + + rl_algorithm: dqn-algorithm + + agent_name : "reinforcement-agent" + + environment : gym { name:"CartPole-v1" } + + context : cpu + + num_episodes : 300 + num_max_steps : 9999 + discount_factor : 0.998 + target_score : 1000 + training_interval : 10 + + soft_target_update_rate: 0.001 +} \ No newline at end of file diff --git a/src/test/resources/invalid_cocos_tests/CheckReinforcementRequiresEnvironment.cnnt b/src/test/resources/invalid_cocos_tests/CheckReinforcementRequiresEnvironment.cnnt index dc8861a5396db5ba6988127daa169f511730158b..0cb8e2390985684cba7f62c8689269951ed2ce7b 100644 --- a/src/test/resources/invalid_cocos_tests/CheckReinforcementRequiresEnvironment.cnnt +++ b/src/test/resources/invalid_cocos_tests/CheckReinforcementRequiresEnvironment.cnnt @@ -23,7 +23,7 @@ configuration CheckReinforcementRequiresEnvironment { sample_size : 64 } - action_selection : epsgreedy{ + strategy : epsgreedy{ epsilon : 1.0 min_epsilon : 0.01 epsilon_decay_method: linear diff --git a/src/test/resources/invalid_cocos_tests/CheckRosEnvironmentHasOnlyOneRewardSpecification.cnnt b/src/test/resources/invalid_cocos_tests/CheckRosEnvironmentHasOnlyOneRewardSpecification.cnnt new file mode 100644 index 0000000000000000000000000000000000000000..2f114df649895ee0a5fd6ff41bc5967321dc637d --- /dev/null +++ b/src/test/resources/invalid_cocos_tests/CheckRosEnvironmentHasOnlyOneRewardSpecification.cnnt @@ -0,0 +1,58 @@ +configuration CheckRosEnvironmentHasOnlyOneRewardSpecification { + learning_method : reinforcement + + agent_name : "reinforcement-agent" + + environment : ros_interface { + state_topic : "state" + action_topic : "action" + reset_topic : "reset" + terminal_state_topic : "is_terminal" + reward_topic: "reward" + } + + reward_function : path.to.reward.component + + context : cpu + + num_episodes : 300 + num_max_steps : 9999 + discount_factor : 0.998 + target_score : 1000 + training_interval : 10 + + loss : huber_loss + + use_fix_target_network : true + target_network_update_interval : 100 + + use_double_dqn : true + + replay_memory : buffer{ + memory_size : 1000000 + sample_size : 64 + } + + strategy : epsgreedy{ + epsilon : 1.0 + min_epsilon : 0.01 + epsilon_decay_method: linear + epsilon_decay : 0.0001 + } + + optimizer : rmsprop{ + learning_rate : 0.001 + learning_rate_minimum : 0.00001 + weight_decay : 0.01 + learning_rate_decay : 0.9 + learning_rate_policy : step + step_size : 1000 + rescale_grad : 1.1 + clip_gradient : 10 + gamma1 : 0.9 + gamma2 : 0.9 + epsilon : 0.000001 + centered : true + clip_weights : 10 + } +} \ No newline at end of file diff --git a/src/test/resources/invalid_cocos_tests/CheckRosEnvironmentRequiresRewardFunction.cnnt b/src/test/resources/invalid_cocos_tests/CheckRosEnvironmentRequiresRewardFunction.cnnt index 71753777bae0bfacf26b49c9e6747b2c1015674c..fdb28a75172a88ee289b359c41a1a1858b5717f8 100644 --- a/src/test/resources/invalid_cocos_tests/CheckRosEnvironmentRequiresRewardFunction.cnnt +++ b/src/test/resources/invalid_cocos_tests/CheckRosEnvironmentRequiresRewardFunction.cnnt @@ -8,8 +8,6 @@ configuration CheckRosEnvironmentRequiresRewardFunction { action_topic : "action" reset_topic : "reset" terminal_state_topic : "is_terminal" - greeting_topic : "greeting" - meta_topic : "meta" } context : cpu @@ -32,7 +30,7 @@ configuration CheckRosEnvironmentRequiresRewardFunction { sample_size : 64 } - action_selection : epsgreedy{ + strategy : epsgreedy{ epsilon : 1.0 min_epsilon : 0.01 epsilon_decay_method: linear diff --git a/src/test/resources/invalid_parser_tests/WrongStrategyParameter.cnnt b/src/test/resources/invalid_parser_tests/WrongStrategyParameter.cnnt new file mode 100644 index 0000000000000000000000000000000000000000..5b40e7b6035ad037dfc38e397e28f8b813bc82fa --- /dev/null +++ b/src/test/resources/invalid_parser_tests/WrongStrategyParameter.cnnt @@ -0,0 +1,11 @@ +configuration WrongStrategyParameter { + learning_method : reinforcement + rl_algorithm: ddpg-algorithm + environment : gym { name:"CartPole-v1" } + strategy: epsgreedy { + epsilon: 1.0 + epsilon_decay_method: linear + epsilon_decay: 0.01 + mu: (0.01) + } +} \ No newline at end of file diff --git a/src/test/resources/valid_tests/DdpgConfig.cnnt b/src/test/resources/valid_tests/DdpgConfig.cnnt index 45f3b28f40bdd1a2f4efff5e2eface013d2e0df0..e75c6e541ee02e33c584b47c6d382df6dd9b42b7 100644 --- a/src/test/resources/valid_tests/DdpgConfig.cnnt +++ b/src/test/resources/valid_tests/DdpgConfig.cnnt @@ -3,4 +3,27 @@ configuration DdpgConfig { rl_algorithm : ddpg-algorithm critic : path.to.component environment : gym { name:"CartPole-v1" } + soft_target_update_rate: 0.001 + actor_optimizer : adam{ + learning_rate : 0.0001 + learning_rate_minimum : 0.00005 + learning_rate_decay : 0.9 + learning_rate_policy : step + } + critic_optimizer : rmsprop{ + learning_rate : 0.001 + learning_rate_minimum : 0.0001 + learning_rate_decay : 0.5 + learning_rate_policy : step + } + strategy : ornstein_uhlenbeck{ + epsilon: 1.0 + min_epsilon: 0.001 + epsilon_decay_method: linear + epsilon_decay : 0.0001 + epsilon_decay_start: 50 + mu: (0.0, 0.1, 0.3) + theta: (0.5, 0.0, 0.8) + sigma: (0.3, 0.6, -0.9) + } } \ No newline at end of file diff --git a/src/test/resources/valid_tests/ReinforcementConfig.cnnt b/src/test/resources/valid_tests/ReinforcementConfig.cnnt index 3f03ff5ad1b1b7ac856c9f9a163c617fb45a8afa..a1b59db944660259c5bd62b065ae45d0202ce0b4 100644 --- a/src/test/resources/valid_tests/ReinforcementConfig.cnnt +++ b/src/test/resources/valid_tests/ReinforcementConfig.cnnt @@ -12,6 +12,11 @@ configuration ReinforcementConfig { discount_factor : 0.998 target_score : 1000 training_interval : 10 + snapshot_interval: 100 + + start_training_at : 20 + + evaluation_samples: 100 loss : huber_loss @@ -25,11 +30,12 @@ configuration ReinforcementConfig { sample_size : 64 } - action_selection : epsgreedy{ + strategy : epsgreedy{ epsilon : 1.0 min_epsilon : 0.01 epsilon_decay_method: linear epsilon_decay : 0.0001 + epsilon_decay_start: 50 } optimizer : rmsprop{ diff --git a/src/test/resources/valid_tests/ReinforcementConfig2.cnnt b/src/test/resources/valid_tests/ReinforcementConfig2.cnnt index 537be2db278a6df2cd28f8ad9b552bf865e40291..0462aaa236a0d20ce802567f731ae39edc901a28 100644 --- a/src/test/resources/valid_tests/ReinforcementConfig2.cnnt +++ b/src/test/resources/valid_tests/ReinforcementConfig2.cnnt @@ -8,8 +8,6 @@ configuration ReinforcementConfig2 { action_topic : "action" reset_topic : "reset" terminal_state_topic : "is_terminal" - greeting_topic : "greeting" - meta_topic : "meta" } reward_function : path.to.reward.component @@ -34,7 +32,7 @@ configuration ReinforcementConfig2 { sample_size : 64 } - action_selection : epsgreedy{ + strategy : epsgreedy{ epsilon : 1.0 min_epsilon : 0.01 epsilon_decay_method: linear diff --git a/src/test/resources/valid_tests/ReinforcementWithRosReward.cnnt b/src/test/resources/valid_tests/ReinforcementWithRosReward.cnnt new file mode 100644 index 0000000000000000000000000000000000000000..30e0a4d3c6e9488ff27bfa4afb48684e218c6e5b --- /dev/null +++ b/src/test/resources/valid_tests/ReinforcementWithRosReward.cnnt @@ -0,0 +1,56 @@ +configuration ReinforcementWithRosReward { + learning_method : reinforcement + + agent_name : "reinforcement-agent" + + environment : ros_interface { + state_topic : "state" + action_topic : "action" + reset_topic : "reset" + terminal_state_topic : "is_terminal" + reward_topic: "reward" + } + + context : cpu + + num_episodes : 300 + num_max_steps : 9999 + discount_factor : 0.998 + target_score : 1000 + training_interval : 10 + + loss : huber_loss + + use_fix_target_network : true + target_network_update_interval : 100 + + use_double_dqn : true + + replay_memory : buffer{ + memory_size : 1000000 + sample_size : 64 + } + + strategy : epsgreedy{ + epsilon : 1.0 + min_epsilon : 0.01 + epsilon_decay_method: linear + epsilon_decay : 0.0001 + } + + optimizer : rmsprop{ + learning_rate : 0.001 + learning_rate_minimum : 0.00001 + weight_decay : 0.01 + learning_rate_decay : 0.9 + learning_rate_policy : step + step_size : 1000 + rescale_grad : 1.1 + clip_gradient : 10 + gamma1 : 0.9 + gamma2 : 0.9 + epsilon : 0.000001 + centered : true + clip_weights : 10 + } +} \ No newline at end of file