Commit f9465a60 authored by Nicola Gatto's avatar Nicola Gatto Committed by Evgeny Kusmenko

Implement reinforcement learning

parent 2dd5aae6
......@@ -8,3 +8,4 @@ nppBackup
*.iml
.vscode
......@@ -27,7 +27,7 @@ masterJobLinux:
stage: linux
image: maven:3-jdk-8
script:
- mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B clean deploy --settings settings.xml
- mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B clean deploy --settings settings.xml -Dtest=\!Integration*
- cat target/site/jacoco/index.html
- mvn package sonar:sonar -s settings.xml
only:
......@@ -36,7 +36,7 @@ masterJobLinux:
masterJobWindows:
stage: windows
script:
- mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B clean install --settings settings.xml
- mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B clean install --settings settings.xml -Dtest=\!Integration*
tags:
- Windows10
......@@ -44,7 +44,13 @@ BranchJobLinux:
stage: linux
image: maven:3-jdk-8
script:
- mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B clean install --settings settings.xml
- mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B clean install --settings settings.xml -Dtest=\!Integration*
- cat target/site/jacoco/index.html
except:
- master
PythonWrapperIntegrationTest:
stage: linux
image: registry.git.rwth-aachen.de/monticore/embeddedmontiarc/generators/emadl2pythonwrapper/tests/mvn-swig:latest
script:
- mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B clean install --settings settings.xml -Dtest=IntegrationPythonWrapperTest
......@@ -8,7 +8,7 @@
<groupId>de.monticore.lang.monticar</groupId>
<artifactId>cnnarch-gluon-generator</artifactId>
<version>0.1.6</version>
<version>0.2.0-SNAPSHOT</version>
<!-- == PROJECT DEPENDENCIES ============================================= -->
......@@ -16,9 +16,10 @@
<!-- .. SE-Libraries .................................................. -->
<CNNArch.version>0.3.0-SNAPSHOT</CNNArch.version>
<CNNTrain.version>0.2.6</CNNTrain.version>
<CNNTrain.version>0.3.0-SNAPSHOT</CNNTrain.version>
<CNNArch2MXNet.version>0.2.14-SNAPSHOT</CNNArch2MXNet.version>
<embedded-montiarc-math-opt-generator>0.1.4</embedded-montiarc-math-opt-generator>
<EMADL2PythonWrapper.version>0.0.1</EMADL2PythonWrapper.version>
<!-- .. Libraries .................................................. -->
<guava.version>18.0</guava.version>
......@@ -100,6 +101,12 @@
<version>${embedded-montiarc-math-opt-generator}</version>
</dependency>
<dependency>
<groupId>de.monticore.lang.monticar</groupId>
<artifactId>embedded-montiarc-emadl-pythonwrapper-generator</artifactId>
<version>${EMADL2PythonWrapper.version}</version>
</dependency>
<!-- .. Test Libraries ............................................... -->
<dependency>
......@@ -109,6 +116,13 @@
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-core</artifactId>
<version>1.10.19</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>ch.qos.logback</groupId>
<artifactId>logback-classic</artifactId>
......
......@@ -47,8 +47,10 @@ public class CNNArch2Gluon extends CNNArch2MxNet {
temp = archTc.process("CNNNet", Target.PYTHON);
fileContentMap.put(temp.getKey(), temp.getValue());
temp = archTc.process("CNNDataLoader", Target.PYTHON);
fileContentMap.put(temp.getKey(), temp.getValue());
if (architecture.getDataPath() != null) {
temp = archTc.process("CNNDataLoader", Target.PYTHON);
fileContentMap.put(temp.getKey(), temp.getValue());
}
temp = archTc.process("CNNCreator", Target.PYTHON);
fileContentMap.put(temp.getKey(), temp.getValue());
......
package de.monticore.lang.monticar.cnnarch.gluongenerator;
import de.monticore.lang.monticar.cnnarch.gluongenerator.reinforcement.RewardFunctionParameterAdapter;
import de.monticore.lang.monticar.cnnarch.mxnetgenerator.ConfigurationData;
import de.monticore.lang.monticar.cnntrain._symboltable.*;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
/**
*
*/
public class ReinforcementConfigurationData extends ConfigurationData {
private static final String AST_ENTRY_LEARNING_METHOD = "learning_method";
private static final String AST_ENTRY_NUM_EPISODES = "num_episodes";
private static final String AST_ENTRY_DISCOUNT_FACTOR = "discount_factor";
private static final String AST_ENTRY_NUM_MAX_STEPS = "num_max_steps";
private static final String AST_ENTRY_TARGET_SCORE = "target_score";
private static final String AST_ENTRY_TRAINING_INTERVAL = "training_interval";
private static final String AST_ENTRY_USE_FIX_TARGET_NETWORK = "use_fix_target_network";
private static final String AST_ENTRY_TARGET_NETWORK_UPDATE_INTERVAL = "target_network_update_interval";
private static final String AST_ENTRY_SNAPSHOT_INTERVAL = "snapshot_interval";
private static final String AST_ENTRY_AGENT_NAME = "agent_name";
private static final String AST_ENTRY_USE_DOUBLE_DQN = "use_double_dqn";
private static final String AST_ENTRY_LOSS = "loss";
private static final String AST_ENTRY_REPLAY_MEMORY = "replay_memory";
private static final String AST_ENTRY_ACTION_SELECTION = "action_selection";
private static final String AST_ENTRY_ENVIRONMENT = "environment";
public ReinforcementConfigurationData(ConfigurationSymbol configuration, String instanceName) {
super(configuration, instanceName);
}
public Boolean isSupervisedLearning() {
if (configurationContainsKey(AST_ENTRY_LEARNING_METHOD)) {
return retrieveConfigurationEntryValueByKey(AST_ENTRY_LEARNING_METHOD)
.equals(LearningMethod.SUPERVISED);
}
return true;
}
public Boolean isReinforcementLearning() {
return configurationContainsKey(AST_ENTRY_LEARNING_METHOD)
&& retrieveConfigurationEntryValueByKey(AST_ENTRY_LEARNING_METHOD).equals(LearningMethod.REINFORCEMENT);
}
public Integer getNumEpisodes() {
return !configurationContainsKey(AST_ENTRY_NUM_EPISODES)
? null : (Integer)retrieveConfigurationEntryValueByKey(AST_ENTRY_NUM_EPISODES);
}
public Double getDiscountFactor() {
return !configurationContainsKey(AST_ENTRY_DISCOUNT_FACTOR)
? null : (Double)retrieveConfigurationEntryValueByKey(AST_ENTRY_DISCOUNT_FACTOR);
}
public Integer getNumMaxSteps() {
return !configurationContainsKey(AST_ENTRY_NUM_MAX_STEPS)
? null : (Integer)retrieveConfigurationEntryValueByKey(AST_ENTRY_NUM_MAX_STEPS);
}
public Double getTargetScore() {
return !configurationContainsKey(AST_ENTRY_TARGET_SCORE)
? null : (Double)retrieveConfigurationEntryValueByKey(AST_ENTRY_TARGET_SCORE);
}
public Integer getTrainingInterval() {
return !configurationContainsKey(AST_ENTRY_TRAINING_INTERVAL)
? null : (Integer)retrieveConfigurationEntryValueByKey(AST_ENTRY_TRAINING_INTERVAL);
}
public Boolean getUseFixTargetNetwork() {
return !configurationContainsKey(AST_ENTRY_USE_FIX_TARGET_NETWORK)
? null : (Boolean)retrieveConfigurationEntryValueByKey(AST_ENTRY_USE_FIX_TARGET_NETWORK);
}
public Integer getTargetNetworkUpdateInterval() {
return !configurationContainsKey(AST_ENTRY_TARGET_NETWORK_UPDATE_INTERVAL)
? null : (Integer)retrieveConfigurationEntryValueByKey(AST_ENTRY_TARGET_NETWORK_UPDATE_INTERVAL);
}
public Integer getSnapshotInterval() {
return !configurationContainsKey(AST_ENTRY_SNAPSHOT_INTERVAL)
? null : (Integer)retrieveConfigurationEntryValueByKey(AST_ENTRY_SNAPSHOT_INTERVAL);
}
public String getAgentName() {
return !configurationContainsKey(AST_ENTRY_AGENT_NAME)
? null : (String)retrieveConfigurationEntryValueByKey(AST_ENTRY_AGENT_NAME);
}
public Boolean getUseDoubleDqn() {
return !configurationContainsKey(AST_ENTRY_USE_DOUBLE_DQN)
? null : (Boolean)retrieveConfigurationEntryValueByKey(AST_ENTRY_USE_DOUBLE_DQN);
}
public String getLoss() {
return !configurationContainsKey(AST_ENTRY_LOSS)
? null : retrieveConfigurationEntryValueByKey(AST_ENTRY_LOSS).toString();
}
public Map<String, Object> getReplayMemory() {
return getMultiParamEntry(AST_ENTRY_REPLAY_MEMORY, "method");
}
public Map<String, Object> getActionSelection() {
return getMultiParamEntry(AST_ENTRY_ACTION_SELECTION, "method");
}
public Map<String, Object> getEnvironment() {
return getMultiParamEntry(AST_ENTRY_ENVIRONMENT, "environment");
}
public Boolean hasRewardFunction() {
return this.getConfiguration().getRlRewardFunction().isPresent();
}
public String getRewardFunctionName() {
if (!this.getConfiguration().getRlRewardFunction().isPresent()) {
return null;
}
return String.join("_", this.getConfiguration().getRlRewardFunction()
.get().getRewardFunctionComponentName());
}
private Optional<RewardFunctionParameterAdapter> getRlRewardFunctionParameter() {
if (!this.getConfiguration().getRlRewardFunction().isPresent()
|| !this.getConfiguration().getRlRewardFunction().get().getRewardFunctionParameter().isPresent()) {
return Optional.empty();
}
return Optional.ofNullable(
(RewardFunctionParameterAdapter)this.getConfiguration().getRlRewardFunction().get()
.getRewardFunctionParameter().orElse(null));
}
public Map<String, Object> getRewardFunctionStateParameter() {
if (!getRlRewardFunctionParameter().isPresent()
|| !getRlRewardFunctionParameter().get().getInputStateParameterName().isPresent()) {
return null;
}
return getInputParameterWithName(getRlRewardFunctionParameter().get().getInputStateParameterName().get());
}
public Map<String, Object> getRewardFunctionTerminalParameter() {
if (!getRlRewardFunctionParameter().isPresent()
|| !getRlRewardFunctionParameter().get().getInputTerminalParameter().isPresent()) {
return null;
}
return getInputParameterWithName(getRlRewardFunctionParameter().get().getInputTerminalParameter().get());
}
public String getRewardFunctionOutputName() {
if (!getRlRewardFunctionParameter().isPresent()) {
return null;
}
return getRlRewardFunctionParameter().get().getOutputParameterName().orElse(null);
}
private Map<String, Object> getMultiParamEntry(final String key, final String valueName) {
if (!configurationContainsKey(key)) {
return null;
}
Map<String, Object> resultView = new HashMap<>();
MultiParamValueSymbol multiParamValue = (MultiParamValueSymbol)this.getConfiguration().getEntryMap()
.get(key).getValue();
resultView.put(valueName, multiParamValue.getValue());
resultView.putAll(multiParamValue.getParameters());
return resultView;
}
private Boolean configurationContainsKey(final String key) {
return this.getConfiguration().getEntryMap().containsKey(key);
}
private Object retrieveConfigurationEntryValueByKey(final String key) {
return this.getConfiguration().getEntry(key).getValue().getValue();
}
private Map<String, Object> getInputParameterWithName(final String parameterName) {
if (!getRlRewardFunctionParameter().isPresent()
|| !getRlRewardFunctionParameter().get().getTypeOfInputPort(parameterName).isPresent()
|| !getRlRewardFunctionParameter().get().getInputPortDimensionOfPort(parameterName).isPresent()) {
return null;
}
Map<String, Object> functionStateParameter = new HashMap<>();;
final String portType = getRlRewardFunctionParameter().get().getTypeOfInputPort(parameterName).get();
final List<Integer> dimension = getRlRewardFunctionParameter().get().getInputPortDimensionOfPort(parameterName).get();
String dtype = null;
if (portType.equals("Q")) {
dtype = "double";
} else if (portType.equals("Z")) {
dtype = "int";
} else if (portType.equals("B")) {
dtype = "bool";
}
Boolean isMultiDimensional = dimension.size() > 1
|| (dimension.size() == 1 && dimension.get(0) > 1);
functionStateParameter.put("name", parameterName);
functionStateParameter.put("dtype", dtype);
functionStateParameter.put("isMultiDimensional", isMultiDimensional);
return functionStateParameter;
}
}
package de.monticore.lang.monticar.cnnarch.gluongenerator.reinforcement;
import de.se_rwth.commons.logging.Log;
/**
*
*/
public class FunctionParameterChecker {
private String inputStateParameterName;
private String inputTerminalParameterName;
private String outputParameterName;
private RewardFunctionParameterAdapter rewardFunctionParameter;
public FunctionParameterChecker() {
}
public void check(final RewardFunctionParameterAdapter rewardFunctionParameter) {
this.rewardFunctionParameter = rewardFunctionParameter;
retrieveParameterNames();
checkHasExactlyTwoInputs();
checkHasExactlyOneOutput();
checkHasStateAndTerminalInput();
checkInputStateDimension();
checkInputTerminalTypeAndDimension();
checkOutputDimension();
}
private void checkHasExactlyTwoInputs() {
failIfConditionFails(functionHasTwoInputs(), "Reward function must have exactly two input parameters: "
+ "One input needs to represents the environment's state and another input needs to be a "
+ "boolean value which expresses whether the environment's state is terminal or not");
}
private void checkHasExactlyOneOutput() {
failIfConditionFails(functionHasOneOutput(), "Reward function must have exactly one output");
}
private void checkHasStateAndTerminalInput() {
failIfConditionFails(inputParametersArePresent(),
"Reward function must have exactly two input parameters: "
+"One input needs to represents the environment's state as a numerical scalar, vector or matrice, "
+ "and another input needs to be a "
+ "boolean value which expresses whether the environment's state is terminal or not");
}
private void checkInputStateDimension() {
failIfConditionFails(isInputStateParameterDimensionBetweenOneAndThree(),
"Reward function state parameter with dimension higher than three is not supported");
}
private void checkInputTerminalTypeAndDimension() {
failIfConditionFails(inputTerminalIsBooleanScalar(), "Reward functions needs a terminal input which"
+ " is a boolean scalar");
}
private void checkOutputDimension() {
failIfConditionFails(outputParameterIsScalar(), "Reward function output must be a scalar");
}
private void retrieveParameterNames() {
this.inputStateParameterName = rewardFunctionParameter.getInputStateParameterName().orElse(null);
this.inputTerminalParameterName = rewardFunctionParameter.getInputTerminalParameter().orElse(null);
this.outputParameterName = rewardFunctionParameter.getOutputParameterName().orElse(null);
}
private boolean inputParametersArePresent() {
return rewardFunctionParameter.getInputStateParameterName().isPresent()
&& rewardFunctionParameter.getInputTerminalParameter().isPresent();
}
private boolean functionHasOneOutput() {
return rewardFunctionParameter.getOutputNames().size() == 1;
}
private boolean functionHasTwoInputs() {
return rewardFunctionParameter.getInputNames().size() == 2;
}
private boolean isInputStateParameterDimensionBetweenOneAndThree() {
return (rewardFunctionParameter.getInputPortDimensionOfPort(inputStateParameterName).isPresent())
&& (rewardFunctionParameter.getInputPortDimensionOfPort(inputStateParameterName).get().size() <= 3)
&& (rewardFunctionParameter.getInputPortDimensionOfPort(inputStateParameterName).get().size() > 0);
}
private boolean outputParameterIsScalar() {
return (rewardFunctionParameter.getOutputPortDimensionOfPort(outputParameterName).isPresent())
&& (rewardFunctionParameter.getOutputPortDimensionOfPort(outputParameterName).get().size() == 1)
&& (rewardFunctionParameter.getOutputPortDimensionOfPort(outputParameterName).get().get(0) == 1);
}
private boolean inputTerminalIsBooleanScalar() {
return (rewardFunctionParameter.getInputPortDimensionOfPort(inputTerminalParameterName).isPresent())
&& (rewardFunctionParameter.getTypeOfInputPort(inputTerminalParameterName).isPresent())
&& (rewardFunctionParameter.getInputPortDimensionOfPort(inputTerminalParameterName).get().size() == 1)
&& (rewardFunctionParameter.getInputPortDimensionOfPort(inputTerminalParameterName).get().get(0) == 1)
&& (rewardFunctionParameter.getTypeOfInputPort(inputTerminalParameterName).get().equals("B"));
}
private void failIfConditionFails(final boolean condition, final String message) {
if (!condition) {
fail(message);
}
}
private void fail(final String message) {
Log.error(message);
//System.exit(-1);
}
}
\ No newline at end of file
package de.monticore.lang.monticar.cnnarch.gluongenerator.reinforcement;
import de.monticore.lang.monticar.cnntrain.annotations.RewardFunctionParameter;
import de.monticore.lang.monticar.generator.pythonwrapper.symbolservices.data.ComponentPortInformation;
import de.monticore.lang.monticar.generator.pythonwrapper.symbolservices.data.EmadlType;
import de.monticore.lang.monticar.generator.pythonwrapper.symbolservices.data.PortVariable;
import jdk.nashorn.internal.runtime.options.Option;
import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;
/**
*
*/
public class RewardFunctionParameterAdapter implements RewardFunctionParameter {
private final ComponentPortInformation adaptee;
private String outputParameterName;
private String inputStateParameterName;
private String inputTerminalParameterName;
public RewardFunctionParameterAdapter(final ComponentPortInformation componentPortInformation) {
this.adaptee = componentPortInformation;
}
@Override
public List<String> getInputNames() {
return this.adaptee.getAllInputs().stream()
.map(PortVariable::getVariableName)
.collect(Collectors.toList());
}
@Override
public List<String> getOutputNames() {
return this.adaptee.getAllOutputs().stream()
.map(PortVariable::getVariableName)
.collect(Collectors.toList());
}
@Override
public Optional<String> getTypeOfInputPort(String portName) {
return this.adaptee.getAllInputs().stream()
.filter(port -> port.getVariableName().equals(portName))
.map(port -> port.getEmadlType().toString())
.findFirst();
}
@Override
public Optional<String> getTypeOfOutputPort(String portName) {
return this.adaptee.getAllOutputs().stream()
.filter(port -> port.getVariableName().equals(portName))
.map(port -> port.getEmadlType().toString())
.findFirst();
}
@Override
public Optional<List<Integer>> getInputPortDimensionOfPort(String portName) {
return this.adaptee.getAllInputs().stream()
.filter(port -> port.getVariableName().equals(portName))
.map(PortVariable::getDimension)
.findFirst();
}
@Override
public Optional<List<Integer>> getOutputPortDimensionOfPort(String portName) {
return this.adaptee.getAllOutputs().stream()
.filter(port -> port.getVariableName().equals(portName))
.map(PortVariable::getDimension)
.findFirst();
}
public Optional<String> getOutputParameterName() {
if (this.outputParameterName == null) {
if (this.getOutputNames().size() == 1) {
this.outputParameterName = this.getOutputNames().get(0);
} else {
return Optional.empty();
}
}
return Optional.of(this.outputParameterName);
}
private boolean isBooleanScalar(final PortVariable portVariable) {
return portVariable.getEmadlType().equals(EmadlType.B)
&& portVariable.getDimension().size() == 1
&& portVariable.getDimension().get(0) == 1;
}
private boolean determineInputNames() {
if (this.getInputNames().size() != 2) {
return false;
}
Optional<String> terminalInput = this.adaptee.getAllInputs()
.stream()
.filter(this::isBooleanScalar)
.map(PortVariable::getVariableName)
.findFirst();
if (terminalInput.isPresent()) {
this.inputTerminalParameterName = terminalInput.get();
} else {
return false;
}
Optional<String> stateInput = this.adaptee.getAllInputs().stream()
.filter(portVariable -> !portVariable.getVariableName().equals(this.inputTerminalParameterName))
.filter(portVariable -> !isBooleanScalar(portVariable))
.map(PortVariable::getVariableName)
.findFirst();
if (stateInput.isPresent()) {
this.inputStateParameterName = stateInput.get();
} else {
this.inputTerminalParameterName = null;
return false;
}
return true;
}
public Optional<String> getInputStateParameterName() {
if (this.inputStateParameterName == null) {
this.determineInputNames();
}
return Optional.ofNullable(this.inputStateParameterName);
}
public Optional<String> getInputTerminalParameter() {
if (this.inputTerminalParameterName == null) {
this.determineInputNames();
}
return Optional.ofNullable(this.inputTerminalParameterName);
}