Commit b4645437 authored by Nicola Gatto's avatar Nicola Gatto
Browse files

Use constants defined in CNNTrain

parent 53330fa7
......@@ -5,141 +5,114 @@ import de.monticore.lang.monticar.cnnarch.generator.ConfigurationData;
import de.monticore.lang.monticar.cnntrain._symboltable.*;
import de.monticore.lang.monticar.cnntrain.annotations.Range;
import static de.monticore.lang.monticar.cnntrain.helper.ConfigEntryNameConstants.*;
import java.util.*;
public class ReinforcementConfigurationData extends ConfigurationData {
private static final String AST_ENTRY_LEARNING_METHOD = "learning_method";
private static final String AST_ENTRY_NUM_EPISODES = "num_episodes";
private static final String AST_ENTRY_DISCOUNT_FACTOR = "discount_factor";
private static final String AST_ENTRY_NUM_MAX_STEPS = "num_max_steps";
private static final String AST_ENTRY_TARGET_SCORE = "target_score";
private static final String AST_ENTRY_TRAINING_INTERVAL = "training_interval";
private static final String AST_ENTRY_USE_FIX_TARGET_NETWORK = "use_fix_target_network";
private static final String AST_ENTRY_TARGET_NETWORK_UPDATE_INTERVAL = "target_network_update_interval";
private static final String AST_ENTRY_SNAPSHOT_INTERVAL = "snapshot_interval";
private static final String AST_ENTRY_AGENT_NAME = "agent_name";
private static final String AST_ENTRY_USE_DOUBLE_DQN = "use_double_dqn";
private static final String AST_ENTRY_LOSS = "loss";
private static final String AST_ENTRY_RL_ALGORITHM = "rl_algorithm";
private static final String AST_ENTRY_REPLAY_MEMORY = "replay_memory";
private static final String AST_ENTRY_STRATEGY = "strategy";
private static final String AST_ENTRY_ENVIRONMENT = "environment";
private static final String AST_ENTRY_START_TRAINING_AT = "start_training_at";
private static final String AST_SOFT_TARGET_UPDATE_RATE = "soft_target_update_rate";
private static final String AST_EVALUATION_SAMPLES = "evaluation_samples";
private static final String AST_ENTRY_POLICY_NOISE = "policy_noise";
private static final String AST_ENTRY_NOISE_CLIP = "noise_clip";
private static final String AST_ENTRY_POLICY_DELAY = "policy_delay";
private static final String ENVIRONMENT_PARAM_REWARD_TOPIC = "reward_topic";
private static final String ENVIRONMENT_ROS = "ros_interface";
private static final String ENVIRONMENT_GYM = "gym";
private static final String STRATEGY_ORNSTEIN_UHLENBECK = "ornstein_uhlenbeck";
public ReinforcementConfigurationData(ConfigurationSymbol configuration, String instanceName) {
super(configuration, instanceName);
}
public Boolean isSupervisedLearning() {
if (configurationContainsKey(AST_ENTRY_LEARNING_METHOD)) {
return retrieveConfigurationEntryValueByKey(AST_ENTRY_LEARNING_METHOD)
if (configurationContainsKey(LEARNING_METHOD)) {
return retrieveConfigurationEntryValueByKey(LEARNING_METHOD)
.equals(LearningMethod.SUPERVISED);
}
return true;
}
public Boolean isReinforcementLearning() {
return configurationContainsKey(AST_ENTRY_LEARNING_METHOD)
&& retrieveConfigurationEntryValueByKey(AST_ENTRY_LEARNING_METHOD).equals(LearningMethod.REINFORCEMENT);
return configurationContainsKey(LEARNING_METHOD)
&& retrieveConfigurationEntryValueByKey(LEARNING_METHOD).equals(LearningMethod.REINFORCEMENT);
}
public Integer getNumEpisodes() {
return !configurationContainsKey(AST_ENTRY_NUM_EPISODES)
? null : (Integer)retrieveConfigurationEntryValueByKey(AST_ENTRY_NUM_EPISODES);
return !configurationContainsKey(NUM_EPISODES)
? null : (Integer)retrieveConfigurationEntryValueByKey(NUM_EPISODES);
}
public Double getDiscountFactor() {
return !configurationContainsKey(AST_ENTRY_DISCOUNT_FACTOR)
? null : (Double)retrieveConfigurationEntryValueByKey(AST_ENTRY_DISCOUNT_FACTOR);
return !configurationContainsKey(DISCOUNT_FACTOR)
? null : (Double)retrieveConfigurationEntryValueByKey(DISCOUNT_FACTOR);
}
public Integer getNumMaxSteps() {
return !configurationContainsKey(AST_ENTRY_NUM_MAX_STEPS)
? null : (Integer)retrieveConfigurationEntryValueByKey(AST_ENTRY_NUM_MAX_STEPS);
return !configurationContainsKey(NUM_MAX_STEPS)
? null : (Integer)retrieveConfigurationEntryValueByKey(NUM_MAX_STEPS);
}
public Double getTargetScore() {
return !configurationContainsKey(AST_ENTRY_TARGET_SCORE)
? null : (Double)retrieveConfigurationEntryValueByKey(AST_ENTRY_TARGET_SCORE);
return !configurationContainsKey(TARGET_SCORE)
? null : (Double)retrieveConfigurationEntryValueByKey(TARGET_SCORE);
}
public Integer getTrainingInterval() {
return !configurationContainsKey(AST_ENTRY_TRAINING_INTERVAL)
? null : (Integer)retrieveConfigurationEntryValueByKey(AST_ENTRY_TRAINING_INTERVAL);
return !configurationContainsKey(TRAINING_INTERVAL)
? null : (Integer)retrieveConfigurationEntryValueByKey(TRAINING_INTERVAL);
}
public Boolean getUseFixTargetNetwork() {
return !configurationContainsKey(AST_ENTRY_USE_FIX_TARGET_NETWORK)
? null : (Boolean)retrieveConfigurationEntryValueByKey(AST_ENTRY_USE_FIX_TARGET_NETWORK);
return !configurationContainsKey(USE_FIX_TARGET_NETWORK)
? null : (Boolean)retrieveConfigurationEntryValueByKey(USE_FIX_TARGET_NETWORK);
}
public Integer getTargetNetworkUpdateInterval() {
return !configurationContainsKey(AST_ENTRY_TARGET_NETWORK_UPDATE_INTERVAL)
? null : (Integer)retrieveConfigurationEntryValueByKey(AST_ENTRY_TARGET_NETWORK_UPDATE_INTERVAL);
return !configurationContainsKey(TARGET_NETWORK_UPDATE_INTERVAL)
? null : (Integer)retrieveConfigurationEntryValueByKey(TARGET_NETWORK_UPDATE_INTERVAL);
}
public Integer getSnapshotInterval() {
return !configurationContainsKey(AST_ENTRY_SNAPSHOT_INTERVAL)
? null : (Integer)retrieveConfigurationEntryValueByKey(AST_ENTRY_SNAPSHOT_INTERVAL);
return !configurationContainsKey(SNAPSHOT_INTERVAL)
? null : (Integer)retrieveConfigurationEntryValueByKey(SNAPSHOT_INTERVAL);
}
public String getAgentName() {
return !configurationContainsKey(AST_ENTRY_AGENT_NAME)
? null : (String)retrieveConfigurationEntryValueByKey(AST_ENTRY_AGENT_NAME);
return !configurationContainsKey(AGENT_NAME)
? null : (String)retrieveConfigurationEntryValueByKey(AGENT_NAME);
}
public Boolean getUseDoubleDqn() {
return !configurationContainsKey(AST_ENTRY_USE_DOUBLE_DQN)
? null : (Boolean)retrieveConfigurationEntryValueByKey(AST_ENTRY_USE_DOUBLE_DQN);
return !configurationContainsKey(USE_DOUBLE_DQN)
? null : (Boolean)retrieveConfigurationEntryValueByKey(USE_DOUBLE_DQN);
}
public Double getSoftTargetUpdateRate() {
return !configurationContainsKey(AST_SOFT_TARGET_UPDATE_RATE)
? null : (Double)retrieveConfigurationEntryValueByKey(AST_SOFT_TARGET_UPDATE_RATE);
return !configurationContainsKey(SOFT_TARGET_UPDATE_RATE)
? null : (Double)retrieveConfigurationEntryValueByKey(SOFT_TARGET_UPDATE_RATE);
}
public Integer getStartTrainingAt() {
return !configurationContainsKey(AST_ENTRY_START_TRAINING_AT)
? null : (Integer)retrieveConfigurationEntryValueByKey(AST_ENTRY_START_TRAINING_AT);
return !configurationContainsKey(START_TRAINING_AT)
? null : (Integer)retrieveConfigurationEntryValueByKey(START_TRAINING_AT);
}
public Integer getEvaluationSamples() {
return !configurationContainsKey(AST_EVALUATION_SAMPLES)
? null : (Integer)retrieveConfigurationEntryValueByKey(AST_EVALUATION_SAMPLES);
return !configurationContainsKey(EVALUATION_SAMPLES)
? null : (Integer)retrieveConfigurationEntryValueByKey(EVALUATION_SAMPLES);
}
public Double getPolicyNoise() {
return !configurationContainsKey(AST_ENTRY_POLICY_NOISE)
? null : (Double) retrieveConfigurationEntryValueByKey(AST_ENTRY_POLICY_NOISE);
return !configurationContainsKey(POLICY_NOISE)
? null : (Double) retrieveConfigurationEntryValueByKey(POLICY_NOISE);
}
public Double getNoiseClip() {
return !configurationContainsKey(AST_ENTRY_NOISE_CLIP)
? null : (Double) retrieveConfigurationEntryValueByKey(AST_ENTRY_NOISE_CLIP);
return !configurationContainsKey(NOISE_CLIP)
? null : (Double) retrieveConfigurationEntryValueByKey(NOISE_CLIP);
}
public Integer getPolicyDelay() {
return !configurationContainsKey(AST_ENTRY_POLICY_DELAY)
? null : (Integer) retrieveConfigurationEntryValueByKey(AST_ENTRY_POLICY_DELAY);
return !configurationContainsKey(POLICY_DELAY)
? null : (Integer) retrieveConfigurationEntryValueByKey(POLICY_DELAY);
}
public RLAlgorithm getRlAlgorithm() {
if (!isReinforcementLearning()) {
return null;
}
return !configurationContainsKey(AST_ENTRY_RL_ALGORITHM)
? RLAlgorithm.DQN : (RLAlgorithm)retrieveConfigurationEntryValueByKey(AST_ENTRY_RL_ALGORITHM);
return !configurationContainsKey(RL_ALGORITHM)
? RLAlgorithm.DQN : (RLAlgorithm)retrieveConfigurationEntryValueByKey(RL_ALGORITHM);
}
public String getInputNameOfTrainedArchitecture() {
......@@ -179,18 +152,18 @@ public class ReinforcementConfigurationData extends ConfigurationData {
}
public String getLoss() {
return !configurationContainsKey(AST_ENTRY_LOSS)
? null : retrieveConfigurationEntryValueByKey(AST_ENTRY_LOSS).toString();
return !configurationContainsKey(LOSS)
? null : retrieveConfigurationEntryValueByKey(LOSS).toString();
}
public Map<String, Object> getReplayMemory() {
return getMultiParamEntry(AST_ENTRY_REPLAY_MEMORY, "method");
return getMultiParamEntry(REPLAY_MEMORY, "method");
}
public Map<String, Object> getStrategy() {
assert isReinforcementLearning(): "Strategy parameter only for reinforcement learning but called in a " +
" non reinforcement learning context";
Map<String, Object> strategyParams = getMultiParamEntry(AST_ENTRY_STRATEGY, "method");
Map<String, Object> strategyParams = getMultiParamEntry(STRATEGY, "method");
assert getConfiguration().getTrainedArchitecture().isPresent(): "Architecture not present," +
" but reinforcement training";
NNArchitectureSymbol trainedArchitecture = getConfiguration().getTrainedArchitecture().get();
......@@ -218,7 +191,7 @@ public class ReinforcementConfigurationData extends ConfigurationData {
}
public Map<String, Object> getEnvironment() {
return getMultiParamEntry(AST_ENTRY_ENVIRONMENT, "environment");
return getMultiParamEntry(ENVIRONMENT, "environment");
}
public Boolean hasRewardFunction() {
......@@ -312,12 +285,12 @@ public class ReinforcementConfigurationData extends ConfigurationData {
}
public boolean hasRosRewardTopic() {
Map<String, Object> environmentParameters = getMultiParamEntry(AST_ENTRY_ENVIRONMENT, "environment");
Map<String, Object> environmentParameters = getMultiParamEntry(ENVIRONMENT, "environment");
if (environmentParameters == null
|| !environmentParameters.containsKey("environment")) {
return false;
}
return environmentParameters.containsKey(ENVIRONMENT_PARAM_REWARD_TOPIC);
return environmentParameters.containsKey(ENVIRONMENT_REWARD_TOPIC);
}
private Map<String, Object> getMultiParamEntry(final String key, final String valueName) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment