Commit b12468cb authored by Sebastian Nickels's avatar Sebastian Nickels
Browse files

Merge

parents 2d4fd64e f71a6aa5
Pipeline #170397 failed with stages
in 2 minutes and 23 seconds
......@@ -27,7 +27,7 @@ masterJobLinux:
stage: linux
image: maven:3-jdk-8
script:
- mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B clean deploy --settings settings.xml -Dtest=\!Integration*
- mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B clean deploy --settings settings.xml
- cat target/site/jacoco/index.html
- mvn package sonar:sonar -s settings.xml
only:
......@@ -36,7 +36,7 @@ masterJobLinux:
masterJobWindows:
stage: windows
script:
- mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B clean install --settings settings.xml -Dtest=\!Integration*
- mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B clean install --settings settings.xml
tags:
- Windows10
......@@ -44,13 +44,7 @@ BranchJobLinux:
stage: linux
image: maven:3-jdk-8
script:
- mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B clean install --settings settings.xml -Dtest=\!Integration*
- mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B clean install --settings settings.xml
- cat target/site/jacoco/index.html
except:
- master
PythonWrapperIntegrationTest:
stage: linux
image: registry.git.rwth-aachen.de/monticore/embeddedmontiarc/generators/emadl2pythonwrapper/tests/mvn-swig:latest
script:
- mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B clean install --settings settings.xml -Dtest=IntegrationPythonWrapperTest
......@@ -8,7 +8,7 @@
<groupId>de.monticore.lang.monticar</groupId>
<artifactId>cnnarch-gluon-generator</artifactId>
<version>0.2.2-SNAPSHOT</version>
<version>0.2.7-SNAPSHOT</version>
<!-- == PROJECT DEPENDENCIES ============================================= -->
......@@ -16,10 +16,10 @@
<!-- .. SE-Libraries .................................................. -->
<CNNArch.version>0.3.2-SNAPSHOT</CNNArch.version>
<CNNTrain.version>0.3.4-SNAPSHOT</CNNTrain.version>
<CNNTrain.version>0.3.6-SNAPSHOT</CNNTrain.version>
<CNNArch2X.version>0.0.3-SNAPSHOT</CNNArch2X.version>
<embedded-montiarc-math-opt-generator>0.1.4</embedded-montiarc-math-opt-generator>
<EMADL2PythonWrapper.version>0.0.1</EMADL2PythonWrapper.version>
<EMADL2PythonWrapper.version>0.0.2-SNAPSHOT</EMADL2PythonWrapper.version>
<!-- .. Libraries .................................................. -->
<guava.version>18.0</guava.version>
......
......@@ -29,4 +29,4 @@ public class CNNArch2GluonCli {
GenericCNNArchCli cli = new GenericCNNArchCli(generator);
cli.run(args);
}
}
}
\ No newline at end of file
package de.monticore.lang.monticar.cnnarch.gluongenerator;
import com.google.common.collect.Maps;
import de.monticore.lang.monticar.cnnarch.gluongenerator.reinforcement.critic.CriticNetworkGenerationPair;
import de.monticore.lang.monticar.cnnarch.gluongenerator.reinforcement.critic.CriticNetworkGenerator;
import de.monticore.lang.embeddedmontiarc.embeddedmontiarc._symboltable.instanceStructure.EMAComponentInstanceSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureSymbol;
import de.monticore.lang.monticar.cnnarch.gluongenerator.annotations.ArchitectureAdapter;
import de.monticore.lang.monticar.cnnarch.gluongenerator.reinforcement.FunctionParameterChecker;
import de.monticore.lang.monticar.cnnarch.gluongenerator.reinforcement.RewardFunctionParameterAdapter;
import de.monticore.lang.monticar.cnnarch.gluongenerator.reinforcement.RewardFunctionSourceGenerator;
......@@ -10,15 +11,12 @@ import de.monticore.lang.monticar.cnnarch.generator.ConfigurationData;
import de.monticore.lang.monticar.cnnarch.generator.CNNTrainGenerator;
import de.monticore.lang.monticar.cnnarch.generator.TemplateConfiguration;
import de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol;
import de.monticore.lang.monticar.cnntrain._symboltable.LearningMethod;
import de.monticore.lang.monticar.cnntrain._symboltable.RLAlgorithm;
import de.monticore.lang.monticar.cnntrain._symboltable.RewardFunctionSymbol;
import de.monticore.lang.monticar.cnntrain.annotations.TrainedArchitecture;
import de.monticore.lang.monticar.cnntrain._symboltable.*;
import de.monticore.lang.monticar.generator.FileContent;
import de.monticore.lang.monticar.generator.cpp.GeneratorCPP;
import de.monticore.lang.monticar.generator.pythonwrapper.GeneratorPythonWrapperStandaloneApi;
import de.monticore.lang.monticar.generator.pythonwrapper.symbolservices.data.ComponentPortInformation;
import de.monticore.lang.tagging._symboltable.TaggingResolver;
import de.se_rwth.commons.logging.Log;
import java.io.File;
......@@ -54,13 +52,6 @@ public class CNNTrain2Gluon extends CNNTrainGenerator {
@Override
public ConfigurationSymbol getConfigurationSymbol(Path modelsDirPath, String rootModelName) {
ConfigurationSymbol configurationSymbol = super.getConfigurationSymbol(modelsDirPath, rootModelName);
// Generate Reward function if necessary
if (configurationSymbol.getLearningMethod().equals(LearningMethod.REINFORCEMENT)
&& configurationSymbol.getRlRewardFunction().isPresent()) {
generateRewardFunction(configurationSymbol.getRlRewardFunction().get(), modelsDirPath);
}
return configurationSymbol;
}
......@@ -93,17 +84,25 @@ public class CNNTrain2Gluon extends CNNTrainGenerator {
}
}
public void generate(Path modelsDirPath, String rootModelName, TrainedArchitecture trainedArchitecture) {
public void generate(Path modelsDirPath,
String rootModelName,
NNArchitectureSymbol trainedArchitecture,
NNArchitectureSymbol criticNetwork) {
ConfigurationSymbol configurationSymbol = this.getConfigurationSymbol(modelsDirPath, rootModelName);
configurationSymbol.setTrainedArchitecture(trainedArchitecture);
configurationSymbol.setCriticNetwork(criticNetwork);
this.setRootProjectModelsDir(modelsDirPath.toString());
generateFilesFromConfigurationSymbol(configurationSymbol);
}
public void generate(Path modelsDirPath, String rootModelName, NNArchitectureSymbol trainedArchitecture) {
generate(modelsDirPath, rootModelName, trainedArchitecture, null);
}
@Override
public Map<String, String> generateStrings(ConfigurationSymbol configuration) {
TemplateConfiguration templateConfiguration = new GluonTemplateConfiguration();
ReinforcementConfigurationData configData = new ReinforcementConfigurationData(configuration, getInstanceName());
GluonConfigurationData configData = new GluonConfigurationData(configuration, getInstanceName());
List<ConfigurationData> configDataList = new ArrayList<>();
configDataList.add(configData);
......@@ -119,24 +118,39 @@ public class CNNTrain2Gluon extends CNNTrainGenerator {
final String trainerName = "CNNTrainer_" + getInstanceName();
final RLAlgorithm rlAlgorithm = configData.getRlAlgorithm();
if (rlAlgorithm.equals(RLAlgorithm.DDPG)) {
CriticNetworkGenerator criticNetworkGenerator = new CriticNetworkGenerator();
criticNetworkGenerator.setGenerationTargetPath(
Paths.get(getGenerationTargetPath(), REINFORCEMENT_LEARNING_FRAMEWORK_MODULE).toString());
if (getRootProjectModelsDir().isPresent()) {
criticNetworkGenerator.setRootModelsDir(getRootProjectModelsDir().get());
} else {
Log.error("No root model dir set");
if (rlAlgorithm.equals(RLAlgorithm.DDPG)
|| rlAlgorithm.equals(RLAlgorithm.TD3)) {
if (!configuration.getCriticNetwork().isPresent()) {
Log.error("No architecture model for critic available but is required for chosen " +
"actor-critic algorithm");
}
NNArchitectureSymbol genericArchitectureSymbol = configuration.getCriticNetwork().get();
ArchitectureSymbol architectureSymbol
= ((ArchitectureAdapter)genericArchitectureSymbol).getArchitectureSymbol();
CriticNetworkGenerationPair criticNetworkResult
= criticNetworkGenerator.generateCriticNetworkContent(templateConfiguration, configuration);
CNNArch2Gluon gluonGenerator = new CNNArch2Gluon();
gluonGenerator.setGenerationTargetPath(
Paths.get(getGenerationTargetPath(), REINFORCEMENT_LEARNING_FRAMEWORK_MODULE).toString());
Map<String, String> architectureFileContentMap
= gluonGenerator.generateStringsAllowMultipleIO(architectureSymbol, true);
fileContentMap.putAll(criticNetworkResult.getFileContent().entrySet().stream().collect(Collectors.toMap(
final String creatorName = architectureFileContentMap.keySet().iterator().next();
final String criticInstanceName = creatorName.substring(
creatorName.indexOf('_') + 1, creatorName.lastIndexOf(".py"));
fileContentMap.putAll(architectureFileContentMap.entrySet().stream().collect(Collectors.toMap(
k -> REINFORCEMENT_LEARNING_FRAMEWORK_MODULE + "/" + k.getKey(),
Map.Entry::getValue))
);
ftlContext.put("criticInstanceName", criticNetworkResult.getCriticNetworkName());
ftlContext.put("criticInstanceName", criticInstanceName);
}
// Generate Reward function if necessary
if (configuration.getRlRewardFunction().isPresent()) {
generateRewardFunction(configuration.getRlRewardFunction().get(), Paths.get(rootProjectModelsDir));
}
ftlContext.put("trainerName", trainerName);
......@@ -164,8 +178,11 @@ public class CNNTrain2Gluon extends CNNTrainGenerator {
setRootProjectModelsDir(modelsDirPath.toString());
}
rewardFunctionSourceGenerator.generate(getRootProjectModelsDir().get(),
rewardFunctionRootModel, rewardFunctionOutputPath);
final TaggingResolver taggingResolver
= rewardFunctionSourceGenerator.createTaggingResolver(getRootProjectModelsDir().get());
final EMAComponentInstanceSymbol emaSymbol
= rewardFunctionSourceGenerator.resolveSymbol(taggingResolver, rewardFunctionRootModel);
rewardFunctionSourceGenerator.generate(emaSymbol, taggingResolver, rewardFunctionOutputPath);
fixArmadilloEmamGenerationOfFile(Paths.get(rewardFunctionOutputPath, String.join("_", fullNameOfComponent) + ".h"));
String pythonWrapperOutputPath = Paths.get(rewardFunctionOutputPath, "pylib").toString();
......@@ -175,12 +192,11 @@ public class CNNTrain2Gluon extends CNNTrainGenerator {
if (pythonWrapperApi.checkIfPythonModuleBuildAvailable()) {
final String rewardModuleOutput
= Paths.get(getGenerationTargetPath(), REINFORCEMENT_LEARNING_FRAMEWORK_MODULE).toString();
componentPortInformation = pythonWrapperApi.generateAndTryBuilding(getRootProjectModelsDir().get(),
rewardFunctionRootModel, pythonWrapperOutputPath, rewardModuleOutput);
componentPortInformation = pythonWrapperApi.generateAndTryBuilding(emaSymbol,
pythonWrapperOutputPath, rewardModuleOutput);
} else {
Log.warn("Cannot build wrapper automatically: OS not supported. Please build manually before starting training.");
componentPortInformation = pythonWrapperApi.generate(getRootProjectModelsDir().get(), rewardFunctionRootModel,
pythonWrapperOutputPath);
componentPortInformation = pythonWrapperApi.generate(emaSymbol, pythonWrapperOutputPath);
}
RewardFunctionParameterAdapter functionParameter = new RewardFunctionParameterAdapter(componentPortInformation);
new FunctionParameterChecker().check(functionParameter);
......
package de.monticore.lang.monticar.cnnarch.gluongenerator;
import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureSymbol;
import de.monticore.lang.monticar.cnnarch.gluongenerator.reinforcement.RewardFunctionParameterAdapter;
import de.monticore.lang.monticar.cnnarch.generator.ConfigurationData;
import de.monticore.lang.monticar.cnntrain._symboltable.*;
import de.monticore.lang.monticar.cnntrain.annotations.Range;
import de.monticore.lang.monticar.cnntrain.annotations.TrainedArchitecture;
import static de.monticore.lang.monticar.cnntrain.helper.ConfigEntryNameConstants.*;
import java.util.*;
public class ReinforcementConfigurationData extends ConfigurationData {
private static final String AST_ENTRY_LEARNING_METHOD = "learning_method";
private static final String AST_ENTRY_NUM_EPISODES = "num_episodes";
private static final String AST_ENTRY_DISCOUNT_FACTOR = "discount_factor";
private static final String AST_ENTRY_NUM_MAX_STEPS = "num_max_steps";
private static final String AST_ENTRY_TARGET_SCORE = "target_score";
private static final String AST_ENTRY_TRAINING_INTERVAL = "training_interval";
private static final String AST_ENTRY_USE_FIX_TARGET_NETWORK = "use_fix_target_network";
private static final String AST_ENTRY_TARGET_NETWORK_UPDATE_INTERVAL = "target_network_update_interval";
private static final String AST_ENTRY_SNAPSHOT_INTERVAL = "snapshot_interval";
private static final String AST_ENTRY_AGENT_NAME = "agent_name";
private static final String AST_ENTRY_USE_DOUBLE_DQN = "use_double_dqn";
private static final String AST_ENTRY_LOSS = "loss";
private static final String AST_ENTRY_RL_ALGORITHM = "rl_algorithm";
private static final String AST_ENTRY_REPLAY_MEMORY = "replay_memory";
private static final String AST_ENTRY_STRATEGY = "strategy";
private static final String AST_ENTRY_ENVIRONMENT = "environment";
private static final String AST_ENTRY_START_TRAINING_AT = "start_training_at";
private static final String AST_SOFT_TARGET_UPDATE_RATE = "soft_target_update_rate";
private static final String AST_EVALUATION_SAMPLES = "evaluation_samples";
private static final String ENVIRONMENT_PARAM_REWARD_TOPIC = "reward_topic";
private static final String ENVIRONMENT_ROS = "ros_interface";
private static final String ENVIRONMENT_GYM = "gym";
private static final String STRATEGY_ORNSTEIN_UHLENBECK = "ornstein_uhlenbeck";
public ReinforcementConfigurationData(ConfigurationSymbol configuration, String instanceName) {
public class GluonConfigurationData extends ConfigurationData {
public GluonConfigurationData(ConfigurationSymbol configuration, String instanceName) {
super(configuration, instanceName);
}
public Boolean isSupervisedLearning() {
if (configurationContainsKey(AST_ENTRY_LEARNING_METHOD)) {
return retrieveConfigurationEntryValueByKey(AST_ENTRY_LEARNING_METHOD)
if (configurationContainsKey(LEARNING_METHOD)) {
return retrieveConfigurationEntryValueByKey(LEARNING_METHOD)
.equals(LearningMethod.SUPERVISED);
}
return true;
}
public Boolean isReinforcementLearning() {
return configurationContainsKey(AST_ENTRY_LEARNING_METHOD)
&& retrieveConfigurationEntryValueByKey(AST_ENTRY_LEARNING_METHOD).equals(LearningMethod.REINFORCEMENT);
return configurationContainsKey(LEARNING_METHOD)
&& retrieveConfigurationEntryValueByKey(LEARNING_METHOD).equals(LearningMethod.REINFORCEMENT);
}
public Integer getNumEpisodes() {
return !configurationContainsKey(AST_ENTRY_NUM_EPISODES)
? null : (Integer)retrieveConfigurationEntryValueByKey(AST_ENTRY_NUM_EPISODES);
return !configurationContainsKey(NUM_EPISODES)
? null : (Integer)retrieveConfigurationEntryValueByKey(NUM_EPISODES);
}
public Double getDiscountFactor() {
return !configurationContainsKey(AST_ENTRY_DISCOUNT_FACTOR)
? null : (Double)retrieveConfigurationEntryValueByKey(AST_ENTRY_DISCOUNT_FACTOR);
return !configurationContainsKey(DISCOUNT_FACTOR)
? null : (Double)retrieveConfigurationEntryValueByKey(DISCOUNT_FACTOR);
}
public Integer getNumMaxSteps() {
return !configurationContainsKey(AST_ENTRY_NUM_MAX_STEPS)
? null : (Integer)retrieveConfigurationEntryValueByKey(AST_ENTRY_NUM_MAX_STEPS);
return !configurationContainsKey(NUM_MAX_STEPS)
? null : (Integer)retrieveConfigurationEntryValueByKey(NUM_MAX_STEPS);
}
public Double getTargetScore() {
return !configurationContainsKey(AST_ENTRY_TARGET_SCORE)
? null : (Double)retrieveConfigurationEntryValueByKey(AST_ENTRY_TARGET_SCORE);
return !configurationContainsKey(TARGET_SCORE)
? null : (Double)retrieveConfigurationEntryValueByKey(TARGET_SCORE);
}
public Integer getTrainingInterval() {
return !configurationContainsKey(AST_ENTRY_TRAINING_INTERVAL)
? null : (Integer)retrieveConfigurationEntryValueByKey(AST_ENTRY_TRAINING_INTERVAL);
return !configurationContainsKey(TRAINING_INTERVAL)
? null : (Integer)retrieveConfigurationEntryValueByKey(TRAINING_INTERVAL);
}
public Boolean getUseFixTargetNetwork() {
return !configurationContainsKey(AST_ENTRY_USE_FIX_TARGET_NETWORK)
? null : (Boolean)retrieveConfigurationEntryValueByKey(AST_ENTRY_USE_FIX_TARGET_NETWORK);
return !configurationContainsKey(USE_FIX_TARGET_NETWORK)
? null : (Boolean)retrieveConfigurationEntryValueByKey(USE_FIX_TARGET_NETWORK);
}
public Integer getTargetNetworkUpdateInterval() {
return !configurationContainsKey(AST_ENTRY_TARGET_NETWORK_UPDATE_INTERVAL)
? null : (Integer)retrieveConfigurationEntryValueByKey(AST_ENTRY_TARGET_NETWORK_UPDATE_INTERVAL);
return !configurationContainsKey(TARGET_NETWORK_UPDATE_INTERVAL)
? null : (Integer)retrieveConfigurationEntryValueByKey(TARGET_NETWORK_UPDATE_INTERVAL);
}
public Integer getSnapshotInterval() {
return !configurationContainsKey(AST_ENTRY_SNAPSHOT_INTERVAL)
? null : (Integer)retrieveConfigurationEntryValueByKey(AST_ENTRY_SNAPSHOT_INTERVAL);
return !configurationContainsKey(SNAPSHOT_INTERVAL)
? null : (Integer)retrieveConfigurationEntryValueByKey(SNAPSHOT_INTERVAL);
}
public String getAgentName() {
return !configurationContainsKey(AST_ENTRY_AGENT_NAME)
? null : (String)retrieveConfigurationEntryValueByKey(AST_ENTRY_AGENT_NAME);
return !configurationContainsKey(AGENT_NAME)
? null : (String)retrieveConfigurationEntryValueByKey(AGENT_NAME);
}
public Boolean getUseDoubleDqn() {
return !configurationContainsKey(AST_ENTRY_USE_DOUBLE_DQN)
? null : (Boolean)retrieveConfigurationEntryValueByKey(AST_ENTRY_USE_DOUBLE_DQN);
return !configurationContainsKey(USE_DOUBLE_DQN)
? null : (Boolean)retrieveConfigurationEntryValueByKey(USE_DOUBLE_DQN);
}
public Double getSoftTargetUpdateRate() {
return !configurationContainsKey(AST_SOFT_TARGET_UPDATE_RATE)
? null : (Double)retrieveConfigurationEntryValueByKey(AST_SOFT_TARGET_UPDATE_RATE);
return !configurationContainsKey(SOFT_TARGET_UPDATE_RATE)
? null : (Double)retrieveConfigurationEntryValueByKey(SOFT_TARGET_UPDATE_RATE);
}
public Integer getStartTrainingAt() {
return !configurationContainsKey(AST_ENTRY_START_TRAINING_AT)
? null : (Integer)retrieveConfigurationEntryValueByKey(AST_ENTRY_START_TRAINING_AT);
return !configurationContainsKey(START_TRAINING_AT)
? null : (Integer)retrieveConfigurationEntryValueByKey(START_TRAINING_AT);
}
public Integer getEvaluationSamples() {
return !configurationContainsKey(AST_EVALUATION_SAMPLES)
? null : (Integer)retrieveConfigurationEntryValueByKey(AST_EVALUATION_SAMPLES);
return !configurationContainsKey(EVALUATION_SAMPLES)
? null : (Integer)retrieveConfigurationEntryValueByKey(EVALUATION_SAMPLES);
}
public Double getPolicyNoise() {
return !configurationContainsKey(POLICY_NOISE)
? null : (Double) retrieveConfigurationEntryValueByKey(POLICY_NOISE);
}
public Double getNoiseClip() {
return !configurationContainsKey(NOISE_CLIP)
? null : (Double) retrieveConfigurationEntryValueByKey(NOISE_CLIP);
}
public Integer getPolicyDelay() {
return !configurationContainsKey(POLICY_DELAY)
? null : (Integer) retrieveConfigurationEntryValueByKey(POLICY_DELAY);
}
public RLAlgorithm getRlAlgorithm() {
if (!isReinforcementLearning()) {
return null;
}
return !configurationContainsKey(AST_ENTRY_RL_ALGORITHM)
? RLAlgorithm.DQN : (RLAlgorithm)retrieveConfigurationEntryValueByKey(AST_ENTRY_RL_ALGORITHM);
return !configurationContainsKey(RL_ALGORITHM)
? RLAlgorithm.DQN : (RLAlgorithm)retrieveConfigurationEntryValueByKey(RL_ALGORITHM);
}
public String getInputNameOfTrainedArchitecture() {
if (!this.getConfiguration().getTrainedArchitecture().isPresent()) {
throw new IllegalStateException("No trained architecture set");
}
TrainedArchitecture trainedArchitecture = getConfiguration().getTrainedArchitecture().get();
NNArchitectureSymbol trainedArchitecture = getConfiguration().getTrainedArchitecture().get();
// We allow only one input, the first one is the only input
return trainedArchitecture.getInputs().get(0);
}
......@@ -141,7 +128,7 @@ public class ReinforcementConfigurationData extends ConfigurationData {
if (!this.getConfiguration().getTrainedArchitecture().isPresent()) {
throw new IllegalStateException("No trained architecture set");
}
TrainedArchitecture trainedArchitecture = getConfiguration().getTrainedArchitecture().get();
NNArchitectureSymbol trainedArchitecture = getConfiguration().getTrainedArchitecture().get();
// We allow only one output, the first one is the only output
return trainedArchitecture.getOutputs().get(0);
}
......@@ -151,7 +138,7 @@ public class ReinforcementConfigurationData extends ConfigurationData {
return null;
}
final String inputName = getInputNameOfTrainedArchitecture();
TrainedArchitecture trainedArchitecture = this.getConfiguration().getTrainedArchitecture().get();
NNArchitectureSymbol trainedArchitecture = this.getConfiguration().getTrainedArchitecture().get();
return trainedArchitecture.getDimensions().get(inputName);
}
......@@ -160,53 +147,51 @@ public class ReinforcementConfigurationData extends ConfigurationData {
return null;
}
final String outputName = getOutputNameOfTrainedArchitecture();
TrainedArchitecture trainedArchitecture = this.getConfiguration().getTrainedArchitecture().get();
NNArchitectureSymbol trainedArchitecture = this.getConfiguration().getTrainedArchitecture().get();
return trainedArchitecture.getDimensions().get(outputName);
}
public String getLoss() {
return !configurationContainsKey(AST_ENTRY_LOSS)
? null : retrieveConfigurationEntryValueByKey(AST_ENTRY_LOSS).toString();
return !configurationContainsKey(LOSS)
? null : retrieveConfigurationEntryValueByKey(LOSS).toString();
}
public Map<String, Object> getReplayMemory() {
return getMultiParamEntry(AST_ENTRY_REPLAY_MEMORY, "method");
return getMultiParamEntry(REPLAY_MEMORY, "method");
}
public Map<String, Object> getStrategy() {
assert isReinforcementLearning(): "Strategy parameter only for reinforcement learning but called in a " +
" non reinforcement learning context";
Map<String, Object> strategyParams = getMultiParamEntry(AST_ENTRY_STRATEGY, "method");
if (strategyParams.get("method").equals(STRATEGY_ORNSTEIN_UHLENBECK)) {
assert getConfiguration().getTrainedArchitecture().isPresent(): "Architecture not present," +
" but reinforcement training";
TrainedArchitecture trainedArchitecture = getConfiguration().getTrainedArchitecture().get();
final String actionPortName = getOutputNameOfTrainedArchitecture();
Range actionRange = trainedArchitecture.getRanges().get(actionPortName);
if (actionRange.isLowerLimitInfinity() && actionRange.isUpperLimitInfinity()) {
strategyParams.put("action_low", null);
strategyParams.put("action_high", null);
} else if(!actionRange.isLowerLimitInfinity() && actionRange.isUpperLimitInfinity()) {
assert actionRange.getLowerLimit().isPresent();
strategyParams.put("action_low", actionRange.getLowerLimit().get());
strategyParams.put("action_high", null);
} else if (actionRange.isLowerLimitInfinity() && !actionRange.isUpperLimitInfinity()) {
assert actionRange.getUpperLimit().isPresent();
strategyParams.put("action_low", null);
strategyParams.put("action_high", actionRange.getUpperLimit().get());
} else {
assert actionRange.getLowerLimit().isPresent();
assert actionRange.getUpperLimit().isPresent();
strategyParams.put("action_low", actionRange.getLowerLimit().get());
strategyParams.put("action_high", actionRange.getUpperLimit().get());
}
Map<String, Object> strategyParams = getMultiParamEntry(STRATEGY, "method");
assert getConfiguration().getTrainedArchitecture().isPresent(): "Architecture not present," +
" but reinforcement training";
NNArchitectureSymbol trainedArchitecture = getConfiguration().getTrainedArchitecture().get();
final String actionPortName = getOutputNameOfTrainedArchitecture();
Range actionRange = trainedArchitecture.getRanges().get(actionPortName);
if (actionRange.isLowerLimitInfinity() && actionRange.isUpperLimitInfinity()) {
strategyParams.put("action_low", null);
strategyParams.put("action_high", null);
} else if(!actionRange.isLowerLimitInfinity() && actionRange.isUpperLimitInfinity()) {
assert actionRange.getLowerLimit().isPresent();
strategyParams.put("action_low", actionRange.getLowerLimit().get());
strategyParams.put("action_high", null);
} else if (actionRange.isLowerLimitInfinity() && !actionRange.isUpperLimitInfinity()) {
assert actionRange.getUpperLimit().isPresent();
strategyParams.put("action_low", null);
strategyParams.put("action_high", actionRange.getUpperLimit().get());
} else {
assert actionRange.getLowerLimit().isPresent();
assert actionRange.getUpperLimit().isPresent();
strategyParams.put("action_low", actionRange.getLowerLimit().get());
strategyParams.put("action_high", actionRange.getUpperLimit().get());
}
return strategyParams;
}
public Map<String, Object> getEnvironment() {
return getMultiParamEntry(AST_ENTRY_ENVIRONMENT, "environment");
return getMultiParamEntry(ENVIRONMENT, "environment");
}
public Boolean hasRewardFunction() {
......@@ -300,12 +285,12 @@ public class ReinforcementConfigurationData extends ConfigurationData {
}
public boolean hasRosRewardTopic() {
Map<String, Object> environmentParameters = getMultiParamEntry(AST_ENTRY_ENVIRONMENT, "environment");
Map<String, Object> environmentParameters = getMultiParamEntry(ENVIRONMENT, "environment");
if (environmentParameters == null
|| !environmentParameters.containsKey("environment")) {
return false;
}
return environmentParameters.containsKey(ENVIRONMENT_PARAM_REWARD_TOPIC);
return environmentParameters.containsKey(ENVIRONMENT_REWARD_TOPIC);
}
private Map<String, Object> getMultiParamEntry(final String key, final String valueName) {
......
......@@ -3,8 +3,8 @@ package de.monticore.lang.monticar.cnnarch.gluongenerator.annotations;
import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.IODeclarationSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.VariableSymbol;
import de.monticore.lang.monticar.cnntrain._symboltable.NNArchitectureSymbol;
import de.monticore.lang.monticar.cnntrain.annotations.Range;
import de.monticore.lang.monticar.cnntrain.annotations.TrainedArchitecture;
import de.monticore.lang.monticar.ranges._ast.ASTRange;
import de.monticore.symboltable.CommonSymbol;
......@@ -15,11 +15,13 @@ import java.util.stream.Collectors;
import static com.google.common.base.Preconditions.checkNotNull;
public class ArchitectureAdapter implements TrainedArchitecture {
public class ArchitectureAdapter extends NNArchitectureSymbol {
private ArchitectureSymbol architectureSymbol;
public ArchitectureAdapter(final ArchitectureSymbol architectureSymbol) {
public ArchitectureAdapter(final String name,
final ArchitectureSymbol architectureSymbol) {
super(name);
checkNotNull(architectureSymbol);
this.architectureSymbol = architectureSymbol;
}
......@@ -56,6 +58,10 @@ public class ArchitectureAdapter implements TrainedArchitecture {
s -> ((IODeclarationSymbol) s.getDeclaration()).getType().getDomain().getName()));
}
public ArchitectureSymbol getArchitectureSymbol() {
return this.architectureSymbol;
}
private Range astRangeToTrainRange(final ASTRange range) {
if (range == null || (range.hasNoLowerLimit() && range.hasNoUpperLimit())) {
return Range.withInfinityLimits();
......