Skip to content
Snippets Groups Projects
Commit b4645437 authored by Nicola Gatto's avatar Nicola Gatto
Browse files

Use constants defined in CNNTrain

parent 53330fa7
No related branches found
No related tags found
3 merge requests!20Implemented layer variables and RNN layer,!19Integrate TD3 Algorithm and Gaussian Noise,!18Integrate TD3 Algorithm and Gaussian Noise
......@@ -5,141 +5,114 @@ import de.monticore.lang.monticar.cnnarch.generator.ConfigurationData;
import de.monticore.lang.monticar.cnntrain._symboltable.*;
import de.monticore.lang.monticar.cnntrain.annotations.Range;
import static de.monticore.lang.monticar.cnntrain.helper.ConfigEntryNameConstants.*;
import java.util.*;
public class ReinforcementConfigurationData extends ConfigurationData {
private static final String AST_ENTRY_LEARNING_METHOD = "learning_method";
private static final String AST_ENTRY_NUM_EPISODES = "num_episodes";
private static final String AST_ENTRY_DISCOUNT_FACTOR = "discount_factor";
private static final String AST_ENTRY_NUM_MAX_STEPS = "num_max_steps";
private static final String AST_ENTRY_TARGET_SCORE = "target_score";
private static final String AST_ENTRY_TRAINING_INTERVAL = "training_interval";
private static final String AST_ENTRY_USE_FIX_TARGET_NETWORK = "use_fix_target_network";
private static final String AST_ENTRY_TARGET_NETWORK_UPDATE_INTERVAL = "target_network_update_interval";
private static final String AST_ENTRY_SNAPSHOT_INTERVAL = "snapshot_interval";
private static final String AST_ENTRY_AGENT_NAME = "agent_name";
private static final String AST_ENTRY_USE_DOUBLE_DQN = "use_double_dqn";
private static final String AST_ENTRY_LOSS = "loss";
private static final String AST_ENTRY_RL_ALGORITHM = "rl_algorithm";
private static final String AST_ENTRY_REPLAY_MEMORY = "replay_memory";
private static final String AST_ENTRY_STRATEGY = "strategy";
private static final String AST_ENTRY_ENVIRONMENT = "environment";
private static final String AST_ENTRY_START_TRAINING_AT = "start_training_at";
private static final String AST_SOFT_TARGET_UPDATE_RATE = "soft_target_update_rate";
private static final String AST_EVALUATION_SAMPLES = "evaluation_samples";
private static final String AST_ENTRY_POLICY_NOISE = "policy_noise";
private static final String AST_ENTRY_NOISE_CLIP = "noise_clip";
private static final String AST_ENTRY_POLICY_DELAY = "policy_delay";
private static final String ENVIRONMENT_PARAM_REWARD_TOPIC = "reward_topic";
private static final String ENVIRONMENT_ROS = "ros_interface";
private static final String ENVIRONMENT_GYM = "gym";
private static final String STRATEGY_ORNSTEIN_UHLENBECK = "ornstein_uhlenbeck";
public ReinforcementConfigurationData(ConfigurationSymbol configuration, String instanceName) {
super(configuration, instanceName);
}
public Boolean isSupervisedLearning() {
if (configurationContainsKey(AST_ENTRY_LEARNING_METHOD)) {
return retrieveConfigurationEntryValueByKey(AST_ENTRY_LEARNING_METHOD)
if (configurationContainsKey(LEARNING_METHOD)) {
return retrieveConfigurationEntryValueByKey(LEARNING_METHOD)
.equals(LearningMethod.SUPERVISED);
}
return true;
}
public Boolean isReinforcementLearning() {
return configurationContainsKey(AST_ENTRY_LEARNING_METHOD)
&& retrieveConfigurationEntryValueByKey(AST_ENTRY_LEARNING_METHOD).equals(LearningMethod.REINFORCEMENT);
return configurationContainsKey(LEARNING_METHOD)
&& retrieveConfigurationEntryValueByKey(LEARNING_METHOD).equals(LearningMethod.REINFORCEMENT);
}
public Integer getNumEpisodes() {
return !configurationContainsKey(AST_ENTRY_NUM_EPISODES)
? null : (Integer)retrieveConfigurationEntryValueByKey(AST_ENTRY_NUM_EPISODES);
return !configurationContainsKey(NUM_EPISODES)
? null : (Integer)retrieveConfigurationEntryValueByKey(NUM_EPISODES);
}
public Double getDiscountFactor() {
return !configurationContainsKey(AST_ENTRY_DISCOUNT_FACTOR)
? null : (Double)retrieveConfigurationEntryValueByKey(AST_ENTRY_DISCOUNT_FACTOR);
return !configurationContainsKey(DISCOUNT_FACTOR)
? null : (Double)retrieveConfigurationEntryValueByKey(DISCOUNT_FACTOR);
}
public Integer getNumMaxSteps() {
return !configurationContainsKey(AST_ENTRY_NUM_MAX_STEPS)
? null : (Integer)retrieveConfigurationEntryValueByKey(AST_ENTRY_NUM_MAX_STEPS);
return !configurationContainsKey(NUM_MAX_STEPS)
? null : (Integer)retrieveConfigurationEntryValueByKey(NUM_MAX_STEPS);
}
public Double getTargetScore() {
return !configurationContainsKey(AST_ENTRY_TARGET_SCORE)
? null : (Double)retrieveConfigurationEntryValueByKey(AST_ENTRY_TARGET_SCORE);
return !configurationContainsKey(TARGET_SCORE)
? null : (Double)retrieveConfigurationEntryValueByKey(TARGET_SCORE);
}
public Integer getTrainingInterval() {
return !configurationContainsKey(AST_ENTRY_TRAINING_INTERVAL)
? null : (Integer)retrieveConfigurationEntryValueByKey(AST_ENTRY_TRAINING_INTERVAL);
return !configurationContainsKey(TRAINING_INTERVAL)
? null : (Integer)retrieveConfigurationEntryValueByKey(TRAINING_INTERVAL);
}
public Boolean getUseFixTargetNetwork() {
return !configurationContainsKey(AST_ENTRY_USE_FIX_TARGET_NETWORK)
? null : (Boolean)retrieveConfigurationEntryValueByKey(AST_ENTRY_USE_FIX_TARGET_NETWORK);
return !configurationContainsKey(USE_FIX_TARGET_NETWORK)
? null : (Boolean)retrieveConfigurationEntryValueByKey(USE_FIX_TARGET_NETWORK);
}
public Integer getTargetNetworkUpdateInterval() {
return !configurationContainsKey(AST_ENTRY_TARGET_NETWORK_UPDATE_INTERVAL)
? null : (Integer)retrieveConfigurationEntryValueByKey(AST_ENTRY_TARGET_NETWORK_UPDATE_INTERVAL);
return !configurationContainsKey(TARGET_NETWORK_UPDATE_INTERVAL)
? null : (Integer)retrieveConfigurationEntryValueByKey(TARGET_NETWORK_UPDATE_INTERVAL);
}
public Integer getSnapshotInterval() {
return !configurationContainsKey(AST_ENTRY_SNAPSHOT_INTERVAL)
? null : (Integer)retrieveConfigurationEntryValueByKey(AST_ENTRY_SNAPSHOT_INTERVAL);
return !configurationContainsKey(SNAPSHOT_INTERVAL)
? null : (Integer)retrieveConfigurationEntryValueByKey(SNAPSHOT_INTERVAL);
}
public String getAgentName() {
return !configurationContainsKey(AST_ENTRY_AGENT_NAME)
? null : (String)retrieveConfigurationEntryValueByKey(AST_ENTRY_AGENT_NAME);
return !configurationContainsKey(AGENT_NAME)
? null : (String)retrieveConfigurationEntryValueByKey(AGENT_NAME);
}
public Boolean getUseDoubleDqn() {
return !configurationContainsKey(AST_ENTRY_USE_DOUBLE_DQN)
? null : (Boolean)retrieveConfigurationEntryValueByKey(AST_ENTRY_USE_DOUBLE_DQN);
return !configurationContainsKey(USE_DOUBLE_DQN)
? null : (Boolean)retrieveConfigurationEntryValueByKey(USE_DOUBLE_DQN);
}
public Double getSoftTargetUpdateRate() {
return !configurationContainsKey(AST_SOFT_TARGET_UPDATE_RATE)
? null : (Double)retrieveConfigurationEntryValueByKey(AST_SOFT_TARGET_UPDATE_RATE);
return !configurationContainsKey(SOFT_TARGET_UPDATE_RATE)
? null : (Double)retrieveConfigurationEntryValueByKey(SOFT_TARGET_UPDATE_RATE);
}
public Integer getStartTrainingAt() {
return !configurationContainsKey(AST_ENTRY_START_TRAINING_AT)
? null : (Integer)retrieveConfigurationEntryValueByKey(AST_ENTRY_START_TRAINING_AT);
return !configurationContainsKey(START_TRAINING_AT)
? null : (Integer)retrieveConfigurationEntryValueByKey(START_TRAINING_AT);
}
public Integer getEvaluationSamples() {
return !configurationContainsKey(AST_EVALUATION_SAMPLES)
? null : (Integer)retrieveConfigurationEntryValueByKey(AST_EVALUATION_SAMPLES);
return !configurationContainsKey(EVALUATION_SAMPLES)
? null : (Integer)retrieveConfigurationEntryValueByKey(EVALUATION_SAMPLES);
}
public Double getPolicyNoise() {
return !configurationContainsKey(AST_ENTRY_POLICY_NOISE)
? null : (Double) retrieveConfigurationEntryValueByKey(AST_ENTRY_POLICY_NOISE);
return !configurationContainsKey(POLICY_NOISE)
? null : (Double) retrieveConfigurationEntryValueByKey(POLICY_NOISE);
}
public Double getNoiseClip() {
return !configurationContainsKey(AST_ENTRY_NOISE_CLIP)
? null : (Double) retrieveConfigurationEntryValueByKey(AST_ENTRY_NOISE_CLIP);
return !configurationContainsKey(NOISE_CLIP)
? null : (Double) retrieveConfigurationEntryValueByKey(NOISE_CLIP);
}
public Integer getPolicyDelay() {
return !configurationContainsKey(AST_ENTRY_POLICY_DELAY)
? null : (Integer) retrieveConfigurationEntryValueByKey(AST_ENTRY_POLICY_DELAY);
return !configurationContainsKey(POLICY_DELAY)
? null : (Integer) retrieveConfigurationEntryValueByKey(POLICY_DELAY);
}
public RLAlgorithm getRlAlgorithm() {
if (!isReinforcementLearning()) {
return null;
}
return !configurationContainsKey(AST_ENTRY_RL_ALGORITHM)
? RLAlgorithm.DQN : (RLAlgorithm)retrieveConfigurationEntryValueByKey(AST_ENTRY_RL_ALGORITHM);
return !configurationContainsKey(RL_ALGORITHM)
? RLAlgorithm.DQN : (RLAlgorithm)retrieveConfigurationEntryValueByKey(RL_ALGORITHM);
}
public String getInputNameOfTrainedArchitecture() {
......@@ -179,18 +152,18 @@ public class ReinforcementConfigurationData extends ConfigurationData {
}
public String getLoss() {
return !configurationContainsKey(AST_ENTRY_LOSS)
? null : retrieveConfigurationEntryValueByKey(AST_ENTRY_LOSS).toString();
return !configurationContainsKey(LOSS)
? null : retrieveConfigurationEntryValueByKey(LOSS).toString();
}
public Map<String, Object> getReplayMemory() {
return getMultiParamEntry(AST_ENTRY_REPLAY_MEMORY, "method");
return getMultiParamEntry(REPLAY_MEMORY, "method");
}
public Map<String, Object> getStrategy() {
assert isReinforcementLearning(): "Strategy parameter only for reinforcement learning but called in a " +
" non reinforcement learning context";
Map<String, Object> strategyParams = getMultiParamEntry(AST_ENTRY_STRATEGY, "method");
Map<String, Object> strategyParams = getMultiParamEntry(STRATEGY, "method");
assert getConfiguration().getTrainedArchitecture().isPresent(): "Architecture not present," +
" but reinforcement training";
NNArchitectureSymbol trainedArchitecture = getConfiguration().getTrainedArchitecture().get();
......@@ -218,7 +191,7 @@ public class ReinforcementConfigurationData extends ConfigurationData {
}
public Map<String, Object> getEnvironment() {
return getMultiParamEntry(AST_ENTRY_ENVIRONMENT, "environment");
return getMultiParamEntry(ENVIRONMENT, "environment");
}
public Boolean hasRewardFunction() {
......@@ -312,12 +285,12 @@ public class ReinforcementConfigurationData extends ConfigurationData {
}
public boolean hasRosRewardTopic() {
Map<String, Object> environmentParameters = getMultiParamEntry(AST_ENTRY_ENVIRONMENT, "environment");
Map<String, Object> environmentParameters = getMultiParamEntry(ENVIRONMENT, "environment");
if (environmentParameters == null
|| !environmentParameters.containsKey("environment")) {
return false;
}
return environmentParameters.containsKey(ENVIRONMENT_PARAM_REWARD_TOPIC);
return environmentParameters.containsKey(ENVIRONMENT_REWARD_TOPIC);
}
private Map<String, Object> getMultiParamEntry(final String key, final String valueName) {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment