Commit 54a1d2c8 authored by Nicola Gatto's avatar Nicola Gatto Committed by Evgeny Kusmenko

Integrate TD3 Algorithm and Gaussian Noise

parent 92670da6
......@@ -117,7 +117,7 @@ configuration ReinforcementConfig {
| Parameter | Value | Default | Required | Algorithm | Description |
|------------|--------|---------|----------|-----------|-------------|
|learning_method| reinforcement,supervised | supervised | No | All | Determines that this CNNTrain configuration is a reinforcement or supervised learning configuration |
| rl_algorithm | ddpg-algorithm, dqn-algorithm | dqn-algorithm | No | All | Determines the RL algorithm that is used to train the agent
| rl_algorithm | ddpg-algorithm, dqn-algorithm, td3-algorithm | dqn-algorithm | No | All | Determines the RL algorithm that is used to train the agent
| agent_name | String | "agent" | No | All | Names the agent (e.g. for logging output) |
|environment | gym, ros_interface | Yes | / | All | If *ros_interface* is selected, then the agent and the environment communicates via [ROS](http://www.ros.org/). The gym environment comes with a set of environments which are listed [here](https://gym.openai.com/) |
| context | cpu, gpu | cpu | No | All | Determines whether the GPU is used during training or the CPU |
......@@ -133,12 +133,15 @@ configuration ReinforcementConfig {
| replay_memory | buffer, online, combined | buffer | No | All | Determines the behaviour of the replay memory |
| strategy | epsgreedy, ornstein_uhlenbeck | epsgreedy (discrete), ornstein_uhlenbeck (continuous) | No | All | Determines the action selection policy during the training |
| reward_function | Full name of an EMAM component | / | Yes, if *ros_interface* is selected as the environment and no reward topic is given | All | The EMAM component that is used to calculate the reward. It must have two inputs, one for the current state and one boolean input that determines if the current state is terminal. It must also have exactly one output which represents the reward. |
critic | Full name of architecture definition | / | Yes, if DDPG is selected | DDPG | The architecture definition which specifies the architecture of the critic network |
soft_target_update_rate | Float | 0.001 | No | DDPG | Determines the update rate of the critic and actor target network |
actor_optimizer | See supervised learning | adam with LR .0001 | No | DDPG | Determines the optimizer parameters of the actor network |
critic_optimizer | See supervised learning | adam with LR .001 | No | DDPG | Determines the optimizer parameters of the critic network |
critic | Full name of architecture definition | / | Yes, if DDPG or TD3 is selected | DDPG, TD3 | The architecture definition which specifies the architecture of the critic network |
soft_target_update_rate | Float | 0.001 | No | DDPG, TD3 | Determines the update rate of the critic and actor target network |
actor_optimizer | See supervised learning | adam with LR .0001 | No | DDPG, TD3 | Determines the optimizer parameters of the actor network |
critic_optimizer | See supervised learning | adam with LR .001 | No | DDPG, TD3 | Determines the optimizer parameters of the critic network |
| start_training_at | Integer | 0 | No | All | Determines at which episode the training starts |
| evaluation_samples | Integer | 100 | No | All | Determines how many epsiodes are run when evaluating the network |
| policy_noise | Float | 0.1 | No | TD3 | Determines the standard deviation of the noise that is added to the actions predicted by the target actor network when calculating the targets.
| noise_clip | Float | 0.5 | No | TD3 | Sets the upper and lower limit of the policy noise
policy_delay | Integer | 2 | No | TD3 | Every policy_delay of steps, the actor network and targets are updated.
#### Environment
......@@ -189,6 +192,7 @@ This strategy is only available for discrete problems. It selects an action base
- **epsilon_decay_start**: Number of Episodes after the decay of epsilon starts
- **epsilon_decay**: The actual decay of epsilon after each step.
- **min_epsilon**: After *min_epsilon* is reached, epsilon is not decreased further.
- **epsilon_decay_per_step**:Expects either true or false. If true, the decay will be performed for each step the agent executes instead of performing the decay after each episode. The default value is false
#### Option: ornstein_uhlenbeck
......@@ -209,6 +213,9 @@ Example: Given an actor network with action output of shape (3,), we can write
to specify the parameters for each place.
### Option: gaussian
This strategy is also only available for continuous problems. If this strat- egy is selected, uncorrelated Gaussian noise with zero mean is added to the current policy action selection. This strategy provides the same parameters as the epsgreedy option and the parameter **noise_variance** that determines the variance of the noise.
## Generation
To execute generation in your project, use the following code to generate a separate Config file:
......
......@@ -30,7 +30,7 @@
<groupId>de.monticore.lang.monticar</groupId>
<artifactId>cnn-train</artifactId>
<version>0.3.4-SNAPSHOT</version>
<version>0.3.6-SNAPSHOT</version>
<!-- == PROJECT DEPENDENCIES ============================================= -->
......
......@@ -145,7 +145,7 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number
LearningMethodValue implements ConfigValue = (supervisedLearning:"supervised" | reinforcement:"reinforcement");
RLAlgorithmValue implements ConfigValue = (dqn:"dqn-algorithm" | ddpg:"ddpg-algorithm");
RLAlgorithmValue implements ConfigValue = (dqn:"dqn-algorithm" | ddpg:"ddpg-algorithm" | tdThree:"td3-algorithm");
interface MultiParamConfigEntry extends ConfigEntry;
......@@ -176,17 +176,23 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number
interface StrategyOrnsteinUhlenbeckEntry extends Entry;
StrategyOrnsteinUhlenbeckValue implements StrategyValue = name:"ornstein_uhlenbeck" ("{" params:StrategyOrnsteinUhlenbeckEntry* "}")?;
interface StrategyGaussianEntry extends Entry;
StrategyGaussianValue implements StrategyValue = name:"gaussian" ("{" params:StrategyGaussianEntry* "}")?;
StrategyGaussianNoiseVarianceEntry implements StrategyGaussianEntry = name: "noise_variance" ":" value:NumberValue;
StrategyOUMu implements StrategyOrnsteinUhlenbeckEntry = name: "mu" ":" value:DoubleVectorValue;
StrategyOUTheta implements StrategyOrnsteinUhlenbeckEntry = name: "theta" ":" value:DoubleVectorValue;
StrategyOUSigma implements StrategyOrnsteinUhlenbeckEntry = name: "sigma" ":" value:DoubleVectorValue;
interface GeneralStrategyEntry extends StrategyEpsGreedyEntry, StrategyOrnsteinUhlenbeckEntry;
interface GeneralStrategyEntry extends StrategyEpsGreedyEntry, StrategyOrnsteinUhlenbeckEntry, StrategyGaussianEntry;
GreedyEpsilonEntry implements GeneralStrategyEntry = name:"epsilon" ":" value:NumberValue;
MinEpsilonEntry implements GeneralStrategyEntry = name:"min_epsilon" ":" value:NumberValue;
EpsilonDecayStartEntry implements GeneralStrategyEntry = name:"epsilon_decay_start" ":" value:IntegerValue;
EpsilonDecayMethodEntry implements GeneralStrategyEntry = name:"epsilon_decay_method" ":" value:EpsilonDecayMethodValue;
EpsilonDecayMethodValue implements ConfigValue = (linear:"linear" | no:"no");
EpsilonDecayPerStepEntry implements GeneralStrategyEntry = name:"epsilon_decay_per_step" ":" value:BooleanValue;
EpsilonDecayEntry implements GeneralStrategyEntry = name:"epsilon_decay" ":" value:NumberValue;
// Environment
......@@ -211,8 +217,13 @@ grammar CNNTrain extends de.monticore.lang.monticar.Common2, de.monticore.Number
UseDoubleDQNEntry implements ConfigEntry = name:"use_double_dqn" ":" value:BooleanValue;
// DDPG exclusive parameters
// DDPG and TD3 exclusive parameters
CriticNetworkEntry implements ConfigEntry = name:"critic" ":" value:ComponentNameValue;
SoftTargetUpdateRateEntry implements ConfigEntry = name:"soft_target_update_rate" ":" value:NumberValue;
CriticOptimizerEntry implements ConfigEntry = name:"critic_optimizer" ":" value:OptimizerValue;
// TD3 exclusive parameters
PolicyNoiseEntry implements ConfigEntry = name:"policy_noise" ":" value:NumberValue;
NoiseClipEntry implements ConfigEntry = name:"noise_clip" ":" value:NumberValue;
PolicyDelayEntry implements ConfigEntry = name:"policy_delay" ":" value:IntegerValue;
}
\ No newline at end of file
......@@ -21,6 +21,7 @@
package de.monticore.lang.monticar.cnntrain._cocos;
import de.monticore.lang.monticar.cnntrain._ast.*;
import static de.monticore.lang.monticar.cnntrain.helper.ConfigEntryNameConstants.*;
import java.util.Optional;
......@@ -41,8 +42,16 @@ class ASTConfigurationUtils {
e -> (e instanceof ASTRLAlgorithmEntry) && ((ASTRLAlgorithmEntry)e).getValue().isPresentDdpg());
}
static boolean isTd3Algorithm(final ASTConfiguration configuration) {
return isReinforcementLearning(configuration)
&& configuration.getEntriesList().stream().anyMatch(
e -> (e instanceof ASTRLAlgorithmEntry) && ((ASTRLAlgorithmEntry)e).getValue().isPresentTdThree());
}
static boolean isDqnAlgorithm(final ASTConfiguration configuration) {
return isReinforcementLearning(configuration) && !isDdpgAlgorithm(configuration);
return isReinforcementLearning(configuration)
&& !isDdpgAlgorithm(configuration)
&& !isTd3Algorithm(configuration);
}
static boolean hasEntry(final ASTConfiguration configuration, final Class<? extends ASTConfigEntry> entryClazz) {
......@@ -69,7 +78,7 @@ class ASTConfigurationUtils {
return ASTConfigurationUtils.hasEnvironment(node)
&& node.getEntriesList().stream()
.anyMatch(e -> (e instanceof ASTEnvironmentEntry)
&& ((ASTEnvironmentEntry)e).getValue().getName().equals("ros_interface"));
&& ((ASTEnvironmentEntry)e).getValue().getName().equals(ENVIRONMENT_ROS));
}
static boolean hasRewardTopic(final ASTConfiguration node) {
......@@ -84,4 +93,18 @@ class ASTConfigurationUtils {
}
return false;
}
static boolean isActorCriticAlgorithm(final ASTConfiguration node) {
return isDdpgAlgorithm(node) || isTd3Algorithm(node);
}
static boolean hasCriticEntry(final ASTConfiguration node) {
return node.getEntriesList().stream()
.anyMatch(e -> ((e instanceof ASTCriticNetworkEntry)
&& !((ASTCriticNetworkEntry)e).getValue().getNameList().isEmpty()));
}
public static boolean isContinuousAlgorithm(final ASTConfiguration node) {
return isDdpgAlgorithm(node) || isTd3Algorithm(node);
}
}
......@@ -22,6 +22,7 @@ package de.monticore.lang.monticar.cnntrain._cocos;
import de.monticore.lang.monticar.cnntrain._ast.ASTCNNTrainNode;
import de.monticore.lang.monticar.cnntrain._symboltable.CNNTrainCompilationUnitSymbol;
import de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol;
import de.se_rwth.commons.logging.Log;
public class CNNTrainCocos {
......@@ -34,7 +35,7 @@ public class CNNTrainCocos {
.addCoCo(new CheckReinforcementRequiresEnvironment())
.addCoCo(new CheckLearningParameterCombination())
.addCoCo(new CheckRosEnvironmentRequiresRewardFunction())
.addCoCo(new CheckDdpgRequiresCriticNetwork())
.addCoCo(new CheckActorCriticRequiresCriticNetwork())
.addCoCo(new CheckRlAlgorithmParameter())
.addCoCo(new CheckDiscreteRLAlgorithmUsesDiscreteStrategy())
.addCoCo(new CheckContinuousRLAlgorithmUsesContinuousStrategy())
......@@ -46,4 +47,19 @@ public class CNNTrainCocos {
int findings = Log.getFindings().size();
createChecker().checkAll(node);
}
public static void checkTrainedArchitectureCoCos(final ConfigurationSymbol configurationSymbol) {
CNNTrainConfigurationSymbolChecker checker = new CNNTrainConfigurationSymbolChecker()
.addCoCo(new CheckTrainedRlNetworkHasExactlyOneInput())
.addCoCo(new CheckTrainedRlNetworkHasExactlyOneOutput())
.addCoCo(new CheckOUParameterDimensionEqualsActionDimension());
checker.checkAll(configurationSymbol);
}
public static void checkCriticCocos(final ConfigurationSymbol configurationSymbol) {
CNNTrainConfigurationSymbolChecker checker = new CNNTrainConfigurationSymbolChecker()
.addCoCo(new CheckCriticNetworkHasExactlyAOneDimensionalOutput())
.addCoCo(new CheckCriticNetworkInputs());
checker.checkAll(configurationSymbol);
}
}
\ No newline at end of file
/**
*
* ******************************************************************************
* MontiCAR Modeling Family, www.se-rwth.de
* Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
* All rights reserved.
*
* This project is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 3.0 of the License, or (at your option) any later version.
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package de.monticore.lang.monticar.cnntrain._cocos;
import de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol;
import java.util.ArrayList;
import java.util.List;
/**
*
*/
public class CNNTrainConfigurationSymbolChecker {
private List<CNNTrainConfigurationSymbolCoCo> cocos = new ArrayList<>();
public CNNTrainConfigurationSymbolChecker addCoCo(CNNTrainConfigurationSymbolCoCo coco) {
cocos.add(coco);
return this;
}
public void checkAll(ConfigurationSymbol configurationSymbol) {
for (CNNTrainConfigurationSymbolCoCo coco : cocos) {
coco.check(configurationSymbol);
}
}
}
\ No newline at end of file
......@@ -18,16 +18,13 @@
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package de.monticore.lang.monticar.cnntrain.annotations;
package de.monticore.lang.monticar.cnntrain._cocos;
import java.util.List;
import java.util.Map;
public interface TrainedArchitecture {
public List<String> getInputs();
public List<String> getOutputs();
public Map<String, List<Integer>> getDimensions();
public Map<String, Range> getRanges();
public Map<String, String> getTypes();
import de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol;
/**
*
*/
public interface CNNTrainConfigurationSymbolCoCo {
void check(ConfigurationSymbol configurationSymbol);
}
\ No newline at end of file
......@@ -26,24 +26,23 @@ import de.monticore.lang.monticar.cnntrain._ast.ASTRLAlgorithmEntry;
import de.monticore.lang.monticar.cnntrain.helper.ErrorCodes;
import de.se_rwth.commons.logging.Log;
public class CheckDdpgRequiresCriticNetwork implements CNNTrainASTConfigurationCoCo {
import static de.monticore.lang.monticar.cnntrain._cocos.ASTConfigurationUtils.hasCriticEntry;
import static de.monticore.lang.monticar.cnntrain._cocos.ASTConfigurationUtils.isActorCriticAlgorithm;
public class CheckActorCriticRequiresCriticNetwork implements CNNTrainASTConfigurationCoCo {
@Override
public void check(ASTConfiguration node) {
boolean isDdpg = node.getEntriesList().stream()
.anyMatch(e -> e instanceof ASTRLAlgorithmEntry
&& ((ASTRLAlgorithmEntry)e).getValue().isPresentDdpg());
boolean hasCriticEntry = node.getEntriesList().stream()
.anyMatch(e -> ((e instanceof ASTCriticNetworkEntry)
&& !((ASTCriticNetworkEntry)e).getValue().getNameList().isEmpty()));
boolean isActorCritic = isActorCriticAlgorithm(node);
boolean hasCriticEntry = hasCriticEntry(node);
if (isDdpg && !hasCriticEntry) {
if (isActorCritic && !hasCriticEntry) {
ASTRLAlgorithmEntry algorithmEntry = node.getEntriesList().stream()
.filter(e -> e instanceof ASTRLAlgorithmEntry)
.map(e -> (ASTRLAlgorithmEntry)e)
.findFirst()
.orElseThrow(() -> new IllegalStateException("ASTRLAlgorithmEntry entry must be available"));
Log.error("0" + ErrorCodes.REQUIRED_PARAMETER_MISSING + " DDPG learning algorithm requires critc" +
Log.error("0" + ErrorCodes.REQUIRED_PARAMETER_MISSING + " DDPG learning algorithm requires critic" +
" network entry", algorithmEntry.get_SourcePositionStart());
}
}
......
......@@ -22,19 +22,23 @@ package de.monticore.lang.monticar.cnntrain._cocos;
import com.google.common.collect.ImmutableSet;
import de.monticore.lang.monticar.cnntrain._ast.ASTConfiguration;
import de.monticore.lang.monticar.cnntrain.helper.ConfigEntryNameConstants;
import de.monticore.lang.monticar.cnntrain.helper.ErrorCodes;
import de.se_rwth.commons.logging.Log;
import java.util.Set;
import static de.monticore.lang.monticar.cnntrain.helper.ConfigEntryNameConstants.*;
public class CheckContinuousRLAlgorithmUsesContinuousStrategy implements CNNTrainASTConfigurationCoCo{
private static final Set<String> CONTINUOUS_STRATEGIES = ImmutableSet.<String>builder()
.add("ornstein_uhlenbeck")
.add(STRATEGY_OU)
.add(STRATEGY_GAUSSIAN)
.build();
@Override
public void check(ASTConfiguration node) {
if (ASTConfigurationUtils.isDdpgAlgorithm(node)
if (ASTConfigurationUtils.isContinuousAlgorithm(node)
&& ASTConfigurationUtils.hasStrategy(node)
&& ASTConfigurationUtils.getStrategyMethod(node).isPresent()) {
final String usedStrategy = ASTConfigurationUtils.getStrategyMethod(node).get();
......
/**
*
* ******************************************************************************
* MontiCAR Modeling Family, www.se-rwth.de
* Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
* All rights reserved.
*
* This project is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 3.0 of the License, or (at your option) any later version.
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package de.monticore.lang.monticar.cnntrain._cocos;
import de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol;
import de.monticore.lang.monticar.cnntrain._symboltable.NNArchitectureSymbol;
import de.monticore.lang.monticar.cnntrain.helper.ErrorCodes;
import de.se_rwth.commons.logging.Log;
import java.util.List;
/**
*
*/
public class CheckCriticNetworkHasExactlyAOneDimensionalOutput implements CNNTrainConfigurationSymbolCoCo {
@Override
public void check(ConfigurationSymbol configurationSymbol) {
if (configurationSymbol.getCriticNetwork().isPresent()) {
NNArchitectureSymbol criticNetwork = configurationSymbol.getCriticNetwork().get();
if (criticNetwork.getOutputs().size() > 1) {
Log.error("0" + ErrorCodes.CRITIC_NETWORK_ERROR
+ " The critic network has more than one outputs", criticNetwork.getSourcePosition());
}
final String outputName = criticNetwork.getOutputs().get(0);
List<Integer> dimensions = criticNetwork.getDimensions().get(outputName);
if (dimensions.size() != 1 || dimensions.get(0) != 1) {
Log.error("0" + ErrorCodes.CRITIC_NETWORK_ERROR + " The output " + outputName
+ " of critic network is not a one-dimensional vector", configurationSymbol.getSourcePosition());
}
}
}
}
/**
*
* ******************************************************************************
* MontiCAR Modeling Family, www.se-rwth.de
* Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
* All rights reserved.
*
* This project is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 3.0 of the License, or (at your option) any later version.
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package de.monticore.lang.monticar.cnntrain._cocos;
import de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol;
import de.monticore.lang.monticar.cnntrain._symboltable.NNArchitectureSymbol;
import de.monticore.lang.monticar.cnntrain.annotations.Range;
import de.monticore.lang.monticar.cnntrain.helper.ErrorCodes;
import de.se_rwth.commons.logging.Log;
import java.util.List;
import java.util.stream.Collectors;
/**
*
*/
public class CheckCriticNetworkInputs implements CNNTrainConfigurationSymbolCoCo {
@Override
public void check(ConfigurationSymbol configurationSymbol) {
if (configurationSymbol.getCriticNetwork().isPresent()) {
if (!configurationSymbol.getTrainedArchitecture().isPresent()) {
Log.error("0" + ErrorCodes.MISSING_TRAINED_ARCHITECTURE +
"No architecture found that is trained by this configuration.", configurationSymbol.getSourcePosition());
}
NNArchitectureSymbol trainedArchitecture = configurationSymbol.getTrainedArchitecture().get();
NNArchitectureSymbol criticNetwork = configurationSymbol.getCriticNetwork().get();
if (trainedArchitecture.getInputs().size() != 1 || trainedArchitecture.getOutputs().size() != 1) {
Log.error("Malformed trained architecture");
}
final String stateInput = trainedArchitecture.getInputs().get(0);
final String actionOutput = trainedArchitecture.getOutputs().get(0);
final List<Integer> stateDimensions = trainedArchitecture.getDimensions().get(stateInput);
final List<Integer> actionDimensions = trainedArchitecture.getDimensions().get(actionOutput);
final Range stateRange = trainedArchitecture.getRanges().get(stateInput);
final Range actionRange = trainedArchitecture.getRanges().get(actionOutput);
final String stateType = trainedArchitecture.getTypes().get(stateInput);
final String actionType = trainedArchitecture.getTypes().get(actionOutput);
String criticInput1 = criticNetwork.getInputs().get(0);
String criticInput2 = criticNetwork.getInputs().get(1);
if (criticNetwork.getInputs().size() != 2) {
Log.error("0" + ErrorCodes.CRITIC_NETWORK_ERROR
+ "Number of critic network inputs is wrong. Critic network has two inputs," +
"first needs to be a state input and second needs to be the action input.");
}
if (!criticNetwork.getDimensions().get(criticInput1).equals(stateDimensions)) {
Log.error("0" + ErrorCodes.CRITIC_NETWORK_ERROR
+ " Declared critic network is not a critic: Dimensions of first input of critic architecture must be" +
" equal to state's dimensions "
+ stateDimensions.stream().map(Object::toString).collect(Collectors.joining(",", "{", "}"))
+ ".", configurationSymbol.getSourcePosition());
}
if (!criticNetwork.getDimensions().get(criticInput2).equals(actionDimensions)) {
Log.error("0" + ErrorCodes.CRITIC_NETWORK_ERROR
+ " Declared critic network is not a critic: Dimensions of second input of critic architecture must be" +
" equal to action's dimensions "
+ actionDimensions.stream().map(Object::toString).collect(Collectors.joining(",", "{", "}"))
+ ".", configurationSymbol.getSourcePosition());
}
if (!criticNetwork.getRanges().get(criticInput1).equals(stateRange)) {
Log.error("0" + ErrorCodes.CRITIC_NETWORK_ERROR
+ " Declared critic network is not a critic: Ranges of first input of critic architecture must be" +
" equal to state's ranges "
+ stateRange.toString()
+ ".", configurationSymbol.getSourcePosition());
}
if (!criticNetwork.getRanges().get(criticInput2).equals(actionRange)) {
Log.error("0" + ErrorCodes.CRITIC_NETWORK_ERROR
+ " Declared critic network is not a critic: Ranges of second input of critic architecture must be" +
" equal to action's ranges "
+ actionRange.toString()
+ ".", configurationSymbol.getSourcePosition());
}
if (!criticNetwork.getTypes().get(criticInput1).equals(stateType)) {
Log.error("0" + ErrorCodes.CRITIC_NETWORK_ERROR
+ " Declared critic network is not a critic: Type of first input of critic architecture must be" +
" equal to state's types "
+ stateType
+ ".", configurationSymbol.getSourcePosition());
}
if (!criticNetwork.getTypes().get(criticInput2).equals(actionType)) {
Log.error("0" + ErrorCodes.CRITIC_NETWORK_ERROR
+ " Declared critic network is not a critic: Type of second input of critic architecture must be" +
" equal to action's types "
+ stateType
+ ".", configurationSymbol.getSourcePosition());
}
}
}
}
......@@ -22,6 +22,7 @@ package de.monticore.lang.monticar.cnntrain._cocos;
import com.google.common.collect.ImmutableSet;
import de.monticore.lang.monticar.cnntrain._ast.ASTConfiguration;
import de.monticore.lang.monticar.cnntrain.helper.ConfigEntryNameConstants;
import de.monticore.lang.monticar.cnntrain.helper.ErrorCodes;
import de.se_rwth.commons.logging.Log;
......@@ -29,7 +30,7 @@ import java.util.Set;
public class CheckDiscreteRLAlgorithmUsesDiscreteStrategy implements CNNTrainASTConfigurationCoCo{
private static final Set<String> DISCRETE_STRATEGIES = ImmutableSet.<String>builder()
.add("epsgreedy")
.add(ConfigEntryNameConstants.STRATEGY_EPSGREEDY)
.build();
@Override
......
......@@ -23,6 +23,7 @@ package de.monticore.lang.monticar.cnntrain._cocos;
import de.monticore.lang.monticar.cnntrain._ast.*;
import de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol;
import de.monticore.lang.monticar.cnntrain._symboltable.EntrySymbol;
import de.monticore.lang.monticar.cnntrain.helper.ConfigEntryNameConstants;
import de.monticore.lang.monticar.cnntrain.helper.ErrorCodes;
import de.se_rwth.commons.logging.Log;
......@@ -33,8 +34,6 @@ import java.util.Map;
*
*/
public class CheckFixTargetNetworkRequiresInterval implements CNNTrainASTConfigurationCoCo {
private static final String PARAMETER_USE_FIX_TARGET_NETWORK = "use_fix_target_network";
private static final String PARAMETER_TARGET_NETWORK_UPDATE_INTERVAL = "target_network_update_interval";
@Override
public void check(ASTConfiguration node) {
......@@ -50,8 +49,8 @@ public class CheckFixTargetNetworkRequiresInterval implements CNNTrainASTConfigu
.map(e -> (ASTUseFixTargetNetworkEntry)e)
.findFirst()
.orElseThrow(() -> new IllegalStateException("ASTUseFixTargetNetwork entry must be available"));
Log.error("0" + ErrorCodes.REQUIRED_PARAMETER_MISSING + " Parameter " + Boolean.toString(useFixTargetNetwork)
+ " requires parameter " + PARAMETER_TARGET_NETWORK_UPDATE_INTERVAL,
Log.error("0" + ErrorCodes.REQUIRED_PARAMETER_MISSING + " Parameter " + ConfigEntryNameConstants.USE_FIX_TARGET_NETWORK
+ " requires parameter " + ConfigEntryNameConstants.TARGET_NETWORK_UPDATE_INTERVAL,
useFixTargetNetworkEntry.get_SourcePositionStart());
} else if (!useFixTargetNetwork && hasTargetNetworkUpdateInterval) {
ASTTargetNetworkUpdateIntervalEntry targetNetworkUpdateIntervalEntry = node.getEntriesList().stream()
......@@ -62,7 +61,7 @@ public class CheckFixTargetNetworkRequiresInterval implements CNNTrainASTConfigu
() -> new IllegalStateException("ASTTargetNetworkUpdateInterval entry must be available"));
Log.error("0" + ErrorCodes.REQUIRED_PARAMETER_MISSING + " Parameter "
+ targetNetworkUpdateIntervalEntry.getName() + " requires that parameter "
+ PARAMETER_USE_FIX_TARGET_NETWORK + " to be true.",
+ ConfigEntryNameConstants.USE_FIX_TARGET_NETWORK + " to be true.",
targetNetworkUpdateIntervalEntry.get_SourcePositionStart());
}
}
......