Aufgrund einer Wartung wird GitLab am 26.10. zwischen 8:00 und 9:00 Uhr kurzzeitig nicht zur Verfügung stehen. / Due to maintenance, GitLab will be temporarily unavailable on 26.10. between 8:00 and 9:00 am.

Commit d3a10207 authored by Nicola Gatto's avatar Nicola Gatto
Browse files

Add constants for parameter entries

parent 46333043
......@@ -21,6 +21,7 @@
package de.monticore.lang.monticar.cnntrain._cocos;
import de.monticore.lang.monticar.cnntrain._ast.*;
import static de.monticore.lang.monticar.cnntrain.helper.ConfigEntryNameConstants.*;
import java.util.Optional;
......@@ -77,7 +78,7 @@ class ASTConfigurationUtils {
return ASTConfigurationUtils.hasEnvironment(node)
&& node.getEntriesList().stream()
.anyMatch(e -> (e instanceof ASTEnvironmentEntry)
&& ((ASTEnvironmentEntry)e).getValue().getName().equals("ros_interface"));
&& ((ASTEnvironmentEntry)e).getValue().getName().equals(ENVIRONMENT_ROS));
}
static boolean hasRewardTopic(final ASTConfiguration node) {
......
......@@ -22,15 +22,18 @@ package de.monticore.lang.monticar.cnntrain._cocos;
import com.google.common.collect.ImmutableSet;
import de.monticore.lang.monticar.cnntrain._ast.ASTConfiguration;
import de.monticore.lang.monticar.cnntrain.helper.ConfigEntryNameConstants;
import de.monticore.lang.monticar.cnntrain.helper.ErrorCodes;
import de.se_rwth.commons.logging.Log;
import java.util.Set;
import static de.monticore.lang.monticar.cnntrain.helper.ConfigEntryNameConstants.*;
public class CheckContinuousRLAlgorithmUsesContinuousStrategy implements CNNTrainASTConfigurationCoCo{
private static final Set<String> CONTINUOUS_STRATEGIES = ImmutableSet.<String>builder()
.add("ornstein_uhlenbeck")
.add("gaussian")
.add(STRATEGY_OU)
.add(STRATEGY_GAUSSIAN)
.build();
@Override
......
......@@ -22,6 +22,7 @@ package de.monticore.lang.monticar.cnntrain._cocos;
import com.google.common.collect.ImmutableSet;
import de.monticore.lang.monticar.cnntrain._ast.ASTConfiguration;
import de.monticore.lang.monticar.cnntrain.helper.ConfigEntryNameConstants;
import de.monticore.lang.monticar.cnntrain.helper.ErrorCodes;
import de.se_rwth.commons.logging.Log;
......@@ -29,7 +30,7 @@ import java.util.Set;
public class CheckDiscreteRLAlgorithmUsesDiscreteStrategy implements CNNTrainASTConfigurationCoCo{
private static final Set<String> DISCRETE_STRATEGIES = ImmutableSet.<String>builder()
.add("epsgreedy")
.add(ConfigEntryNameConstants.STRATEGY_EPSGREEDY)
.build();
@Override
......
......@@ -23,6 +23,7 @@ package de.monticore.lang.monticar.cnntrain._cocos;
import de.monticore.lang.monticar.cnntrain._ast.*;
import de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol;
import de.monticore.lang.monticar.cnntrain._symboltable.EntrySymbol;
import de.monticore.lang.monticar.cnntrain.helper.ConfigEntryNameConstants;
import de.monticore.lang.monticar.cnntrain.helper.ErrorCodes;
import de.se_rwth.commons.logging.Log;
......@@ -33,8 +34,6 @@ import java.util.Map;
*
*/
public class CheckFixTargetNetworkRequiresInterval implements CNNTrainASTConfigurationCoCo {
private static final String PARAMETER_USE_FIX_TARGET_NETWORK = "use_fix_target_network";
private static final String PARAMETER_TARGET_NETWORK_UPDATE_INTERVAL = "target_network_update_interval";
@Override
public void check(ASTConfiguration node) {
......@@ -50,8 +49,8 @@ public class CheckFixTargetNetworkRequiresInterval implements CNNTrainASTConfigu
.map(e -> (ASTUseFixTargetNetworkEntry)e)
.findFirst()
.orElseThrow(() -> new IllegalStateException("ASTUseFixTargetNetwork entry must be available"));
Log.error("0" + ErrorCodes.REQUIRED_PARAMETER_MISSING + " Parameter " + Boolean.toString(useFixTargetNetwork)
+ " requires parameter " + PARAMETER_TARGET_NETWORK_UPDATE_INTERVAL,
Log.error("0" + ErrorCodes.REQUIRED_PARAMETER_MISSING + " Parameter " + ConfigEntryNameConstants.USE_FIX_TARGET_NETWORK
+ " requires parameter " + ConfigEntryNameConstants.TARGET_NETWORK_UPDATE_INTERVAL,
useFixTargetNetworkEntry.get_SourcePositionStart());
} else if (!useFixTargetNetwork && hasTargetNetworkUpdateInterval) {
ASTTargetNetworkUpdateIntervalEntry targetNetworkUpdateIntervalEntry = node.getEntriesList().stream()
......@@ -62,7 +61,7 @@ public class CheckFixTargetNetworkRequiresInterval implements CNNTrainASTConfigu
() -> new IllegalStateException("ASTTargetNetworkUpdateInterval entry must be available"));
Log.error("0" + ErrorCodes.REQUIRED_PARAMETER_MISSING + " Parameter "
+ targetNetworkUpdateIntervalEntry.getName() + " requires that parameter "
+ PARAMETER_USE_FIX_TARGET_NETWORK + " to be true.",
+ ConfigEntryNameConstants.USE_FIX_TARGET_NETWORK + " to be true.",
targetNetworkUpdateIntervalEntry.get_SourcePositionStart());
}
}
......
......@@ -25,6 +25,7 @@ import de.monticore.lang.monticar.cnntrain._ast.ASTEnvironmentEntry;
import de.monticore.lang.monticar.cnntrain._ast.ASTLearningMethodEntry;
import de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol;
import de.monticore.lang.monticar.cnntrain._symboltable.LearningMethod;
import de.monticore.lang.monticar.cnntrain.helper.ConfigEntryNameConstants;
import de.monticore.lang.monticar.cnntrain.helper.ErrorCodes;
import de.se_rwth.commons.logging.Log;
......@@ -32,8 +33,6 @@ import de.se_rwth.commons.logging.Log;
*
*/
public class CheckReinforcementRequiresEnvironment implements CNNTrainASTConfigurationCoCo {
private static final String PARAMETER_ENVIRONMENT = "environment";
@Override
public void check(ASTConfiguration node) {
boolean isReinforcementLearning = ASTConfigurationUtils.isReinforcementLearning(node);
......@@ -41,7 +40,7 @@ public class CheckReinforcementRequiresEnvironment implements CNNTrainASTConfigu
if (isReinforcementLearning && !hasEnvironment) {
Log.error("0" + ErrorCodes.REQUIRED_PARAMETER_MISSING + " The required parameter "
+ PARAMETER_ENVIRONMENT + " is missing");
+ ConfigEntryNameConstants.ENVIRONMENT + " is missing");
}
}
}
\ No newline at end of file
......@@ -2,6 +2,7 @@ package de.monticore.lang.monticar.cnntrain._cocos;
import de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol;
import de.monticore.lang.monticar.cnntrain._symboltable.RLAlgorithm;
import de.monticore.lang.monticar.cnntrain.helper.ConfigEntryNameConstants;
import de.monticore.lang.monticar.cnntrain.helper.ErrorCodes;
import de.se_rwth.commons.logging.Log;
......@@ -16,7 +17,7 @@ public class CheckTrainedRlNetworkHasExactlyOneInput implements CNNTrainConfigur
final int numberOfInputs = configurationSymbol.getTrainedArchitecture().get().getInputs().size();
if (numberOfInputs != 1) {
final String networkName
= configurationSymbol.getEntry("rl_algorithm").getValue().getValue()
= configurationSymbol.getEntry(ConfigEntryNameConstants.RL_ALGORITHM).getValue().getValue()
.equals(RLAlgorithm.DQN) ? "Q-Network" : "Actor-Network";
Log.error("x0" + ErrorCodes.TRAINED_ARCHITECTURE_ERROR
+ networkName + " " +configurationSymbol.getTrainedArchitecture().get().getName()
......
......@@ -2,6 +2,7 @@ package de.monticore.lang.monticar.cnntrain._cocos;
import de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol;
import de.monticore.lang.monticar.cnntrain._symboltable.RLAlgorithm;
import de.monticore.lang.monticar.cnntrain.helper.ConfigEntryNameConstants;
import de.monticore.lang.monticar.cnntrain.helper.ErrorCodes;
import de.se_rwth.commons.logging.Log;
......@@ -16,7 +17,7 @@ public class CheckTrainedRlNetworkHasExactlyOneOutput implements CNNTrainConfigu
final int numberOfOutputs = configurationSymbol.getTrainedArchitecture().get().getOutputs().size();
if (numberOfOutputs != 1) {
final String networkName
= configurationSymbol.getEntry("rl_algorithm").getValue().getValue()
= configurationSymbol.getEntry(ConfigEntryNameConstants.RL_ALGORITHM).getValue().getValue()
.equals(RLAlgorithm.DQN) ? "Q-Network" : "Actor-Network";
Log.error("x0" + ErrorCodes.TRAINED_ARCHITECTURE_ERROR
+ networkName + " " +configurationSymbol.getTrainedArchitecture().get().getName()
......
......@@ -24,6 +24,8 @@ import de.monticore.symboltable.CommonScopeSpanningSymbol;
import java.util.*;
import static de.monticore.lang.monticar.cnntrain.helper.ConfigEntryNameConstants.*;
public class ConfigurationSymbol extends CommonScopeSpanningSymbol {
private Map<String, EntrySymbol> entryMap = new HashMap<>();
......@@ -99,8 +101,8 @@ public class ConfigurationSymbol extends CommonScopeSpanningSymbol {
}
public LearningMethod getLearningMethod() {
return this.entryMap.containsKey("learning_method")
? (LearningMethod)this.entryMap.get("learning_method").getValue().getValue() : LearningMethod.SUPERVISED;
return this.entryMap.containsKey(LEARNING_METHOD)
? (LearningMethod)this.entryMap.get(LEARNING_METHOD).getValue().getValue() : LearningMethod.SUPERVISED;
}
public boolean isReinforcementLearningMethod() {
......@@ -108,7 +110,7 @@ public class ConfigurationSymbol extends CommonScopeSpanningSymbol {
}
public boolean hasCritic() {
return getEntryMap().containsKey("critic");
return getEntryMap().containsKey(CRITIC);
}
public Optional<String> getCriticName() {
......@@ -116,7 +118,7 @@ public class ConfigurationSymbol extends CommonScopeSpanningSymbol {
return Optional.empty();
}
final Object criticNameValue = getEntry("critic").getValue().getValue();
final Object criticNameValue = getEntry(CRITIC).getValue().getValue();
assert criticNameValue instanceof String;
return Optional.of((String)criticNameValue);
}
......
package de.monticore.lang.monticar.cnntrain.helper;
/**
*
*/
public class ConfigEntryNameConstants {
public static final String LEARNING_METHOD = "learning_method";
public static final String NUM_EPISODES = "num_episodes";
public static final String DISCOUNT_FACTOR = "discount_factor";
public static final String NUM_MAX_STEPS = "num_max_steps";
public static final String TARGET_SCORE = "target_score";
public static final String TRAINING_INTERVAL = "training_interval";
public static final String USE_FIX_TARGET_NETWORK = "use_fix_target_network";
public static final String TARGET_NETWORK_UPDATE_INTERVAL = "target_network_update_interval";
public static final String SNAPSHOT_INTERVAL = "snapshot_interval";
public static final String AGENT_NAME = "agent_name";
public static final String USE_DOUBLE_DQN = "use_double_dqn";
public static final String LOSS = "loss";
public static final String RL_ALGORITHM = "rl_algorithm";
public static final String REPLAY_MEMORY = "replay_memory";
public static final String ENVIRONMENT = "environment";
public static final String START_TRAINING_AT = "start_training_at";
public static final String SOFT_TARGET_UPDATE_RATE = "soft_target_update_rate";
public static final String EVALUATION_SAMPLES = "evaluation_samples";
public static final String POLICY_NOISE = "policy_noise";
public static final String NOISE_CLIP = "noise_clip";
public static final String POLICY_DELAY = "policy_delay";
public static final String ENVIRONMENT_REWARD_TOPIC = "reward_topic";
public static final String ENVIRONMENT_ROS = "ros_interface";
public static final String ENVIRONMENT_GYM = "gym";
public static final String STRATEGY = "strategy";
public static final String STRATEGY_OU = "ornstein_uhlenbeck";
public static final String STRATEGY_GAUSSIAN = "gaussian";
public static final String STRATEGY_EPSGREEDY = "epsgreedy";
public static final String STRATEGY_EPSDECAY = "epsdecay";
public static final String CRITIC = "critic";
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment