Commit f9465a60 authored by Nicola Gatto's avatar Nicola Gatto Committed by Evgeny Kusmenko

Implement reinforcement learning

parent 2dd5aae6
...@@ -8,3 +8,4 @@ nppBackup ...@@ -8,3 +8,4 @@ nppBackup
*.iml *.iml
.vscode
...@@ -27,7 +27,7 @@ masterJobLinux: ...@@ -27,7 +27,7 @@ masterJobLinux:
stage: linux stage: linux
image: maven:3-jdk-8 image: maven:3-jdk-8
script: script:
- mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B clean deploy --settings settings.xml - mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B clean deploy --settings settings.xml -Dtest=\!Integration*
- cat target/site/jacoco/index.html - cat target/site/jacoco/index.html
- mvn package sonar:sonar -s settings.xml - mvn package sonar:sonar -s settings.xml
only: only:
...@@ -36,7 +36,7 @@ masterJobLinux: ...@@ -36,7 +36,7 @@ masterJobLinux:
masterJobWindows: masterJobWindows:
stage: windows stage: windows
script: script:
- mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B clean install --settings settings.xml - mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B clean install --settings settings.xml -Dtest=\!Integration*
tags: tags:
- Windows10 - Windows10
...@@ -44,7 +44,13 @@ BranchJobLinux: ...@@ -44,7 +44,13 @@ BranchJobLinux:
stage: linux stage: linux
image: maven:3-jdk-8 image: maven:3-jdk-8
script: script:
- mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B clean install --settings settings.xml - mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B clean install --settings settings.xml -Dtest=\!Integration*
- cat target/site/jacoco/index.html - cat target/site/jacoco/index.html
except: except:
- master - master
PythonWrapperIntegrationTest:
stage: linux
image: registry.git.rwth-aachen.de/monticore/embeddedmontiarc/generators/emadl2pythonwrapper/tests/mvn-swig:latest
script:
- mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B clean install --settings settings.xml -Dtest=IntegrationPythonWrapperTest
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
<groupId>de.monticore.lang.monticar</groupId> <groupId>de.monticore.lang.monticar</groupId>
<artifactId>cnnarch-gluon-generator</artifactId> <artifactId>cnnarch-gluon-generator</artifactId>
<version>0.1.6</version> <version>0.2.0-SNAPSHOT</version>
<!-- == PROJECT DEPENDENCIES ============================================= --> <!-- == PROJECT DEPENDENCIES ============================================= -->
...@@ -16,9 +16,10 @@ ...@@ -16,9 +16,10 @@
<!-- .. SE-Libraries .................................................. --> <!-- .. SE-Libraries .................................................. -->
<CNNArch.version>0.3.0-SNAPSHOT</CNNArch.version> <CNNArch.version>0.3.0-SNAPSHOT</CNNArch.version>
<CNNTrain.version>0.2.6</CNNTrain.version> <CNNTrain.version>0.3.0-SNAPSHOT</CNNTrain.version>
<CNNArch2MXNet.version>0.2.14-SNAPSHOT</CNNArch2MXNet.version> <CNNArch2MXNet.version>0.2.14-SNAPSHOT</CNNArch2MXNet.version>
<embedded-montiarc-math-opt-generator>0.1.4</embedded-montiarc-math-opt-generator> <embedded-montiarc-math-opt-generator>0.1.4</embedded-montiarc-math-opt-generator>
<EMADL2PythonWrapper.version>0.0.1</EMADL2PythonWrapper.version>
<!-- .. Libraries .................................................. --> <!-- .. Libraries .................................................. -->
<guava.version>18.0</guava.version> <guava.version>18.0</guava.version>
...@@ -100,6 +101,12 @@ ...@@ -100,6 +101,12 @@
<version>${embedded-montiarc-math-opt-generator}</version> <version>${embedded-montiarc-math-opt-generator}</version>
</dependency> </dependency>
<dependency>
<groupId>de.monticore.lang.monticar</groupId>
<artifactId>embedded-montiarc-emadl-pythonwrapper-generator</artifactId>
<version>${EMADL2PythonWrapper.version}</version>
</dependency>
<!-- .. Test Libraries ............................................... --> <!-- .. Test Libraries ............................................... -->
<dependency> <dependency>
...@@ -109,6 +116,13 @@ ...@@ -109,6 +116,13 @@
<scope>test</scope> <scope>test</scope>
</dependency> </dependency>
<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-core</artifactId>
<version>1.10.19</version>
<scope>test</scope>
</dependency>
<dependency> <dependency>
<groupId>ch.qos.logback</groupId> <groupId>ch.qos.logback</groupId>
<artifactId>logback-classic</artifactId> <artifactId>logback-classic</artifactId>
......
...@@ -47,8 +47,10 @@ public class CNNArch2Gluon extends CNNArch2MxNet { ...@@ -47,8 +47,10 @@ public class CNNArch2Gluon extends CNNArch2MxNet {
temp = archTc.process("CNNNet", Target.PYTHON); temp = archTc.process("CNNNet", Target.PYTHON);
fileContentMap.put(temp.getKey(), temp.getValue()); fileContentMap.put(temp.getKey(), temp.getValue());
temp = archTc.process("CNNDataLoader", Target.PYTHON); if (architecture.getDataPath() != null) {
fileContentMap.put(temp.getKey(), temp.getValue()); temp = archTc.process("CNNDataLoader", Target.PYTHON);
fileContentMap.put(temp.getKey(), temp.getValue());
}
temp = archTc.process("CNNCreator", Target.PYTHON); temp = archTc.process("CNNCreator", Target.PYTHON);
fileContentMap.put(temp.getKey(), temp.getValue()); fileContentMap.put(temp.getKey(), temp.getValue());
......
package de.monticore.lang.monticar.cnnarch.gluongenerator; package de.monticore.lang.monticar.cnnarch.gluongenerator;
import com.google.common.collect.Maps;
import de.monticore.lang.monticar.cnnarch.gluongenerator.reinforcement.FunctionParameterChecker;
import de.monticore.lang.monticar.cnnarch.gluongenerator.reinforcement.RewardFunctionParameterAdapter;
import de.monticore.lang.monticar.cnnarch.gluongenerator.reinforcement.RewardFunctionSourceGenerator;
import de.monticore.lang.monticar.cnnarch.mxnetgenerator.ConfigurationData; import de.monticore.lang.monticar.cnnarch.mxnetgenerator.ConfigurationData;
import de.monticore.lang.monticar.cnnarch.mxnetgenerator.CNNTrain2MxNet; import de.monticore.lang.monticar.cnnarch.mxnetgenerator.CNNTrain2MxNet;
import de.monticore.lang.monticar.cnnarch.mxnetgenerator.TemplateConfiguration; import de.monticore.lang.monticar.cnnarch.mxnetgenerator.TemplateConfiguration;
import de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol; import de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol;
import de.monticore.lang.monticar.cnntrain._symboltable.LearningMethod;
import de.monticore.lang.monticar.cnntrain._symboltable.RewardFunctionSymbol;
import de.monticore.lang.monticar.generator.FileContent;
import de.monticore.lang.monticar.generator.cpp.GeneratorCPP;
import de.monticore.lang.monticar.generator.pythonwrapper.GeneratorPythonWrapperStandaloneApi;
import de.monticore.lang.monticar.generator.pythonwrapper.symbolservices.data.ComponentPortInformation;
import de.se_rwth.commons.logging.Log;
import java.io.File;
import java.io.IOException;
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.*; import java.util.*;
public class CNNTrain2Gluon extends CNNTrain2MxNet { public class CNNTrain2Gluon extends CNNTrain2MxNet {
public CNNTrain2Gluon() { private static final String REINFORCEMENT_LEARNING_FRAMEWORK_MODULE = "reinforcement_learning";
private final RewardFunctionSourceGenerator rewardFunctionSourceGenerator;
private String rootProjectModelsDir;
public Optional<String> getRootProjectModelsDir() {
return Optional.ofNullable(rootProjectModelsDir);
}
public void setRootProjectModelsDir(String rootProjectModelsDir) {
this.rootProjectModelsDir = rootProjectModelsDir;
}
public CNNTrain2Gluon(RewardFunctionSourceGenerator rewardFunctionSourceGenerator) {
super(); super();
this.rewardFunctionSourceGenerator = rewardFunctionSourceGenerator;
}
@Override
public ConfigurationSymbol getConfigurationSymbol(Path modelsDirPath, String rootModelName) {
ConfigurationSymbol configurationSymbol = super.getConfigurationSymbol(modelsDirPath, rootModelName);
// Generate Reward function if necessary
if (configurationSymbol.getLearningMethod().equals(LearningMethod.REINFORCEMENT)
&& configurationSymbol.getRlRewardFunction().isPresent()) {
generateRewardFunction(configurationSymbol.getRlRewardFunction().get(), modelsDirPath);
}
return configurationSymbol;
}
@Override
public void generate(Path modelsDirPath, String rootModelName) {
ConfigurationSymbol configuration = this.getConfigurationSymbol(modelsDirPath, rootModelName);
Map<String, String> fileContents = this.generateStrings(configuration);
GeneratorCPP genCPP = new GeneratorCPP();
genCPP.setGenerationTargetPath(this.getGenerationTargetPath());
try {
Iterator var6 = fileContents.keySet().iterator();
while(var6.hasNext()) {
String fileName = (String)var6.next();
genCPP.generateFile(new FileContent((String)fileContents.get(fileName), fileName));
}
} catch (IOException var8) {
Log.error("CNNTrainer file could not be generated" + var8.getMessage());
}
} }
@Override @Override
public Map<String, String> generateStrings(ConfigurationSymbol configuration) { public Map<String, String> generateStrings(ConfigurationSymbol configuration) {
TemplateConfiguration templateConfiguration = new GluonTemplateConfiguration(); TemplateConfiguration templateConfiguration = new GluonTemplateConfiguration();
ConfigurationData configData = new ConfigurationData(configuration, getInstanceName()); ReinforcementConfigurationData configData = new ReinforcementConfigurationData(configuration, getInstanceName());
List<ConfigurationData> configDataList = new ArrayList<>(); List<ConfigurationData> configDataList = new ArrayList<>();
configDataList.add(configData); configDataList.add(configData);
Map<String, Object> ftlContext = Collections.singletonMap("configurations", configDataList);
Map<String, Object> ftlContext = Maps.newHashMap();
ftlContext.put("configurations", configDataList);
Map<String, String> fileContentMap = new HashMap<>(); Map<String, String> fileContentMap = new HashMap<>();
String cnnTrainTemplateContent = templateConfiguration.processTemplate(ftlContext, "CNNTrainer.ftl"); if (configData.isSupervisedLearning()) {
fileContentMap.put("CNNTrainer_" + getInstanceName() + ".py", cnnTrainTemplateContent); String cnnTrainTemplateContent = templateConfiguration.processTemplate(ftlContext, "CNNTrainer.ftl");
fileContentMap.put("CNNTrainer_" + getInstanceName() + ".py", cnnTrainTemplateContent);
String cnnSupervisedTrainerContent = templateConfiguration.processTemplate(ftlContext, "CNNSupervisedTrainer.ftl");
fileContentMap.put("supervised_trainer.py", cnnSupervisedTrainerContent);
} else if (configData.isReinforcementLearning()) {
final String trainerName = "CNNTrainer_" + getInstanceName();
ftlContext.put("trainerName", trainerName);
Map<String, String> rlFrameworkContentMap = constructReinforcementLearningFramework(templateConfiguration, ftlContext);
fileContentMap.putAll(rlFrameworkContentMap);
final String reinforcementTrainerContent = templateConfiguration.processTemplate(ftlContext, "reinforcement/Trainer.ftl");
fileContentMap.put(trainerName + ".py", reinforcementTrainerContent);
final String startTrainerScriptContent = templateConfiguration.processTemplate(ftlContext, "reinforcement/StartTrainer.ftl");
fileContentMap.put("start_training.sh", startTrainerScriptContent);
}
return fileContentMap;
}
private void generateRewardFunction(RewardFunctionSymbol rewardFunctionSymbol, Path modelsDirPath) {
GeneratorPythonWrapperStandaloneApi pythonWrapperApi = new GeneratorPythonWrapperStandaloneApi();
List<String> fullNameOfComponent = rewardFunctionSymbol.getRewardFunctionComponentName();
String rewardFunctionRootModel = String.join(".", fullNameOfComponent);
String rewardFunctionOutputPath = Paths.get(this.getGenerationTargetPath(), "reward").toString();
if (!getRootProjectModelsDir().isPresent()) {
setRootProjectModelsDir(modelsDirPath.toString());
}
rewardFunctionSourceGenerator.generate(getRootProjectModelsDir().get(),
rewardFunctionRootModel, rewardFunctionOutputPath);
fixArmadilloEmamGenerationOfFile(Paths.get(rewardFunctionOutputPath, String.join("_", fullNameOfComponent) + ".h"));
String pythonWrapperOutputPath = Paths.get(rewardFunctionOutputPath, "pylib").toString();
Log.info("Generating reward function python wrapper...", "CNNTrain2Gluon");
ComponentPortInformation componentPortInformation;
if (pythonWrapperApi.checkIfPythonModuleBuildAvailable()) {
final String rewardModuleOutput
= Paths.get(getGenerationTargetPath(), REINFORCEMENT_LEARNING_FRAMEWORK_MODULE).toString();
componentPortInformation = pythonWrapperApi.generateAndTryBuilding(getRootProjectModelsDir().get(),
rewardFunctionRootModel, pythonWrapperOutputPath, rewardModuleOutput);
} else {
Log.warn("Cannot build wrapper automatically: OS not supported. Please build manually before starting training.");
componentPortInformation = pythonWrapperApi.generate(getRootProjectModelsDir().get(), rewardFunctionRootModel,
pythonWrapperOutputPath);
}
RewardFunctionParameterAdapter functionParameter = new RewardFunctionParameterAdapter(componentPortInformation);
new FunctionParameterChecker().check(functionParameter);
rewardFunctionSymbol.setRewardFunctionParameter(functionParameter);
}
private void fixArmadilloEmamGenerationOfFile(Path pathToBrokenFile){
final File brokenFile = pathToBrokenFile.toFile();
if (brokenFile.exists()) {
try {
Charset charset = StandardCharsets.UTF_8;
String fileContent = new String(Files.readAllBytes(pathToBrokenFile), charset);
fileContent = fileContent.replace("armadillo.h", "armadillo");
Files.write(pathToBrokenFile, fileContent.getBytes());
} catch (IOException e) {
Log.warn("Cannot fix wrong armadillo library in " + pathToBrokenFile.toString());
}
}
}
private Map<String, String> constructReinforcementLearningFramework(
final TemplateConfiguration templateConfiguration, final Map<String, Object> ftlContext) {
Map<String, String> fileContentMap = Maps.newHashMap();
ftlContext.put("rlFrameworkModule", REINFORCEMENT_LEARNING_FRAMEWORK_MODULE);
final String reinforcementAgentContent = templateConfiguration.processTemplate(ftlContext, "reinforcement/agent/Agent.ftl");
fileContentMap.put(REINFORCEMENT_LEARNING_FRAMEWORK_MODULE + "/agent.py", reinforcementAgentContent);
final String reinforcementPolicyContent = templateConfiguration.processTemplate(ftlContext, "reinforcement/agent/ActionPolicy.ftl");
fileContentMap.put(REINFORCEMENT_LEARNING_FRAMEWORK_MODULE + "/action_policy.py", reinforcementPolicyContent);
final String replayMemoryContent = templateConfiguration.processTemplate(ftlContext, "reinforcement/agent/ReplayMemory.ftl");
fileContentMap.put(REINFORCEMENT_LEARNING_FRAMEWORK_MODULE + "/replay_memory.py", replayMemoryContent);
final String environmentContent = templateConfiguration.processTemplate(ftlContext, "reinforcement/environment/Environment.ftl");
fileContentMap.put(REINFORCEMENT_LEARNING_FRAMEWORK_MODULE + "/environment.py", environmentContent);
final String utilContent = templateConfiguration.processTemplate(ftlContext, "reinforcement/util/Util.ftl");
fileContentMap.put(REINFORCEMENT_LEARNING_FRAMEWORK_MODULE + "/util.py", utilContent);
String cnnSupervisedTrainerContent = templateConfiguration.processTemplate(ftlContext, "CNNSupervisedTrainer.ftl"); final String initContent = "";
fileContentMap.put("supervised_trainer.py", cnnSupervisedTrainerContent); fileContentMap.put(REINFORCEMENT_LEARNING_FRAMEWORK_MODULE + "/__init__.py", initContent);
return fileContentMap; return fileContentMap;
} }
......
package de.monticore.lang.monticar.cnnarch.gluongenerator;
import de.monticore.lang.monticar.cnnarch.gluongenerator.reinforcement.RewardFunctionParameterAdapter;
import de.monticore.lang.monticar.cnnarch.mxnetgenerator.ConfigurationData;
import de.monticore.lang.monticar.cnntrain._symboltable.*;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
/**
*
*/
public class ReinforcementConfigurationData extends ConfigurationData {
private static final String AST_ENTRY_LEARNING_METHOD = "learning_method";
private static final String AST_ENTRY_NUM_EPISODES = "num_episodes";
private static final String AST_ENTRY_DISCOUNT_FACTOR = "discount_factor";
private static final String AST_ENTRY_NUM_MAX_STEPS = "num_max_steps";
private static final String AST_ENTRY_TARGET_SCORE = "target_score";
private static final String AST_ENTRY_TRAINING_INTERVAL = "training_interval";
private static final String AST_ENTRY_USE_FIX_TARGET_NETWORK = "use_fix_target_network";
private static final String AST_ENTRY_TARGET_NETWORK_UPDATE_INTERVAL = "target_network_update_interval";
private static final String AST_ENTRY_SNAPSHOT_INTERVAL = "snapshot_interval";
private static final String AST_ENTRY_AGENT_NAME = "agent_name";
private static final String AST_ENTRY_USE_DOUBLE_DQN = "use_double_dqn";
private static final String AST_ENTRY_LOSS = "loss";
private static final String AST_ENTRY_REPLAY_MEMORY = "replay_memory";
private static final String AST_ENTRY_ACTION_SELECTION = "action_selection";
private static final String AST_ENTRY_ENVIRONMENT = "environment";
public ReinforcementConfigurationData(ConfigurationSymbol configuration, String instanceName) {
super(configuration, instanceName);
}
public Boolean isSupervisedLearning() {
if (configurationContainsKey(AST_ENTRY_LEARNING_METHOD)) {
return retrieveConfigurationEntryValueByKey(AST_ENTRY_LEARNING_METHOD)
.equals(LearningMethod.SUPERVISED);
}
return true;
}
public Boolean isReinforcementLearning() {
return configurationContainsKey(AST_ENTRY_LEARNING_METHOD)
&& retrieveConfigurationEntryValueByKey(AST_ENTRY_LEARNING_METHOD).equals(LearningMethod.REINFORCEMENT);
}
public Integer getNumEpisodes() {
return !configurationContainsKey(AST_ENTRY_NUM_EPISODES)
? null : (Integer)retrieveConfigurationEntryValueByKey(AST_ENTRY_NUM_EPISODES);
}
public Double getDiscountFactor() {
return !configurationContainsKey(AST_ENTRY_DISCOUNT_FACTOR)
? null : (Double)retrieveConfigurationEntryValueByKey(AST_ENTRY_DISCOUNT_FACTOR);
}
public Integer getNumMaxSteps() {
return !configurationContainsKey(AST_ENTRY_NUM_MAX_STEPS)
? null : (Integer)retrieveConfigurationEntryValueByKey(AST_ENTRY_NUM_MAX_STEPS);
}
public Double getTargetScore() {
return !configurationContainsKey(AST_ENTRY_TARGET_SCORE)
? null : (Double)retrieveConfigurationEntryValueByKey(AST_ENTRY_TARGET_SCORE);
}
public Integer getTrainingInterval() {
return !configurationContainsKey(AST_ENTRY_TRAINING_INTERVAL)
? null : (Integer)retrieveConfigurationEntryValueByKey(AST_ENTRY_TRAINING_INTERVAL);
}
public Boolean getUseFixTargetNetwork() {
return !configurationContainsKey(AST_ENTRY_USE_FIX_TARGET_NETWORK)
? null : (Boolean)retrieveConfigurationEntryValueByKey(AST_ENTRY_USE_FIX_TARGET_NETWORK);
}
public Integer getTargetNetworkUpdateInterval() {
return !configurationContainsKey(AST_ENTRY_TARGET_NETWORK_UPDATE_INTERVAL)
? null : (Integer)retrieveConfigurationEntryValueByKey(AST_ENTRY_TARGET_NETWORK_UPDATE_INTERVAL);
}
public Integer getSnapshotInterval() {
return !configurationContainsKey(AST_ENTRY_SNAPSHOT_INTERVAL)
? null : (Integer)retrieveConfigurationEntryValueByKey(AST_ENTRY_SNAPSHOT_INTERVAL);
}
public String getAgentName() {
return !configurationContainsKey(AST_ENTRY_AGENT_NAME)
? null : (String)retrieveConfigurationEntryValueByKey(AST_ENTRY_AGENT_NAME);
}
public Boolean getUseDoubleDqn() {
return !configurationContainsKey(AST_ENTRY_USE_DOUBLE_DQN)
? null : (Boolean)retrieveConfigurationEntryValueByKey(AST_ENTRY_USE_DOUBLE_DQN);
}
public String getLoss() {
return !configurationContainsKey(AST_ENTRY_LOSS)
? null : retrieveConfigurationEntryValueByKey(AST_ENTRY_LOSS).toString();
}
public Map<String, Object> getReplayMemory() {
return getMultiParamEntry(AST_ENTRY_REPLAY_MEMORY, "method");
}
public Map<String, Object> getActionSelection() {
return getMultiParamEntry(AST_ENTRY_ACTION_SELECTION, "method");
}
public Map<String, Object> getEnvironment() {
return getMultiParamEntry(AST_ENTRY_ENVIRONMENT, "environment");
}
public Boolean hasRewardFunction() {
return this.getConfiguration().getRlRewardFunction().isPresent();
}
public String getRewardFunctionName() {
if (!this.getConfiguration().getRlRewardFunction().isPresent()) {
return null;
}
return String.join("_", this.getConfiguration().getRlRewardFunction()
.get().getRewardFunctionComponentName());
}
private Optional<RewardFunctionParameterAdapter> getRlRewardFunctionParameter() {
if (!this.getConfiguration().getRlRewardFunction().isPresent()
|| !this.getConfiguration().getRlRewardFunction().get().getRewardFunctionParameter().isPresent()) {
return Optional.empty();
}
return Optional.ofNullable(
(RewardFunctionParameterAdapter)this.getConfiguration().getRlRewardFunction().get()
.getRewardFunctionParameter().orElse(null));
}
public Map<String, Object> getRewardFunctionStateParameter() {
if (!getRlRewardFunctionParameter().isPresent()
|| !getRlRewardFunctionParameter().get().getInputStateParameterName().isPresent()) {
return null;
}
return getInputParameterWithName(getRlRewardFunctionParameter().get().getInputStateParameterName().get());
}
public Map<String, Object> getRewardFunctionTerminalParameter() {
if (!getRlRewardFunctionParameter().isPresent()
|| !getRlRewardFunctionParameter().get().getInputTerminalParameter().isPresent()) {
return null;
}
return getInputParameterWithName(getRlRewardFunctionParameter().get().getInputTerminalParameter().get());
}
public String getRewardFunctionOutputName() {
if (!getRlRewardFunctionParameter().isPresent()) {
return null;
}
return getRlRewardFunctionParameter().get().getOutputParameterName().orElse(null);
}
private Map<String, Object> getMultiParamEntry(final String key, final String valueName) {
if (!configurationContainsKey(key)) {
return null;
}
Map<String, Object> resultView = new HashMap<>();
MultiParamValueSymbol multiParamValue = (MultiParamValueSymbol)this.getConfiguration().getEntryMap()
.get(key).getValue();
resultView.put(valueName, multiParamValue.getValue());
resultView.putAll(multiParamValue.getParameters());
return resultView;
}
private Boolean configurationContainsKey(final String key) {
return this.getConfiguration().getEntryMap().containsKey(key);
}
private Object retrieveConfigurationEntryValueByKey(final String key) {
return this.getConfiguration().getEntry(key).getValue().getValue();
}
private Map<String, Object> getInputParameterWithName(final String parameterName) {
if (!getRlRewardFunctionParameter().isPresent()
|| !getRlRewardFunctionParameter().get().getTypeOfInputPort(parameterName).isPresent()
|| !getRlRewardFunctionParameter().get().getInputPortDimensionOfPort(parameterName).isPresent()) {
return null;
}
Map<String, Object> functionStateParameter = new HashMap<>();;
final String portType = getRlRewardFunctionParameter().get().getTypeOfInputPort(parameterName).get();
final List<Integer> dimension = getRlRewardFunctionParameter().get().getInputPortDimensionOfPort(parameterName).get();
String dtype = null;