diff --git a/.gitignore b/.gitignore
index 11780027328480bc4bac278fc1025f1fac376079..96c56c406e286b2d7a0200a7af2ebe5e69f934f5 100644
--- a/.gitignore
+++ b/.gitignore
@@ -8,3 +8,4 @@ nppBackup
*.iml
+.vscode
diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml
index a45cd7c87301d8847cb1237464de25df85737f9b..575941b3070abd976c1470ec6e41e6523befbfb6 100644
--- a/.gitlab-ci.yml
+++ b/.gitlab-ci.yml
@@ -27,7 +27,7 @@ masterJobLinux:
stage: linux
image: maven:3-jdk-8
script:
- - mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B clean deploy --settings settings.xml
+ - mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B clean deploy --settings settings.xml -Dtest=\!Integration*
- cat target/site/jacoco/index.html
- mvn package sonar:sonar -s settings.xml
only:
@@ -36,7 +36,7 @@ masterJobLinux:
masterJobWindows:
stage: windows
script:
- - mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B clean install --settings settings.xml
+ - mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B clean install --settings settings.xml -Dtest=\!Integration*
tags:
- Windows10
@@ -44,7 +44,13 @@ BranchJobLinux:
stage: linux
image: maven:3-jdk-8
script:
- - mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B clean install --settings settings.xml
+ - mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B clean install --settings settings.xml -Dtest=\!Integration*
- cat target/site/jacoco/index.html
except:
- master
+
+PythonWrapperIntegrationTest:
+ stage: linux
+ image: registry.git.rwth-aachen.de/monticore/embeddedmontiarc/generators/emadl2pythonwrapper/tests/mvn-swig:latest
+ script:
+ - mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B clean install --settings settings.xml -Dtest=IntegrationPythonWrapperTest
diff --git a/pom.xml b/pom.xml
index 0cb7bc93763d04b91904ba98896ac850dac47bb1..0f5f083578fd22a4f4d4773825583e7a3716390b 100644
--- a/pom.xml
+++ b/pom.xml
@@ -8,7 +8,7 @@
de.monticore.lang.monticar
cnnarch-gluon-generator
- 0.1.6
+ 0.2.0-SNAPSHOT
@@ -16,9 +16,10 @@
0.3.0-SNAPSHOT
- 0.2.6
+ 0.3.0-SNAPSHOT
0.2.14-SNAPSHOT
0.1.4
+ 0.0.1
18.0
@@ -100,6 +101,12 @@
${embedded-montiarc-math-opt-generator}
+
+ de.monticore.lang.monticar
+ embedded-montiarc-emadl-pythonwrapper-generator
+ ${EMADL2PythonWrapper.version}
+
+
@@ -109,6 +116,13 @@
test
+
+ org.mockito
+ mockito-core
+ 1.10.19
+ test
+
+
ch.qos.logback
logback-classic
diff --git a/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/CNNArch2Gluon.java b/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/CNNArch2Gluon.java
index e0a185a91175095e553f26bd472bf5a3a2c23739..2a6106fb7dc6fc47241256ccfa5da786016bd928 100644
--- a/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/CNNArch2Gluon.java
+++ b/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/CNNArch2Gluon.java
@@ -47,8 +47,10 @@ public class CNNArch2Gluon extends CNNArch2MxNet {
temp = archTc.process("CNNNet", Target.PYTHON);
fileContentMap.put(temp.getKey(), temp.getValue());
- temp = archTc.process("CNNDataLoader", Target.PYTHON);
- fileContentMap.put(temp.getKey(), temp.getValue());
+ if (architecture.getDataPath() != null) {
+ temp = archTc.process("CNNDataLoader", Target.PYTHON);
+ fileContentMap.put(temp.getKey(), temp.getValue());
+ }
temp = archTc.process("CNNCreator", Target.PYTHON);
fileContentMap.put(temp.getKey(), temp.getValue());
diff --git a/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/CNNTrain2Gluon.java b/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/CNNTrain2Gluon.java
index a59beb14d9f737cf1b5fcaa6f6e50c6672e9fd66..980b517a157232ee85afb8309e1eadff4b1275f6 100644
--- a/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/CNNTrain2Gluon.java
+++ b/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/CNNTrain2Gluon.java
@@ -1,33 +1,187 @@
package de.monticore.lang.monticar.cnnarch.gluongenerator;
+import com.google.common.collect.Maps;
+import de.monticore.lang.monticar.cnnarch.gluongenerator.reinforcement.FunctionParameterChecker;
+import de.monticore.lang.monticar.cnnarch.gluongenerator.reinforcement.RewardFunctionParameterAdapter;
+import de.monticore.lang.monticar.cnnarch.gluongenerator.reinforcement.RewardFunctionSourceGenerator;
import de.monticore.lang.monticar.cnnarch.mxnetgenerator.ConfigurationData;
import de.monticore.lang.monticar.cnnarch.mxnetgenerator.CNNTrain2MxNet;
import de.monticore.lang.monticar.cnnarch.mxnetgenerator.TemplateConfiguration;
import de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol;
+import de.monticore.lang.monticar.cnntrain._symboltable.LearningMethod;
+import de.monticore.lang.monticar.cnntrain._symboltable.RewardFunctionSymbol;
+import de.monticore.lang.monticar.generator.FileContent;
+import de.monticore.lang.monticar.generator.cpp.GeneratorCPP;
+import de.monticore.lang.monticar.generator.pythonwrapper.GeneratorPythonWrapperStandaloneApi;
+import de.monticore.lang.monticar.generator.pythonwrapper.symbolservices.data.ComponentPortInformation;
+import de.se_rwth.commons.logging.Log;
+import java.io.File;
+import java.io.IOException;
+import java.nio.charset.Charset;
+import java.nio.charset.StandardCharsets;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
import java.util.*;
public class CNNTrain2Gluon extends CNNTrain2MxNet {
- public CNNTrain2Gluon() {
+ private static final String REINFORCEMENT_LEARNING_FRAMEWORK_MODULE = "reinforcement_learning";
+
+ private final RewardFunctionSourceGenerator rewardFunctionSourceGenerator;
+ private String rootProjectModelsDir;
+
+ public Optional getRootProjectModelsDir() {
+ return Optional.ofNullable(rootProjectModelsDir);
+ }
+
+ public void setRootProjectModelsDir(String rootProjectModelsDir) {
+ this.rootProjectModelsDir = rootProjectModelsDir;
+ }
+
+ public CNNTrain2Gluon(RewardFunctionSourceGenerator rewardFunctionSourceGenerator) {
super();
+ this.rewardFunctionSourceGenerator = rewardFunctionSourceGenerator;
+ }
+
+ @Override
+ public ConfigurationSymbol getConfigurationSymbol(Path modelsDirPath, String rootModelName) {
+ ConfigurationSymbol configurationSymbol = super.getConfigurationSymbol(modelsDirPath, rootModelName);
+
+ // Generate Reward function if necessary
+ if (configurationSymbol.getLearningMethod().equals(LearningMethod.REINFORCEMENT)
+ && configurationSymbol.getRlRewardFunction().isPresent()) {
+ generateRewardFunction(configurationSymbol.getRlRewardFunction().get(), modelsDirPath);
+ }
+
+ return configurationSymbol;
+ }
+
+ @Override
+ public void generate(Path modelsDirPath, String rootModelName) {
+ ConfigurationSymbol configuration = this.getConfigurationSymbol(modelsDirPath, rootModelName);
+
+ Map fileContents = this.generateStrings(configuration);
+ GeneratorCPP genCPP = new GeneratorCPP();
+ genCPP.setGenerationTargetPath(this.getGenerationTargetPath());
+
+ try {
+ Iterator var6 = fileContents.keySet().iterator();
+
+ while(var6.hasNext()) {
+ String fileName = (String)var6.next();
+ genCPP.generateFile(new FileContent((String)fileContents.get(fileName), fileName));
+ }
+ } catch (IOException var8) {
+ Log.error("CNNTrainer file could not be generated" + var8.getMessage());
+ }
}
@Override
public Map generateStrings(ConfigurationSymbol configuration) {
TemplateConfiguration templateConfiguration = new GluonTemplateConfiguration();
- ConfigurationData configData = new ConfigurationData(configuration, getInstanceName());
+ ReinforcementConfigurationData configData = new ReinforcementConfigurationData(configuration, getInstanceName());
List configDataList = new ArrayList<>();
configDataList.add(configData);
- Map ftlContext = Collections.singletonMap("configurations", configDataList);
+
+ Map ftlContext = Maps.newHashMap();
+ ftlContext.put("configurations", configDataList);
Map fileContentMap = new HashMap<>();
- String cnnTrainTemplateContent = templateConfiguration.processTemplate(ftlContext, "CNNTrainer.ftl");
- fileContentMap.put("CNNTrainer_" + getInstanceName() + ".py", cnnTrainTemplateContent);
+ if (configData.isSupervisedLearning()) {
+ String cnnTrainTemplateContent = templateConfiguration.processTemplate(ftlContext, "CNNTrainer.ftl");
+ fileContentMap.put("CNNTrainer_" + getInstanceName() + ".py", cnnTrainTemplateContent);
+
+ String cnnSupervisedTrainerContent = templateConfiguration.processTemplate(ftlContext, "CNNSupervisedTrainer.ftl");
+ fileContentMap.put("supervised_trainer.py", cnnSupervisedTrainerContent);
+ } else if (configData.isReinforcementLearning()) {
+ final String trainerName = "CNNTrainer_" + getInstanceName();
+ ftlContext.put("trainerName", trainerName);
+ Map rlFrameworkContentMap = constructReinforcementLearningFramework(templateConfiguration, ftlContext);
+ fileContentMap.putAll(rlFrameworkContentMap);
+
+ final String reinforcementTrainerContent = templateConfiguration.processTemplate(ftlContext, "reinforcement/Trainer.ftl");
+ fileContentMap.put(trainerName + ".py", reinforcementTrainerContent);
+
+ final String startTrainerScriptContent = templateConfiguration.processTemplate(ftlContext, "reinforcement/StartTrainer.ftl");
+ fileContentMap.put("start_training.sh", startTrainerScriptContent);
+ }
+ return fileContentMap;
+ }
+
+ private void generateRewardFunction(RewardFunctionSymbol rewardFunctionSymbol, Path modelsDirPath) {
+ GeneratorPythonWrapperStandaloneApi pythonWrapperApi = new GeneratorPythonWrapperStandaloneApi();
+
+ List fullNameOfComponent = rewardFunctionSymbol.getRewardFunctionComponentName();
+
+ String rewardFunctionRootModel = String.join(".", fullNameOfComponent);
+ String rewardFunctionOutputPath = Paths.get(this.getGenerationTargetPath(), "reward").toString();
+
+ if (!getRootProjectModelsDir().isPresent()) {
+ setRootProjectModelsDir(modelsDirPath.toString());
+ }
+
+ rewardFunctionSourceGenerator.generate(getRootProjectModelsDir().get(),
+ rewardFunctionRootModel, rewardFunctionOutputPath);
+ fixArmadilloEmamGenerationOfFile(Paths.get(rewardFunctionOutputPath, String.join("_", fullNameOfComponent) + ".h"));
+
+ String pythonWrapperOutputPath = Paths.get(rewardFunctionOutputPath, "pylib").toString();
+
+ Log.info("Generating reward function python wrapper...", "CNNTrain2Gluon");
+ ComponentPortInformation componentPortInformation;
+ if (pythonWrapperApi.checkIfPythonModuleBuildAvailable()) {
+ final String rewardModuleOutput
+ = Paths.get(getGenerationTargetPath(), REINFORCEMENT_LEARNING_FRAMEWORK_MODULE).toString();
+ componentPortInformation = pythonWrapperApi.generateAndTryBuilding(getRootProjectModelsDir().get(),
+ rewardFunctionRootModel, pythonWrapperOutputPath, rewardModuleOutput);
+ } else {
+ Log.warn("Cannot build wrapper automatically: OS not supported. Please build manually before starting training.");
+ componentPortInformation = pythonWrapperApi.generate(getRootProjectModelsDir().get(), rewardFunctionRootModel,
+ pythonWrapperOutputPath);
+ }
+ RewardFunctionParameterAdapter functionParameter = new RewardFunctionParameterAdapter(componentPortInformation);
+ new FunctionParameterChecker().check(functionParameter);
+ rewardFunctionSymbol.setRewardFunctionParameter(functionParameter);
+ }
+
+ private void fixArmadilloEmamGenerationOfFile(Path pathToBrokenFile){
+ final File brokenFile = pathToBrokenFile.toFile();
+ if (brokenFile.exists()) {
+ try {
+ Charset charset = StandardCharsets.UTF_8;
+ String fileContent = new String(Files.readAllBytes(pathToBrokenFile), charset);
+ fileContent = fileContent.replace("armadillo.h", "armadillo");
+ Files.write(pathToBrokenFile, fileContent.getBytes());
+ } catch (IOException e) {
+ Log.warn("Cannot fix wrong armadillo library in " + pathToBrokenFile.toString());
+ }
+ }
+ }
+
+ private Map constructReinforcementLearningFramework(
+ final TemplateConfiguration templateConfiguration, final Map ftlContext) {
+ Map fileContentMap = Maps.newHashMap();
+ ftlContext.put("rlFrameworkModule", REINFORCEMENT_LEARNING_FRAMEWORK_MODULE);
+
+ final String reinforcementAgentContent = templateConfiguration.processTemplate(ftlContext, "reinforcement/agent/Agent.ftl");
+ fileContentMap.put(REINFORCEMENT_LEARNING_FRAMEWORK_MODULE + "/agent.py", reinforcementAgentContent);
+
+ final String reinforcementPolicyContent = templateConfiguration.processTemplate(ftlContext, "reinforcement/agent/ActionPolicy.ftl");
+ fileContentMap.put(REINFORCEMENT_LEARNING_FRAMEWORK_MODULE + "/action_policy.py", reinforcementPolicyContent);
+
+ final String replayMemoryContent = templateConfiguration.processTemplate(ftlContext, "reinforcement/agent/ReplayMemory.ftl");
+ fileContentMap.put(REINFORCEMENT_LEARNING_FRAMEWORK_MODULE + "/replay_memory.py", replayMemoryContent);
+
+ final String environmentContent = templateConfiguration.processTemplate(ftlContext, "reinforcement/environment/Environment.ftl");
+ fileContentMap.put(REINFORCEMENT_LEARNING_FRAMEWORK_MODULE + "/environment.py", environmentContent);
+
+ final String utilContent = templateConfiguration.processTemplate(ftlContext, "reinforcement/util/Util.ftl");
+ fileContentMap.put(REINFORCEMENT_LEARNING_FRAMEWORK_MODULE + "/util.py", utilContent);
- String cnnSupervisedTrainerContent = templateConfiguration.processTemplate(ftlContext, "CNNSupervisedTrainer.ftl");
- fileContentMap.put("supervised_trainer.py", cnnSupervisedTrainerContent);
+ final String initContent = "";
+ fileContentMap.put(REINFORCEMENT_LEARNING_FRAMEWORK_MODULE + "/__init__.py", initContent);
return fileContentMap;
}
diff --git a/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/ReinforcementConfigurationData.java b/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/ReinforcementConfigurationData.java
new file mode 100644
index 0000000000000000000000000000000000000000..c5991e7d8190f6b5492d5d506bd94295c82cfab9
--- /dev/null
+++ b/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/ReinforcementConfigurationData.java
@@ -0,0 +1,216 @@
+package de.monticore.lang.monticar.cnnarch.gluongenerator;
+
+import de.monticore.lang.monticar.cnnarch.gluongenerator.reinforcement.RewardFunctionParameterAdapter;
+import de.monticore.lang.monticar.cnnarch.mxnetgenerator.ConfigurationData;
+import de.monticore.lang.monticar.cnntrain._symboltable.*;
+
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+
+/**
+ *
+ */
+public class ReinforcementConfigurationData extends ConfigurationData {
+ private static final String AST_ENTRY_LEARNING_METHOD = "learning_method";
+ private static final String AST_ENTRY_NUM_EPISODES = "num_episodes";
+ private static final String AST_ENTRY_DISCOUNT_FACTOR = "discount_factor";
+ private static final String AST_ENTRY_NUM_MAX_STEPS = "num_max_steps";
+ private static final String AST_ENTRY_TARGET_SCORE = "target_score";
+ private static final String AST_ENTRY_TRAINING_INTERVAL = "training_interval";
+ private static final String AST_ENTRY_USE_FIX_TARGET_NETWORK = "use_fix_target_network";
+ private static final String AST_ENTRY_TARGET_NETWORK_UPDATE_INTERVAL = "target_network_update_interval";
+ private static final String AST_ENTRY_SNAPSHOT_INTERVAL = "snapshot_interval";
+ private static final String AST_ENTRY_AGENT_NAME = "agent_name";
+ private static final String AST_ENTRY_USE_DOUBLE_DQN = "use_double_dqn";
+ private static final String AST_ENTRY_LOSS = "loss";
+
+ private static final String AST_ENTRY_REPLAY_MEMORY = "replay_memory";
+ private static final String AST_ENTRY_ACTION_SELECTION = "action_selection";
+ private static final String AST_ENTRY_ENVIRONMENT = "environment";
+
+ public ReinforcementConfigurationData(ConfigurationSymbol configuration, String instanceName) {
+ super(configuration, instanceName);
+ }
+
+ public Boolean isSupervisedLearning() {
+ if (configurationContainsKey(AST_ENTRY_LEARNING_METHOD)) {
+ return retrieveConfigurationEntryValueByKey(AST_ENTRY_LEARNING_METHOD)
+ .equals(LearningMethod.SUPERVISED);
+ }
+ return true;
+ }
+
+ public Boolean isReinforcementLearning() {
+ return configurationContainsKey(AST_ENTRY_LEARNING_METHOD)
+ && retrieveConfigurationEntryValueByKey(AST_ENTRY_LEARNING_METHOD).equals(LearningMethod.REINFORCEMENT);
+ }
+
+ public Integer getNumEpisodes() {
+ return !configurationContainsKey(AST_ENTRY_NUM_EPISODES)
+ ? null : (Integer)retrieveConfigurationEntryValueByKey(AST_ENTRY_NUM_EPISODES);
+ }
+
+ public Double getDiscountFactor() {
+ return !configurationContainsKey(AST_ENTRY_DISCOUNT_FACTOR)
+ ? null : (Double)retrieveConfigurationEntryValueByKey(AST_ENTRY_DISCOUNT_FACTOR);
+ }
+
+ public Integer getNumMaxSteps() {
+ return !configurationContainsKey(AST_ENTRY_NUM_MAX_STEPS)
+ ? null : (Integer)retrieveConfigurationEntryValueByKey(AST_ENTRY_NUM_MAX_STEPS);
+ }
+
+ public Double getTargetScore() {
+ return !configurationContainsKey(AST_ENTRY_TARGET_SCORE)
+ ? null : (Double)retrieveConfigurationEntryValueByKey(AST_ENTRY_TARGET_SCORE);
+ }
+
+ public Integer getTrainingInterval() {
+ return !configurationContainsKey(AST_ENTRY_TRAINING_INTERVAL)
+ ? null : (Integer)retrieveConfigurationEntryValueByKey(AST_ENTRY_TRAINING_INTERVAL);
+ }
+
+ public Boolean getUseFixTargetNetwork() {
+ return !configurationContainsKey(AST_ENTRY_USE_FIX_TARGET_NETWORK)
+ ? null : (Boolean)retrieveConfigurationEntryValueByKey(AST_ENTRY_USE_FIX_TARGET_NETWORK);
+ }
+
+ public Integer getTargetNetworkUpdateInterval() {
+ return !configurationContainsKey(AST_ENTRY_TARGET_NETWORK_UPDATE_INTERVAL)
+ ? null : (Integer)retrieveConfigurationEntryValueByKey(AST_ENTRY_TARGET_NETWORK_UPDATE_INTERVAL);
+ }
+
+ public Integer getSnapshotInterval() {
+ return !configurationContainsKey(AST_ENTRY_SNAPSHOT_INTERVAL)
+ ? null : (Integer)retrieveConfigurationEntryValueByKey(AST_ENTRY_SNAPSHOT_INTERVAL);
+ }
+
+ public String getAgentName() {
+ return !configurationContainsKey(AST_ENTRY_AGENT_NAME)
+ ? null : (String)retrieveConfigurationEntryValueByKey(AST_ENTRY_AGENT_NAME);
+ }
+
+ public Boolean getUseDoubleDqn() {
+ return !configurationContainsKey(AST_ENTRY_USE_DOUBLE_DQN)
+ ? null : (Boolean)retrieveConfigurationEntryValueByKey(AST_ENTRY_USE_DOUBLE_DQN);
+ }
+
+ public String getLoss() {
+ return !configurationContainsKey(AST_ENTRY_LOSS)
+ ? null : retrieveConfigurationEntryValueByKey(AST_ENTRY_LOSS).toString();
+ }
+
+ public Map getReplayMemory() {
+ return getMultiParamEntry(AST_ENTRY_REPLAY_MEMORY, "method");
+ }
+
+ public Map getActionSelection() {
+ return getMultiParamEntry(AST_ENTRY_ACTION_SELECTION, "method");
+ }
+
+ public Map getEnvironment() {
+ return getMultiParamEntry(AST_ENTRY_ENVIRONMENT, "environment");
+ }
+
+ public Boolean hasRewardFunction() {
+ return this.getConfiguration().getRlRewardFunction().isPresent();
+ }
+
+ public String getRewardFunctionName() {
+ if (!this.getConfiguration().getRlRewardFunction().isPresent()) {
+ return null;
+ }
+ return String.join("_", this.getConfiguration().getRlRewardFunction()
+ .get().getRewardFunctionComponentName());
+ }
+
+ private Optional getRlRewardFunctionParameter() {
+ if (!this.getConfiguration().getRlRewardFunction().isPresent()
+ || !this.getConfiguration().getRlRewardFunction().get().getRewardFunctionParameter().isPresent()) {
+ return Optional.empty();
+ }
+ return Optional.ofNullable(
+ (RewardFunctionParameterAdapter)this.getConfiguration().getRlRewardFunction().get()
+ .getRewardFunctionParameter().orElse(null));
+ }
+
+ public Map getRewardFunctionStateParameter() {
+ if (!getRlRewardFunctionParameter().isPresent()
+ || !getRlRewardFunctionParameter().get().getInputStateParameterName().isPresent()) {
+ return null;
+ }
+ return getInputParameterWithName(getRlRewardFunctionParameter().get().getInputStateParameterName().get());
+ }
+
+ public Map getRewardFunctionTerminalParameter() {
+ if (!getRlRewardFunctionParameter().isPresent()
+ || !getRlRewardFunctionParameter().get().getInputTerminalParameter().isPresent()) {
+ return null;
+ }
+ return getInputParameterWithName(getRlRewardFunctionParameter().get().getInputTerminalParameter().get());
+ }
+
+ public String getRewardFunctionOutputName() {
+ if (!getRlRewardFunctionParameter().isPresent()) {
+ return null;
+ }
+ return getRlRewardFunctionParameter().get().getOutputParameterName().orElse(null);
+ }
+
+ private Map getMultiParamEntry(final String key, final String valueName) {
+ if (!configurationContainsKey(key)) {
+ return null;
+ }
+
+ Map resultView = new HashMap<>();
+
+ MultiParamValueSymbol multiParamValue = (MultiParamValueSymbol)this.getConfiguration().getEntryMap()
+ .get(key).getValue();
+
+ resultView.put(valueName, multiParamValue.getValue());
+ resultView.putAll(multiParamValue.getParameters());
+ return resultView;
+ }
+
+ private Boolean configurationContainsKey(final String key) {
+ return this.getConfiguration().getEntryMap().containsKey(key);
+ }
+
+ private Object retrieveConfigurationEntryValueByKey(final String key) {
+ return this.getConfiguration().getEntry(key).getValue().getValue();
+ }
+
+ private Map getInputParameterWithName(final String parameterName) {
+ if (!getRlRewardFunctionParameter().isPresent()
+ || !getRlRewardFunctionParameter().get().getTypeOfInputPort(parameterName).isPresent()
+ || !getRlRewardFunctionParameter().get().getInputPortDimensionOfPort(parameterName).isPresent()) {
+ return null;
+ }
+
+ Map functionStateParameter = new HashMap<>();;
+
+ final String portType = getRlRewardFunctionParameter().get().getTypeOfInputPort(parameterName).get();
+ final List dimension = getRlRewardFunctionParameter().get().getInputPortDimensionOfPort(parameterName).get();
+
+ String dtype = null;
+ if (portType.equals("Q")) {
+ dtype = "double";
+ } else if (portType.equals("Z")) {
+ dtype = "int";
+ } else if (portType.equals("B")) {
+ dtype = "bool";
+ }
+
+ Boolean isMultiDimensional = dimension.size() > 1
+ || (dimension.size() == 1 && dimension.get(0) > 1);
+
+
+ functionStateParameter.put("name", parameterName);
+ functionStateParameter.put("dtype", dtype);
+ functionStateParameter.put("isMultiDimensional", isMultiDimensional);
+
+ return functionStateParameter;
+ }
+}
diff --git a/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/reinforcement/FunctionParameterChecker.java b/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/reinforcement/FunctionParameterChecker.java
new file mode 100644
index 0000000000000000000000000000000000000000..7b72bd8f5c5343bcdfe6a878c3382b8054142b0d
--- /dev/null
+++ b/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/reinforcement/FunctionParameterChecker.java
@@ -0,0 +1,109 @@
+package de.monticore.lang.monticar.cnnarch.gluongenerator.reinforcement;
+
+import de.se_rwth.commons.logging.Log;
+
+/**
+ *
+ */
+public class FunctionParameterChecker {
+ private String inputStateParameterName;
+ private String inputTerminalParameterName;
+ private String outputParameterName;
+ private RewardFunctionParameterAdapter rewardFunctionParameter;
+
+ public FunctionParameterChecker() {
+ }
+
+ public void check(final RewardFunctionParameterAdapter rewardFunctionParameter) {
+ this.rewardFunctionParameter = rewardFunctionParameter;
+ retrieveParameterNames();
+ checkHasExactlyTwoInputs();
+ checkHasExactlyOneOutput();
+ checkHasStateAndTerminalInput();
+ checkInputStateDimension();
+ checkInputTerminalTypeAndDimension();
+ checkOutputDimension();
+ }
+
+ private void checkHasExactlyTwoInputs() {
+ failIfConditionFails(functionHasTwoInputs(), "Reward function must have exactly two input parameters: "
+ + "One input needs to represents the environment's state and another input needs to be a "
+ + "boolean value which expresses whether the environment's state is terminal or not");
+ }
+
+ private void checkHasExactlyOneOutput() {
+ failIfConditionFails(functionHasOneOutput(), "Reward function must have exactly one output");
+ }
+
+ private void checkHasStateAndTerminalInput() {
+ failIfConditionFails(inputParametersArePresent(),
+ "Reward function must have exactly two input parameters: "
+ +"One input needs to represents the environment's state as a numerical scalar, vector or matrice, "
+ + "and another input needs to be a "
+ + "boolean value which expresses whether the environment's state is terminal or not");
+ }
+
+ private void checkInputStateDimension() {
+ failIfConditionFails(isInputStateParameterDimensionBetweenOneAndThree(),
+ "Reward function state parameter with dimension higher than three is not supported");
+ }
+
+ private void checkInputTerminalTypeAndDimension() {
+ failIfConditionFails(inputTerminalIsBooleanScalar(), "Reward functions needs a terminal input which"
+ + " is a boolean scalar");
+ }
+
+ private void checkOutputDimension() {
+ failIfConditionFails(outputParameterIsScalar(), "Reward function output must be a scalar");
+ }
+
+ private void retrieveParameterNames() {
+ this.inputStateParameterName = rewardFunctionParameter.getInputStateParameterName().orElse(null);
+ this.inputTerminalParameterName = rewardFunctionParameter.getInputTerminalParameter().orElse(null);
+ this.outputParameterName = rewardFunctionParameter.getOutputParameterName().orElse(null);
+ }
+
+ private boolean inputParametersArePresent() {
+ return rewardFunctionParameter.getInputStateParameterName().isPresent()
+ && rewardFunctionParameter.getInputTerminalParameter().isPresent();
+ }
+
+ private boolean functionHasOneOutput() {
+ return rewardFunctionParameter.getOutputNames().size() == 1;
+ }
+
+ private boolean functionHasTwoInputs() {
+ return rewardFunctionParameter.getInputNames().size() == 2;
+ }
+
+ private boolean isInputStateParameterDimensionBetweenOneAndThree() {
+ return (rewardFunctionParameter.getInputPortDimensionOfPort(inputStateParameterName).isPresent())
+ && (rewardFunctionParameter.getInputPortDimensionOfPort(inputStateParameterName).get().size() <= 3)
+ && (rewardFunctionParameter.getInputPortDimensionOfPort(inputStateParameterName).get().size() > 0);
+ }
+
+ private boolean outputParameterIsScalar() {
+ return (rewardFunctionParameter.getOutputPortDimensionOfPort(outputParameterName).isPresent())
+ && (rewardFunctionParameter.getOutputPortDimensionOfPort(outputParameterName).get().size() == 1)
+ && (rewardFunctionParameter.getOutputPortDimensionOfPort(outputParameterName).get().get(0) == 1);
+ }
+
+ private boolean inputTerminalIsBooleanScalar() {
+ return (rewardFunctionParameter.getInputPortDimensionOfPort(inputTerminalParameterName).isPresent())
+ && (rewardFunctionParameter.getTypeOfInputPort(inputTerminalParameterName).isPresent())
+ && (rewardFunctionParameter.getInputPortDimensionOfPort(inputTerminalParameterName).get().size() == 1)
+ && (rewardFunctionParameter.getInputPortDimensionOfPort(inputTerminalParameterName).get().get(0) == 1)
+ && (rewardFunctionParameter.getTypeOfInputPort(inputTerminalParameterName).get().equals("B"));
+ }
+
+ private void failIfConditionFails(final boolean condition, final String message) {
+ if (!condition) {
+ fail(message);
+ }
+ }
+
+ private void fail(final String message) {
+ Log.error(message);
+ //System.exit(-1);
+ }
+}
\ No newline at end of file
diff --git a/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/reinforcement/RewardFunctionParameterAdapter.java b/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/reinforcement/RewardFunctionParameterAdapter.java
new file mode 100644
index 0000000000000000000000000000000000000000..2796b2aef5165580b588dc3b53390ad2cdcc8455
--- /dev/null
+++ b/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/reinforcement/RewardFunctionParameterAdapter.java
@@ -0,0 +1,136 @@
+package de.monticore.lang.monticar.cnnarch.gluongenerator.reinforcement;
+
+import de.monticore.lang.monticar.cnntrain.annotations.RewardFunctionParameter;
+import de.monticore.lang.monticar.generator.pythonwrapper.symbolservices.data.ComponentPortInformation;
+import de.monticore.lang.monticar.generator.pythonwrapper.symbolservices.data.EmadlType;
+import de.monticore.lang.monticar.generator.pythonwrapper.symbolservices.data.PortVariable;
+import jdk.nashorn.internal.runtime.options.Option;
+
+import java.util.List;
+import java.util.Optional;
+import java.util.stream.Collectors;
+
+/**
+ *
+ */
+public class RewardFunctionParameterAdapter implements RewardFunctionParameter {
+ private final ComponentPortInformation adaptee;
+ private String outputParameterName;
+ private String inputStateParameterName;
+ private String inputTerminalParameterName;
+
+ public RewardFunctionParameterAdapter(final ComponentPortInformation componentPortInformation) {
+ this.adaptee = componentPortInformation;
+ }
+
+ @Override
+ public List getInputNames() {
+ return this.adaptee.getAllInputs().stream()
+ .map(PortVariable::getVariableName)
+ .collect(Collectors.toList());
+ }
+
+ @Override
+ public List getOutputNames() {
+ return this.adaptee.getAllOutputs().stream()
+ .map(PortVariable::getVariableName)
+ .collect(Collectors.toList());
+ }
+
+ @Override
+ public Optional getTypeOfInputPort(String portName) {
+ return this.adaptee.getAllInputs().stream()
+ .filter(port -> port.getVariableName().equals(portName))
+ .map(port -> port.getEmadlType().toString())
+ .findFirst();
+ }
+
+ @Override
+ public Optional getTypeOfOutputPort(String portName) {
+ return this.adaptee.getAllOutputs().stream()
+ .filter(port -> port.getVariableName().equals(portName))
+ .map(port -> port.getEmadlType().toString())
+ .findFirst();
+ }
+
+ @Override
+ public Optional> getInputPortDimensionOfPort(String portName) {
+ return this.adaptee.getAllInputs().stream()
+ .filter(port -> port.getVariableName().equals(portName))
+ .map(PortVariable::getDimension)
+ .findFirst();
+ }
+
+ @Override
+ public Optional> getOutputPortDimensionOfPort(String portName) {
+ return this.adaptee.getAllOutputs().stream()
+ .filter(port -> port.getVariableName().equals(portName))
+ .map(PortVariable::getDimension)
+ .findFirst();
+ }
+
+ public Optional getOutputParameterName() {
+ if (this.outputParameterName == null) {
+ if (this.getOutputNames().size() == 1) {
+ this.outputParameterName = this.getOutputNames().get(0);
+ } else {
+ return Optional.empty();
+ }
+ }
+ return Optional.of(this.outputParameterName);
+ }
+
+
+ private boolean isBooleanScalar(final PortVariable portVariable) {
+ return portVariable.getEmadlType().equals(EmadlType.B)
+ && portVariable.getDimension().size() == 1
+ && portVariable.getDimension().get(0) == 1;
+ }
+
+ private boolean determineInputNames() {
+ if (this.getInputNames().size() != 2) {
+ return false;
+ }
+ Optional terminalInput = this.adaptee.getAllInputs()
+ .stream()
+ .filter(this::isBooleanScalar)
+ .map(PortVariable::getVariableName)
+ .findFirst();
+
+ if (terminalInput.isPresent()) {
+ this.inputTerminalParameterName = terminalInput.get();
+ } else {
+ return false;
+ }
+
+ Optional stateInput = this.adaptee.getAllInputs().stream()
+ .filter(portVariable -> !portVariable.getVariableName().equals(this.inputTerminalParameterName))
+ .filter(portVariable -> !isBooleanScalar(portVariable))
+ .map(PortVariable::getVariableName)
+ .findFirst();
+
+ if (stateInput.isPresent()) {
+ this.inputStateParameterName = stateInput.get();
+ } else {
+ this.inputTerminalParameterName = null;
+ return false;
+ }
+ return true;
+ }
+
+ public Optional getInputStateParameterName() {
+ if (this.inputStateParameterName == null) {
+ this.determineInputNames();
+ }
+
+ return Optional.ofNullable(this.inputStateParameterName);
+ }
+
+ public Optional getInputTerminalParameter() {
+ if (this.inputTerminalParameterName == null) {
+ this.determineInputNames();
+ }
+
+ return Optional.ofNullable(this.inputTerminalParameterName);
+ }
+}
\ No newline at end of file
diff --git a/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/reinforcement/RewardFunctionSourceGenerator.java b/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/reinforcement/RewardFunctionSourceGenerator.java
new file mode 100644
index 0000000000000000000000000000000000000000..b94e58d55e1b0734ee4803871d7cfb9d828ef3d9
--- /dev/null
+++ b/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/reinforcement/RewardFunctionSourceGenerator.java
@@ -0,0 +1,8 @@
+package de.monticore.lang.monticar.cnnarch.gluongenerator.reinforcement;
+
+/**
+ *
+ */
+public interface RewardFunctionSourceGenerator {
+ void generate(String modelPath, String qualifiedName, String targetPath);
+}
\ No newline at end of file
diff --git a/src/main/resources/templates/gluon/CNNCreator.ftl b/src/main/resources/templates/gluon/CNNCreator.ftl
index fd2e406106fa02ed3302c71f4de240ae36cfb478..1a9c62fe6bb9bd0dda46c5d52c654d3fcb81c503 100644
--- a/src/main/resources/templates/gluon/CNNCreator.ftl
+++ b/src/main/resources/templates/gluon/CNNCreator.ftl
@@ -6,12 +6,15 @@ from CNNNet_${tc.fullArchitectureName} import Net
class ${tc.fileNameWithoutEnding}:
_model_dir_ = "model/${tc.componentName}/"
_model_prefix_ = "model"
- _input_shapes_ = [<#list tc.architecture.inputs as input>(${tc.join(input.definition.type.dimensions, ",")})#list>]
+ _input_shapes_ = [<#list tc.architecture.inputs as input>(${tc.join(input.definition.type.dimensions, ",")},)#list>]
def __init__(self):
self.weight_initializer = mx.init.Normal()
self.net = None
+ def get_input_shapes(self):
+ return self._input_shapes_
+
def load(self, context):
lastEpoch = 0
param_file = None
diff --git a/src/main/resources/templates/gluon/CNNPredictor.ftl b/src/main/resources/templates/gluon/CNNPredictor.ftl
index 283ed874833e8a0ed1738a5d1cf256fe59c81208..150033751319c61eaee847ae3df25ac1ede89d6d 100644
--- a/src/main/resources/templates/gluon/CNNPredictor.ftl
+++ b/src/main/resources/templates/gluon/CNNPredictor.ftl
@@ -31,8 +31,7 @@ public:
void predict(${tc.join(tc.architectureInputs, ", ", "const std::vector &", "")},
${tc.join(tc.architectureOutputs, ", ", "std::vector &", "")}){
<#list tc.architectureInputs as inputName>
- MXPredSetInput(handle, "data", ${inputName}.data(), ${inputName}.size());
- //MXPredSetInput(handle, "${inputName}", ${inputName}.data(), ${inputName}.size());
+ MXPredSetInput(handle, "data", ${inputName}.data(), static_cast(${inputName}.size()));
#list>
MXPredForward(handle);
@@ -65,8 +64,6 @@ public:
int dev_type = use_gpu ? 2 : 1;
int dev_id = 0;
- handle = 0;
-
if (json_data.GetLength() == 0 ||
param_data.GetLength() == 0) {
std::exit(-1);
@@ -74,10 +71,15 @@ public:
const mx_uint num_input_nodes = input_keys.size();
+<#if (tc.architectureInputs?size >= 2)>
const char* input_keys_ptr[num_input_nodes];
for(mx_uint i = 0; i < num_input_nodes; i++){
input_keys_ptr[i] = input_keys[i].c_str();
}
+<#else>
+ const char* input_key[1] = { "data" };
+ const char** input_keys_ptr = input_key;
+#if>
mx_uint shape_data_size = 0;
mx_uint input_shape_indptr[input_shapes.size() + 1];
@@ -96,8 +98,8 @@ public:
}
}
- MXPredCreate((const char*)json_data.GetBuffer(),
- (const char*)param_data.GetBuffer(),
+ MXPredCreate(static_cast(json_data.GetBuffer()),
+ static_cast(param_data.GetBuffer()),
static_cast(param_data.GetLength()),
dev_type,
dev_id,
diff --git a/src/main/resources/templates/gluon/reinforcement/StartTrainer.ftl b/src/main/resources/templates/gluon/reinforcement/StartTrainer.ftl
new file mode 100644
index 0000000000000000000000000000000000000000..f2e8afde5362d1f65887a15f74792d2baf174ce7
--- /dev/null
+++ b/src/main/resources/templates/gluon/reinforcement/StartTrainer.ftl
@@ -0,0 +1,3 @@
+<#assign config = configurations[0]>
+#!/bin/bash
+python ${trainerName}.py
\ No newline at end of file
diff --git a/src/main/resources/templates/gluon/reinforcement/Trainer.ftl b/src/main/resources/templates/gluon/reinforcement/Trainer.ftl
new file mode 100644
index 0000000000000000000000000000000000000000..6b28a79c9a053e4016079f0803bdaf07dd368201
--- /dev/null
+++ b/src/main/resources/templates/gluon/reinforcement/Trainer.ftl
@@ -0,0 +1,182 @@
+<#setting number_format="computer">
+<#assign config = configurations[0]>
+from ${rlFrameworkModule}.agent import DqnAgent
+from ${rlFrameworkModule}.util import AgentSignalHandler
+import ${rlFrameworkModule}.environment
+import CNNCreator_${config.instanceName}
+
+import os
+import sys
+import re
+import logging
+import mxnet as mx
+
+session_output_dir = 'session'
+<#if (config.agentName)??>
+agent_name='${config.agentName}'
+<#else>
+agent_name='${config.instanceName}'
+#if>
+session_param_output = os.path.join(session_output_dir, agent_name)
+
+def resume_session():
+ session_param_output = os.path.join(session_output_dir, agent_name)
+ resume_session = False
+ resume_directory = None
+ if os.path.isdir(session_output_dir) and os.path.isdir(session_param_output):
+ regex = re.compile(r'\d\d\d\d-\d\d-\d\d-\d\d-\d\d')
+ dir_content = os.listdir(session_param_output)
+ session_files = filter(regex.search, dir_content)
+ session_files.sort(reverse=True)
+ for d in session_files:
+ interrupted_session_dir = os.path.join(session_param_output, d, '.interrupted_session')
+ if os.path.isdir(interrupted_session_dir):
+ resume = raw_input('Interrupted session from {} found. Do you want to resume? (y/n) '.format(d))
+ if resume == 'y':
+ resume_session = True
+ resume_directory = interrupted_session_dir
+ break
+ return resume_session, resume_directory
+
+if __name__ == "__main__":
+<#if config.environment.environment == "gym">
+ env = ${rlFrameworkModule}.environment.GymEnvironment(<#if config.environment.name??>'${config.environment.name}'<#else>'CartPole-v0'#if>)
+<#else>
+ env_params = {
+ 'ros_node_name' : '${config.instanceName}TrainerNode',
+<#if config.environment.state_topic??>
+ 'state_topic' : '${config.environment.state_topic}',
+#if>
+<#if config.environment.action_topic??>
+ 'action_topic' : '${config.environment.action_topic}',
+#if>
+<#if config.environment.reset_topic??>
+ 'reset_topic' : '${config.environment.reset_topic}',
+#if>
+<#if config.environment.meta_topic??>
+ 'meta_topic' : '${config.environment.meta_topic}',
+#if>
+<#if config.environment.greeting_topic??>
+ 'greeting_topic' : '${config.environment.greeting_topic}'
+#if>
+<#if config.environment.terminal_state_topic??>
+ 'terminal_state_topic' : '${config.environment.terminal_state_topic}'
+#if>
+ }
+ env = ${rlFrameworkModule}.environment.RosEnvironment(**env_params)
+#if>
+<#if (config.context)??>
+ context = mx.${config.context}()
+<#else>
+ context = mx.cpu()
+#if>
+ net_creator = CNNCreator_${config.instanceName}.CNNCreator_${config.instanceName}()
+ net_creator.construct(context)
+
+ replay_memory_params = {
+<#if (config.replayMemory)??>
+ 'method':'${config.replayMemory.method}',
+<#if (config.replayMemory.memory_size)??>
+ 'memory_size':${config.replayMemory.memory_size},
+#if>
+<#if (config.replayMemory.sample_size)??>
+ 'sample_size':${config.replayMemory.sample_size},
+#if>
+<#else>
+ 'method':'online',
+#if>
+ 'state_dtype':'float32',
+ 'action_dtype':'uint8',
+ 'rewards_dtype':'float32'
+ }
+
+ policy_params = {
+<#if (config.actionSelection)??>
+ 'method':'${config.actionSelection.method}',
+<#else>
+ 'method':'epsgreedy'
+#if>
+<#if (config.actionSelection.epsilon)??>
+ 'epsilon': ${config.actionSelection.epsilon},
+#if>
+<#if (config.actionSelection.min_epsilon)??>
+ 'min_epsilon': ${config.actionSelection.min_epsilon},
+#if>
+<#if (config.actionSelection.epsilon_decay_method)??>
+ 'epsilon_decay_method': '${config.actionSelection.epsilon_decay_method}',
+#if>
+<#if (config.actionSelection.epsilon_decay)??>
+ 'epsilon_decay': ${config.actionSelection.epsilon_decay},
+#if>
+ }
+
+ resume_session, resume_directory = resume_session()
+
+ if resume_session:
+ agent = DqnAgent.resume_from_session(resume_directory, net_creator.net, env)
+ else:
+ agent = DqnAgent(
+ network = net_creator.net,
+ environment=env,
+ replay_memory_params=replay_memory_params,
+ policy_params=policy_params,
+ state_dim=net_creator.get_input_shapes()[0],
+<#if (config.context)??>
+ ctx='${config.context}',
+#if>
+<#if (config.discountFactor)??>
+ discount_factor=${config.discountFactor},
+#if>
+<#if (config.loss)??>
+ loss_function='${config.loss}',
+#if>
+<#if (config.configuration.optimizer)??>
+ optimizer='${config.optimizerName}',
+ optimizer_params={
+<#list config.optimizerParams?keys as param>
+ '${param}': ${config.optimizerParams[param]}<#sep>,
+#list>
+ },
+#if>
+<#if (config.numEpisodes)??>
+ training_episodes=${config.numEpisodes},
+#if>
+<#if (config.trainingInterval)??>
+ train_interval=${config.trainingInterval},
+#if>
+<#if (config.useFixTargetNetwork)?? && config.useFixTargetNetwork>
+ use_fix_target=True,
+ target_update_interval=${config.targetNetworkUpdateInterval},
+<#else>
+ use_fix_target=False,
+#if>
+<#if (config.useDoubleDqn)?? && config.useDoubleDqn>
+ double_dqn = True,
+<#else>
+ double_dqn = False,
+#if>
+<#if (config.snapshotInterval)??>
+ snapshot_interval=${config.snapshotInterval},
+#if>
+ agent_name=agent_name,
+<#if (config.numMaxSteps)??>
+ max_episode_step=${config.numMaxSteps},
+#if>
+ output_directory=session_output_dir,
+ verbose=True,
+ live_plot = True,
+ make_logfile=True,
+<#if (config.targetScore)??>
+ target_score=${config.targetScore}
+<#else>
+ target_score=None
+#if>
+ )
+
+ signal_handler = AgentSignalHandler()
+ signal_handler.register_agent(agent)
+
+ train_successful = agent.train()
+
+ if train_successful:
+ agent.save_best_network(net_creator._model_dir_ + net_creator._model_prefix_ + '_newest', epoch=0)
\ No newline at end of file
diff --git a/src/main/resources/templates/gluon/reinforcement/agent/ActionPolicy.ftl b/src/main/resources/templates/gluon/reinforcement/agent/ActionPolicy.ftl
new file mode 100644
index 0000000000000000000000000000000000000000..f43a211fe353f5fcb95fb8dd38d7e412fd1d1ab4
--- /dev/null
+++ b/src/main/resources/templates/gluon/reinforcement/agent/ActionPolicy.ftl
@@ -0,0 +1,73 @@
+import numpy as np
+
+class ActionPolicyBuilder(object):
+ def __init__(self):
+ pass
+
+ def build_by_params(self,
+ method='epsgreedy',
+ epsilon=0.5,
+ min_epsilon=0.05,
+ epsilon_decay_method='no',
+ epsilon_decay=0.0,
+ action_dim=None):
+
+ if epsilon_decay_method == 'linear':
+ decay = LinearDecay(eps_decay=epsilon_decay, min_eps=min_epsilon)
+ else:
+ decay = NoDecay()
+
+ if method == 'epsgreedy':
+ assert action_dim is not None
+ assert len(action_dim) == 1
+ return EpsilonGreedyActionPolicy(eps=epsilon,
+ number_of_actions=action_dim[0], decay=decay)
+ else:
+ assert action_dim is not None
+ assert len(action_dim) == 1
+ return GreedyActionPolicy()
+
+class EpsilonGreedyActionPolicy(object):
+ def __init__(self, eps, number_of_actions, decay):
+ self.eps = eps
+ self.cur_eps = eps
+ self.__number_of_actions = number_of_actions
+ self.__decay_method = decay
+
+ def select_action(self, values):
+ do_exploration = (np.random.rand() < self.cur_eps)
+ if do_exploration:
+ action = np.random.randint(low=0, high=self.__number_of_actions)
+ else:
+ action = values.asnumpy().argmax()
+ return action
+
+ def decay(self):
+ self.cur_eps = self.__decay_method.decay(self.cur_eps)
+
+
+class GreedyActionPolicy(object):
+ def __init__(self):
+ pass
+
+ def select_action(self, values):
+ return values.asnumpy().argmax()
+
+ def decay(self):
+ pass
+
+
+class NoDecay(object):
+ def __init__(self):
+ pass
+
+ def decay(self, cur_eps):
+ return cur_eps
+
+class LinearDecay(object):
+ def __init__(self, eps_decay, min_eps=0):
+ self.eps_decay = eps_decay
+ self.min_eps = min_eps
+
+ def decay(self, cur_eps):
+ return max(cur_eps - self.eps_decay, self.min_eps)
\ No newline at end of file
diff --git a/src/main/resources/templates/gluon/reinforcement/agent/Agent.ftl b/src/main/resources/templates/gluon/reinforcement/agent/Agent.ftl
new file mode 100644
index 0000000000000000000000000000000000000000..acbc974d5295431ea6a88a562cff2fbd7d951066
--- /dev/null
+++ b/src/main/resources/templates/gluon/reinforcement/agent/Agent.ftl
@@ -0,0 +1,506 @@
+import mxnet as mx
+import numpy as np
+import time
+import os
+import logging
+import sys
+import util
+import matplotlib.pyplot as plt
+from replay_memory import ReplayMemoryBuilder
+from action_policy import ActionPolicyBuilder
+from util import copy_net, get_loss_function
+from mxnet import nd, gluon, autograd
+
+class DqnAgent(object):
+ def __init__(self,
+ network,
+ environment,
+ replay_memory_params,
+ policy_params,
+ state_dim,
+ ctx=None,
+ discount_factor=.9,
+ loss_function='euclidean',
+ optimizer='rmsprop',
+ optimizer_params = {'learning_rate':0.09},
+ training_episodes=50,
+ train_interval=1,
+ use_fix_target=False,
+ double_dqn = False,
+ target_update_interval=10,
+ snapshot_interval=200,
+ agent_name='Dqn_agent',
+ max_episode_step=99999,
+ output_directory='model_parameters',
+ verbose=True,
+ live_plot = True,
+ make_logfile=True,
+ target_score=None):
+ assert 0 < discount_factor <= 1
+ assert train_interval > 0
+ assert target_update_interval > 0
+ assert snapshot_interval > 0
+ assert max_episode_step > 0
+ assert training_episodes > 0
+ assert replay_memory_params is not None
+ assert type(state_dim) is tuple
+
+ self.__ctx = mx.gpu() if ctx == 'gpu' else mx.cpu()
+ self.__qnet = network
+
+ self.__environment = environment
+ self.__discount_factor = discount_factor
+ self.__training_episodes = training_episodes
+ self.__train_interval = train_interval
+ self.__verbose = verbose
+ self.__state_dim = state_dim
+ self.__action_dim = self.__qnet(nd.random_normal(shape=((1,) + self.__state_dim), ctx=self.__ctx)).shape[1:]
+
+ replay_memory_params['state_dim'] = state_dim
+ self.__replay_memory_params = replay_memory_params
+ rm_builder = ReplayMemoryBuilder()
+ self.__memory = rm_builder.build_by_params(**replay_memory_params)
+ self.__minibatch_size = self.__memory.sample_size
+
+ policy_params['action_dim'] = self.__action_dim
+ self.__policy_params = policy_params
+ p_builder = ActionPolicyBuilder()
+ self.__policy = p_builder.build_by_params(**policy_params)
+
+ self.__target_update_interval = target_update_interval
+ self.__target_qnet = copy_net(self.__qnet, self.__state_dim, ctx=self.__ctx)
+ self.__loss_function_str = loss_function
+ self.__loss_function = get_loss_function(loss_function)
+ self.__agent_name = agent_name
+ self.__snapshot_interval = snapshot_interval
+ self.__creation_time = time.time()
+ self.__max_episode_step = max_episode_step
+ self.__optimizer = optimizer
+ self.__optimizer_params = optimizer_params
+ self.__make_logfile = make_logfile
+ self.__double_dqn = double_dqn
+ self.__use_fix_target = use_fix_target
+ self.__live_plot = live_plot
+ self.__user_given_directory = output_directory
+ self.__target_score = target_score
+
+ self.__interrupt_flag = False
+
+ # Training Context
+ self.__current_episode = 0
+ self.__total_steps = 0
+
+ # Initialize best network
+ self.__best_net = copy_net(self.__qnet, self.__state_dim, self.__ctx)
+ self.__best_avg_score = None
+
+ # Gluon Trainer definition
+ self.__training_stats = None
+
+ # Prepare output directory and logger
+ self.__output_directory = output_directory\
+ + '/' + self.__agent_name\
+ + '/' + time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime(self.__creation_time))
+ self.__logger = self.__setup_logging()
+ self.__logger.info('Agent created with following parameters: {}'.format(self.__make_config_dict()))
+
+ @classmethod
+ def from_config_file(cls, network, environment, config_file_path, ctx=None):
+ import json
+ # Load config
+ with open(config_file_path, 'r') as config_file:
+ config_dict = json.load(config_file)
+ return cls(network, environment, ctx=ctx, **config_dict)
+
+ @classmethod
+ def resume_from_session(cls, session_dir, net, environment):
+ import pickle
+ if not os.path.exists(session_dir):
+ raise ValueError('Session directory does not exist')
+
+ files = dict()
+ files['agent'] = os.path.join(session_dir, 'agent.p')
+ files['best_net_params'] = os.path.join(session_dir, 'best_net.params')
+ files['q_net_params'] = os.path.join(session_dir, 'qnet.params')
+ files['target_net_params'] = os.path.join(session_dir, 'target_net.params')
+
+ for file in files.values():
+ if not os.path.exists(file):
+ raise ValueError('Session directory is not complete: {} is missing'.format(file))
+
+ with open(files['agent'], 'rb') as f:
+ agent = pickle.load(f)
+
+ agent.__environment = environment
+ agent.__qnet = net
+ agent.__qnet.load_parameters(files['q_net_params'], agent.__ctx)
+ agent.__qnet.hybridize()
+ agent.__qnet(nd.random_normal(shape=((1,) + agent.__state_dim), ctx=agent.__ctx))
+ agent.__best_net = copy_net(agent.__qnet, agent.__state_dim, agent.__ctx)
+ agent.__best_net.load_parameters(files['best_net_params'], agent.__ctx)
+ agent.__target_qnet = copy_net(agent.__qnet, agent.__state_dim, agent.__ctx)
+ agent.__target_qnet.load_parameters(files['target_net_params'], agent.__ctx)
+
+ agent.__logger = agent.__setup_logging(append=True)
+ agent.__training_stats.logger = agent.__logger
+ agent.__logger.info('Agent was retrieved; Training can be continued')
+
+ return agent
+
+ def __interrupt_training(self):
+ import pickle
+ self.__logger.info('Training interrupted; Store state for resuming')
+ session_dir = os.path.join(self.__output_directory, '.interrupted_session')
+ if not os.path.exists(session_dir):
+ os.mkdir(session_dir)
+
+ del self.__training_stats.logger
+ logger = self.__logger
+ self.__logger = None
+ self.__environment.close()
+ self.__environment = None
+
+ self.__save_net(self.__qnet, 'qnet', session_dir)
+ self.__qnet = None
+ self.__save_net(self.__best_net, 'best_net', session_dir)
+ self.__best_net = None
+ self.__save_net(self.__target_qnet, 'target_net', session_dir)
+ self.__target_qnet = None
+
+ agent_session_file = os.path.join(session_dir, 'agent.p')
+
+ with open(agent_session_file, 'wb') as f:
+ pickle.dump(self, f)
+ self.__logger = logger
+ logger.info('State successfully stored')
+
+ @property
+ def current_episode(self):
+ return self.__current_episode
+
+ @property
+ def environment(self):
+ return self.__environment
+
+ def __adjust_optimizer_params(self, optimizer_params):
+ if 'weight_decay' in optimizer_params:
+ optimizer_params['wd'] = optimizer_params['weight_decay']
+ del optimizer_params['weight_decay']
+ if 'learning_rate_decay' in optimizer_params:
+ min_learning_rate = 1e-8
+ if 'learning_rate_minimum' in optimizer_params:
+ min_learning_rate = optimizer_params['learning_rate_minimum']
+ del optimizer_params['learning_rate_minimum']
+ optimizer_params['lr_scheduler'] = mx.lr.scheduler.FactorScheduler(
+ optimizer_params['step_size'],
+ factor=optimizer_params['learning_rate_decay'],
+ stop_factor_lr=min_learning_rate)
+ del optimizer_params['step_size']
+ del optimizer_params['learning_rate_decay']
+
+ return optimizer_params
+
+ def set_interrupt_flag(self, interrupt):
+ self.__interrupt_flag = interrupt
+
+
+ def __make_output_directory_if_not_exist(self):
+ assert self.__output_directory
+ if not os.path.exists(self.__output_directory):
+ os.makedirs(self.__output_directory)
+
+ def __setup_logging(self, append=False):
+ assert self.__output_directory
+ assert self.__agent_name
+
+ output_level = logging.DEBUG if self.__verbose else logging.WARNING
+ filemode = 'a' if append else 'w'
+
+ logformat = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
+ dateformat = '%d-%b-%y %H:%M:%S'
+ formatter = logging.Formatter(fmt=logformat, datefmt=dateformat)
+
+ logger = logging.getLogger('DQNAgent')
+ logger.setLevel(output_level)
+
+ stream_handler = logging.StreamHandler(sys.stdout)
+ stream_handler.setLevel(output_level)
+ stream_handler.setFormatter(formatter)
+ logger.addHandler(stream_handler)
+
+ if self.__make_logfile:
+ self.__make_output_directory_if_not_exist()
+ log_file = os.path.join(self.__output_directory, self.__agent_name + '.log')
+ file_handler = logging.FileHandler(log_file, mode=filemode)
+ file_handler.setLevel(output_level)
+ file_handler.setFormatter(formatter)
+ logger.addHandler(file_handler)
+
+ return logger
+
+ def __is_target_reached(self, avg_reward):
+ return self.__target_score is not None\
+ and avg_reward > self.__target_score
+
+
+ def get_q_values(self, state, with_best=False):
+ return self.get_batch_q_values(nd.array([state], ctx=self.__ctx), with_best=with_best)[0]
+
+ def get_batch_q_values(self, state_batch, with_best=False):
+ return self.__best_net(state_batch) if with_best else self.__qnet(state_batch)
+
+ def get_next_action(self, state, with_best=False):
+ q_values = self.get_q_values(state, with_best=with_best)
+ action = q_values.asnumpy().argmax()
+ return q_values.asnumpy().argmax()
+
+ def __sample_from_memory(self):
+ states, actions, rewards, next_states, terminals\
+ = self.__memory.sample(batch_size=self.__minibatch_size)
+ states = nd.array(states, ctx=self.__ctx)
+ actions = nd.array(actions, ctx=self.__ctx)
+ rewards = nd.array(rewards, ctx=self.__ctx)
+ next_states = nd.array(next_states, ctx=self.__ctx)
+ terminals = nd.array(terminals, ctx=self.__ctx)
+ return states, actions, rewards, next_states, terminals
+
+ def __determine_target_q_values(self, states, actions, rewards, next_states, terminals):
+ if self.__use_fix_target:
+ q_max_val = self.__target_qnet(next_states)
+ else:
+ q_max_val = self.__qnet(next_states)
+
+ if self.__double_dqn:
+ q_values_next_states = self.__qnet(next_states)
+ target_rewards = rewards + nd.choose_element_0index(q_max_val, nd.argmax_channel(q_values_next_states))\
+ * (1.0 - terminals) * self.__discount_factor
+ else:
+ target_rewards = rewards + nd.choose_element_0index(q_max_val, nd.argmax_channel(q_max_val))\
+ * (1.0 - terminals) * self.__discount_factor
+
+ target_qval = self.__qnet(states)
+ for t in range(target_rewards.shape[0]):
+ target_qval[t][actions[t]] = target_rewards[t]
+
+ return target_qval
+
+ def __train_q_net_step(self, trainer):
+ states, actions, rewards, next_states, terminals = self.__sample_from_memory()
+ target_qval = self.__determine_target_q_values(states, actions, rewards, next_states, terminals)
+ with autograd.record():
+ q_values = self.__qnet(states)
+ loss = self.__loss_function(q_values, target_qval)
+ loss.backward()
+ trainer.step(self.__minibatch_size)
+ return loss
+
+ def __do_snapshot_if_in_interval(self, episode):
+ do_snapshot = (episode != 0 and (episode % self.__snapshot_interval == 0))
+ if do_snapshot:
+ self.save_parameters(episode=episode)
+ self.__evaluate()
+
+ def __do_target_update_if_in_interval(self, total_steps):
+ do_target_update = (self.__use_fix_target and total_steps % self.__target_update_interval == 0)
+ if do_target_update:
+ self.__logger.info('Target network is updated after {} steps'.format(total_steps))
+ self.__target_qnet = copy_net(self.__qnet, self.__state_dim, self.__ctx)
+
+ def train(self, episodes=None):
+ self.__logger.info("--- Start training ---")
+ trainer = gluon.Trainer(self.__qnet.collect_params(), self.__optimizer, self.__adjust_optimizer_params(self.__optimizer_params))
+ episodes = episodes if episodes != None else self.__training_episodes
+
+ resume = (self.__current_episode > 0)
+ if resume:
+ self.__logger.info("Training session resumed")
+ self.__logger.info("Starting from episode {}".format(self.__current_episode))
+ else:
+ self.__training_stats = util.TrainingStats(self.__logger, episodes, self.__live_plot)
+
+ # Implementation Deep Q Learning described by Mnih et. al. in Playing Atari with Deep Reinforcement Learning
+ while self.__current_episode < episodes:
+ # Check interrupt flag
+ if self.__interrupt_flag:
+ self.__interrupt_flag = False
+ self.__interrupt_training()
+ return False
+
+ step = 0
+ episode_reward = 0
+ start = time.time()
+ state = self.__environment.reset()
+ episode_loss = 0
+ training_steps = 0
+ while step < self.__max_episode_step:
+ #1. Choose an action based on current game state and policy
+ q_values = self.__qnet(nd.array([state], ctx=self.__ctx))
+ action = self.__policy.select_action(q_values[0])
+
+ #2. Play the game for a single step
+ next_state, reward, terminal, _ = self.__environment.step(action)
+
+ #3. Store transition in replay memory
+ self.__memory.append(state, action, reward, next_state, terminal)
+
+ #4. Train the network if in interval
+ do_training = (self.__total_steps % self.__train_interval == 0\
+ and self.__memory.is_sample_possible(self.__minibatch_size))
+ if do_training:
+ loss = self.__train_q_net_step(trainer)
+ loss_sum = sum(loss).asnumpy()[0]
+ episode_loss += float(loss_sum)/float(self.__minibatch_size)
+ training_steps += 1
+
+ # Update target network if in interval
+ self.__do_target_update_if_in_interval(self.__total_steps)
+
+ step += 1
+ self.__total_steps += 1
+ episode_reward += reward
+ state = next_state
+
+ if terminal:
+ episode_loss = episode_loss if training_steps > 0 else None
+ _, _, avg_reward = self.__training_stats.log_episode(self.__current_episode, start, training_steps,
+ episode_loss, self.__policy.cur_eps, episode_reward)
+ break
+
+ self.__do_snapshot_if_in_interval(self.__current_episode)
+ self.__policy.decay()
+
+ if self.__is_target_reached(avg_reward):
+ self.__logger.info('Target score is reached in average; Training is stopped')
+ break
+
+ self.__current_episode += 1
+
+ self.__evaluate()
+ training_stats_file = os.path.join(self.__output_directory, 'training_stats.pdf')
+ self.__training_stats.save_stats(training_stats_file)
+ self.__logger.info('--------- Training finished ---------')
+ return True
+
+ def __save_net(self, net, filename, filedir=None):
+ filedir = self.__output_directory if filedir is None else filedir
+ filename = os.path.join(filedir, filename + '.params')
+ net.save_parameters(filename)
+
+
+ def save_parameters(self, episode=None, filename='dqn-agent-params'):
+ assert self.__output_directory
+ self.__make_output_directory_if_not_exist()
+
+ if(episode != None):
+ self.__logger.info('Saving model parameters after episode %d' % episode)
+ filename = filename + '-ep{}'.format(episode)
+ else:
+ self.__logger.info('Saving model parameters')
+ self.__save_net(self.__qnet, filename)
+
+ def evaluate(self, target=None, sample_games=100, verbose=True):
+ target = self.__target_score if target is None else target
+ if target:
+ target_achieved = 0
+ total_reward = 0
+
+ for g in range(sample_games):
+ state = self.__environment.reset()
+ step = 0
+ game_reward = 0
+ while step < self.__max_episode_step:
+ action = self.get_next_action(state)
+ state, reward, terminal, _ = self.__environment.step(action)
+ game_reward += reward
+
+ if terminal:
+ if verbose:
+ info = 'Game %d: Reward %f' % (g,game_reward)
+ self.__logger.debug(info)
+ if target:
+ if game_reward >= target:
+ target_achieved += 1
+ total_reward += game_reward
+ break
+
+ step += 1
+
+ avg_reward = float(total_reward)/float(sample_games)
+ info = 'Avg. Reward: %f' % avg_reward
+ if target:
+ target_achieved_ratio = int((float(target_achieved)/float(sample_games))*100)
+ info += '; Target Achieved in %d%% of games' % (target_achieved_ratio)
+
+ if verbose:
+ self.__logger.info(info)
+ return avg_reward
+
+ def __evaluate(self, verbose=True):
+ sample_games = 100
+ avg_reward = self.evaluate(sample_games=sample_games, verbose=False)
+ info = 'Evaluation -> Average Reward in {} games: {}'.format(sample_games, avg_reward)
+
+ if self.__best_avg_score is None or self.__best_avg_score <= avg_reward:
+ self.__best_net = copy_net(self.__qnet, self.__state_dim, self.__ctx)
+ self.__best_avg_score = avg_reward
+ info += ' (NEW BEST)'
+
+ if verbose:
+ self.__logger.info(info)
+
+
+
+ def play(self, update_frame=1, with_best=False):
+ step = 0
+ state = self.__environment.reset()
+ total_reward = 0
+ while step < self.__max_episode_step:
+ action = self.get_next_action(state, with_best=with_best)
+ state, reward, terminal, _ = self.__environment.step(action)
+ total_reward += reward
+ do_update_frame = (step % update_frame == 0)
+ if do_update_frame:
+ self.__environment.render()
+ time.sleep(.100)
+
+ if terminal:
+ break
+
+ step += 1
+ return total_reward
+
+ def save_best_network(self, path, epoch=0):
+ self.__logger.info('Saving best network with average reward of {}'.format(self.__best_avg_score))
+ self.__best_net.export(path, epoch=epoch)
+
+ def __make_config_dict(self):
+ config = dict()
+ config['discount_factor'] = self.__discount_factor
+ config['optimizer'] = self.__optimizer
+ config['optimizer_params'] = self.__optimizer_params
+ config['policy_params'] = self.__policy_params
+ config['replay_memory_params'] = self.__replay_memory_params
+ config['loss_function'] = self.__loss_function_str
+ config['optimizer'] = self.__optimizer
+ config['training_episodes'] = self.__training_episodes
+ config['train_interval'] = self.__train_interval
+ config['use_fix_target'] = self.__use_fix_target
+ config['double_dqn'] = self.__double_dqn
+ config['target_update_interval'] = self.__target_update_interval
+ config['snapshot_interval']= self.__snapshot_interval
+ config['agent_name'] = self.__agent_name
+ config['max_episode_step'] = self.__max_episode_step
+ config['output_directory'] = self.__user_given_directory
+ config['verbose'] = self.__verbose
+ config['live_plot'] = self.__live_plot
+ config['make_logfile'] = self.__make_logfile
+ config['target_score'] = self.__target_score
+ return config
+
+ def save_config_file(self):
+ import json
+ self.__make_output_directory_if_not_exist()
+ filename = os.path.join(self.__output_directory, 'config.json')
+ config = self.__make_config_dict()
+ with open(filename, mode='w') as fp:
+ json.dump(config, fp, indent=4)
\ No newline at end of file
diff --git a/src/main/resources/templates/gluon/reinforcement/agent/ReplayMemory.ftl b/src/main/resources/templates/gluon/reinforcement/agent/ReplayMemory.ftl
new file mode 100644
index 0000000000000000000000000000000000000000..e66cd9350cab02144f994cab706249ef5a1e4288
--- /dev/null
+++ b/src/main/resources/templates/gluon/reinforcement/agent/ReplayMemory.ftl
@@ -0,0 +1,155 @@
+import numpy as np
+
+class ReplayMemoryBuilder(object):
+ def __init__(self):
+ self.__supported_methods = ['online', 'buffer', 'combined']
+
+ def build_by_params(self,
+ state_dim,
+ method='online',
+ state_dtype='float32',
+ action_dtype='uint8',
+ rewards_dtype='float32',
+ memory_size=1000,
+ sample_size=32):
+ assert state_dim is not None
+ assert method in self.__supported_methods
+
+ if method == 'online':
+ return self.build_online_memory(state_dim=state_dim, state_dtype=state_dtype,
+ action_dtype=action_dtype, rewards_dtype=rewards_dtype)
+ else:
+ assert memory_size is not None and memory_size > 0
+ assert sample_size is not None and sample_size > 0
+ if method == 'buffer':
+ return self.build_buffered_memory(state_dim=state_dim, sample_size=sample_size,
+ memory_size=memory_size, state_dtype=state_dtype, action_dtype=action_dtype,
+ rewards_dtype=rewards_dtype)
+ else:
+ return self.build_combined_memory(state_dim=state_dim, sample_size=sample_size,
+ memory_size=memory_size, state_dtype=state_dtype, action_dtype=action_dtype,
+ rewards_dtype=rewards_dtype)
+
+ def build_buffered_memory(self, state_dim, memory_size=1000, sample_size=1, state_dtype='float32',
+ action_dtype='uint8', rewards_dtype='float32'):
+ assert memory_size > 0
+ assert sample_size > 0
+ return ReplayMemory(state_dim, size=memory_size, sample_size=sample_size,
+ state_dtype=state_dtype, action_dtype=action_dtype, rewards_dtype=rewards_dtype)
+
+ def build_combined_memory(self, state_dim, memory_size=1000, sample_size=1, state_dtype='float32',
+ action_dtype='uint8', rewards_dtype='float32'):
+ assert memory_size > 0
+ assert sample_size > 0
+ return CombinedReplayMemory(state_dim, size=memory_size, sample_size=sample_size,
+ state_dtype=state_dtype, action_dtype=action_dtype, rewards_dtype=rewards_dtype)
+
+ def build_online_memory(self, state_dim, state_dtype='float32', action_dtype='uint8',
+ rewards_dtype='float32'):
+ return OnlineReplayMemory(state_dim, state_dtype=state_dtype, action_dtype=action_dtype,
+ rewards_dtype=rewards_dtype)
+
+class ReplayMemory(object):
+ def __init__(self, state_dim, sample_size, size=1000, state_dtype='uint8', action_dtype='uint8', rewards_dtype='float32'):
+ assert size > 0, "Size must be greater than zero"
+ assert type(state_dim) is tuple, "State dimension must be a tuple"
+ assert sample_size > 0
+ self._size = size
+ self._sample_size = sample_size
+ self._cur_size = 0
+ self._pointer = 0
+ self._state_dim = state_dim
+ self._state_dtype = state_dtype
+ self._action_dtype = action_dtype
+ self._rewards_dtype = rewards_dtype
+ self._states = np.zeros((self._size,) + state_dim, dtype=state_dtype)
+ self._actions = np.array([0] * self._size, dtype=action_dtype)
+ self._rewards = np.array([0] * self._size, dtype=rewards_dtype)
+ self._next_states = np.zeros((self._size,) + state_dim, dtype=state_dtype)
+ self._terminals = np.array([0] * self._size, dtype='bool')
+
+ @property
+ def sample_size(self):
+ return self._sample_size
+
+ def append(self, state, action, reward, next_state, terminal):
+ self._states[self._pointer] = state
+ self._actions[self._pointer] = action
+ self._rewards[self._pointer] = reward
+ self._next_states[self._pointer] = next_state
+ self._terminals[self._pointer] = terminal
+
+ self._pointer = self._pointer + 1
+ if self._pointer == self._size:
+ self._pointer = 0
+
+ self._cur_size = min(self._size, self._cur_size + 1)
+
+ def at(self, index):
+ return self._states[index],\
+ self._actions[index],\
+ self._rewards[index],\
+ self._next_states[index],\
+ self._terminals[index]
+
+ def is_sample_possible(self, batch_size=None):
+ batch_size = batch_size if batch_size is not None else self._sample_size
+ return self._cur_size >= batch_size
+
+ def sample(self, batch_size=None):
+ batch_size = batch_size if batch_size is not None else self._sample_size
+ assert self._cur_size >= batch_size, "Size of replay memory must be larger than batch size"
+ i=0
+ states = np.zeros((batch_size,)+self._state_dim, dtype=self._state_dtype)
+ actions = np.zeros(batch_size, dtype=self._action_dtype)
+ rewards = np.zeros(batch_size, dtype=self._rewards_dtype)
+ next_states = np.zeros((batch_size,)+self._state_dim, dtype=self._state_dtype)
+ terminals = np.zeros(batch_size, dtype='bool')
+
+ while i < batch_size:
+ rnd_index = np.random.randint(low=0, high=self._cur_size)
+ states[i] = self._states.take(rnd_index, axis=0)
+ actions[i] = self._actions.take(rnd_index, axis=0)
+ rewards[i] = self._rewards.take(rnd_index, axis=0)
+ next_states[i] = self._next_states.take(rnd_index, axis=0)
+ terminals[i] = self._terminals.take(rnd_index, axis=0)
+ i += 1
+
+ return states, actions, rewards, next_states, terminals
+
+
+class OnlineReplayMemory(ReplayMemory):
+ def __init__(self, state_dim, state_dtype='float32', action_dtype='uint8', rewards_dtype='float32'):
+ super(OnlineReplayMemory, self).__init__(state_dim, sample_size=1, size=1,
+ state_dtype=state_dtype, action_dtype=action_dtype, rewards_dtype=rewards_dtype)
+
+
+class CombinedReplayMemory(ReplayMemory):
+ def __init__(self, state_dim, sample_size, size=1000,
+ state_dtype='uint8', action_dtype='uint8', rewards_dtype='float32'):
+ super(CombinedReplayMemory, self).__init__(state_dim, sample_size=(sample_size - 1), size=size,
+ state_dtype=state_dtype, action_dtype=action_dtype, rewards_dtype=rewards_dtype)
+
+ self._last_state = np.zeros((1,) + state_dim, dtype=state_dtype)
+ self._last_action = np.array([0], dtype=action_dtype)
+ self._last_reward = np.array([0], dtype=rewards_dtype)
+ self._last_next_state = np.zeros((1,) + state_dim, dtype=state_dtype)
+ self._last_terminal = np.array([0], dtype='bool')
+
+ def append(self, state, action, reward, next_state, terminal):
+ super(CombinedReplayMemory, self).append(state, action, reward, next_state, terminal)
+ self._last_state = state
+ self._last_action = action
+ self._last_reward = reward
+ self._last_next_state = next_state
+ self._last_terminal = terminal
+
+ def sample(self, batch_size=None):
+ batch_size = (batch_size-1) if batch_size is not None else self._sample_size
+ states, actions, rewards, next_states, terminals = super(CombinedReplayMemory, self).sample(batch_size=batch_size)
+ states = np.append(states, [self._last_state], axis=0)
+ actions = np.append(actions, [self._last_action], axis=0)
+ rewards = np.append(rewards, [self._last_reward], axis=0)
+ next_states = np.append(next_states, [self._last_next_state], axis=0)
+ terminals = np.append(terminals, [self._last_terminal], axis=0)
+ return states, actions, rewards, next_states, terminals
\ No newline at end of file
diff --git a/src/main/resources/templates/gluon/reinforcement/environment/Environment.ftl b/src/main/resources/templates/gluon/reinforcement/environment/Environment.ftl
new file mode 100644
index 0000000000000000000000000000000000000000..e2104ee8cc4507634b126aac6fb285738cdaad0d
--- /dev/null
+++ b/src/main/resources/templates/gluon/reinforcement/environment/Environment.ftl
@@ -0,0 +1,217 @@
+<#setting number_format="computer">
+<#assign config = configurations[0]>
+import abc
+import logging
+logging.basicConfig(level=logging.INFO)
+logger = logging.getLogger(__name__)
+<#if config.hasRewardFunction() >
+import ${config.rewardFunctionName}_executor
+
+class RewardFunction(object):
+ def __init__(self):
+ self.__reward_wrapper = ${config.rewardFunctionName}_executor.${config.rewardFunctionName}_executor()
+ self.__reward_wrapper.init()
+
+ def reward(self, state, terminal):
+ inp = ${config.rewardFunctionName}_executor.${config.rewardFunctionName}_input()
+ inp.${config.rewardFunctionStateParameter.name} = state
+ inp.${config.rewardFunctionTerminalParameter.name} = terminal
+ output = self.__reward_wrapper.execute(inp)
+ return output.${config.rewardFunctionOutputName}
+
+
+#if>
+
+class Environment:
+ __metaclass__ = abc.ABCMeta
+
+ def __init__(self):
+<#if config.hasRewardFunction() >
+ self._reward_function = RewardFunction()
+<#else>
+ pass
+#if>
+
+ @abc.abstractmethod
+ def reset(self):
+ pass
+
+ @abc.abstractmethod
+ def step(self, action):
+ pass
+
+ @abc.abstractmethod
+ def close(self):
+ pass
+
+<#if config.environment.environment == "gym">
+import gym
+class GymEnvironment(Environment):
+ def __init__(self, env_name, **kwargs):
+ super(GymEnvironment, self).__init__(**kwargs)
+ self.__seed = 42
+ self.__env = gym.make(env_name)
+ self.__env.seed(self.__seed)
+
+ @property
+ def state_dim(self):
+ return self.__env.observation_space.shape
+
+ @property
+ def state_dtype(self):
+ return 'float32'
+
+ @property
+ def action_dtype(self):
+ return 'uint8'
+
+ @property
+ def number_of_actions(self):
+ return self.__env.action_space.n
+
+ @property
+ def rewards_dtype(self):
+ return 'float32'
+
+ def reset(self):
+ return self.__env.reset()
+
+ def step(self, action):
+<#if config.hasRewardFunction() >
+ next_state, reward, terminal, info = self.__env.step(action)
+ reward = self._reward_function.reward(next_state)
+ return next_state, reward, terminal, info
+<#else>
+ return self.__env.step(action)
+#if>
+
+ def close(self):
+ self.__env.close()
+
+ def action_space(self):
+ self.__env.action_space
+
+ def is_in_action_space(self, action):
+ return self.__env.action_space.contains(action)
+
+ def sample_action(self):
+ return self.__env.action_space.sample()
+
+ def render(self):
+ self.__env.render()
+<#else>
+import rospy
+import thread
+import numpy as np
+import time
+from std_msgs.msg import Float32MultiArray, Bool, Int32
+
+class RosEnvironment(Environment):
+ def __init__(self,
+ ros_node_name='RosTrainingAgent',
+ timeout_in_s=3,
+ state_topic='state',
+ action_topic='action',
+ reset_topic='reset',
+ terminal_state_topic='terminal',
+ meta_topic='meta',
+ greeting_topic='greeting'):
+ super(RosEnvironment, self).__init__()
+ self.__timeout_in_s = timeout_in_s
+
+ self.__waiting_for_state_update = False
+ self.__waiting_for_terminal_update = False
+ self.__last_received_state = 0
+ self.__last_received_terminal = 0
+
+ rospy.loginfo("Initialize node {0}".format(ros_node_name))
+
+ self.__step_publisher = rospy.Publisher(action_topic, Int32, queue_size=1)
+ rospy.loginfo('Step Publisher initialized with topic {}'.format(action_topic))
+
+ self.__reset_publisher = rospy.Publisher(reset_topic, Bool, queue_size=1)
+ rospy.loginfo('Reset Publisher initialized with topic {}'.format(reset_topic))
+
+ rospy.init_node(ros_node_name, anonymous=True)
+
+ self.__state_subscriber = rospy.Subscriber(state_topic, Float32MultiArray, self.__state_callback)
+ rospy.loginfo('State Subscriber registered with topic {}'.format(state_topic))
+
+ self.__terminal_state_subscriber = rospy.Subscriber(terminal_state_topic, Bool, self.__terminal_state_callback)
+ rospy.loginfo('Terminal State Subscriber registered with topic {}'.format(terminal_state_topic))
+
+ rate = rospy.Rate(10)
+
+ thread.start_new_thread(rospy.spin, ())
+ time.sleep(2)
+
+ def reset(self):
+ time.sleep(0.5)
+ reset_message = Bool()
+ reset_message.data = True
+ self.__waiting_for_state_update = True
+ self.__reset_publisher.publish(reset_message)
+ while self.__last_received_terminal:
+ self.__wait_for_new_state(self.__reset_publisher, reset_message)
+ return self.__last_received_state
+
+ def step(self, action):
+ action_rospy = Int32()
+ action_rospy.data = action
+
+ logger.debug('Send action: {}'.format(action))
+
+ self.__waiting_for_state_update = True
+ self.__waiting_for_terminal_update = True
+ self.__step_publisher.publish(action_rospy)
+ self.__wait_for_new_state(self.__step_publisher, action_rospy)
+ next_state = self.__last_received_state
+ terminal = self.__last_received_terminal
+ reward = self.__calc_reward(next_state, terminal)
+ rospy.logdebug('Calculated reward: {}'.format(reward))
+
+ return next_state, reward, terminal, 0
+
+ def __wait_for_new_state(self, publisher, msg):
+ time_of_timeout = time.time() + self.__timeout_in_s
+ timeout_counter = 0
+ while(self.__waiting_for_state_update or self.__waiting_for_terminal_update):
+ is_timeout = (time.time() > time_of_timeout)
+ if (is_timeout):
+ if timeout_counter < 3:
+ rospy.logwarn("Timeout occured: Retry message")
+ publisher.publish(msg)
+ timeout_counter += 1
+ time_of_timeout = time.time() + self.__timeout_in_s
+ else:
+ rospy.logerr("Timeout 3 times in a row: Terminate application")
+ exit()
+ time.sleep(100/1000)
+
+ def close(self):
+ rospy.signal_shutdown('Program ended!')
+
+
+ def __state_callback(self, data):
+<#if config.rewardFunctionStateParameter.isMultiDimensional>
+ self.__last_received_state = np.array(data.data, dtype='${config.rewardFunctionStateParameter.dtype}')
+<#else>
+ self.__last_received_state = data.data
+#if>
+ rospy.logdebug('Received state: {}'.format(self.__last_received_state))
+ self.__waiting_for_state_update = False
+
+ def __terminal_state_callback(self, data):
+<#if config.rewardFunctionTerminalParameter.isMultiDimensional>
+ self.__last_received_terminal = np.array(data.data, dtype='${config.rewardFunctionTerminalParameter.dtype}')
+<#else>
+ self.__last_received_terminal = data.data
+#if>
+ rospy.logdebug('Received terminal flag: {}'.format(self.__last_received_terminal))
+ logger.debug('Received terminal: {}'.format(self.__last_received_terminal))
+ self.__waiting_for_terminal_update = False
+
+ def __calc_reward(self, state, terminal):
+ # C++ Wrapper call
+ return self._reward_function.reward(state, terminal)
+#if>
\ No newline at end of file
diff --git a/src/main/resources/templates/gluon/reinforcement/util/Util.ftl b/src/main/resources/templates/gluon/reinforcement/util/Util.ftl
new file mode 100644
index 0000000000000000000000000000000000000000..58578932a13f7d12f1aa840d06af1951bebc3f0f
--- /dev/null
+++ b/src/main/resources/templates/gluon/reinforcement/util/Util.ftl
@@ -0,0 +1,140 @@
+import signal
+import sys
+import numpy as np
+import matplotlib.pyplot as plt
+from matplotlib import style
+import time
+import os
+import mxnet
+from mxnet import gluon, nd
+
+
+LOSS_FUNCTIONS = {
+ 'l1': gluon.loss.L1Loss(),
+ 'euclidean': gluon.loss.L2Loss(),
+ 'huber_loss': gluon.loss.HuberLoss(),
+ 'softmax_cross_entropy': gluon.loss.SoftmaxCrossEntropyLoss(),
+ 'sigmoid_cross_entropy': gluon.loss.SigmoidBinaryCrossEntropyLoss()}
+
+def copy_net(net, input_state_dim, ctx, tmp_filename='tmp.params'):
+ assert isinstance(net, gluon.HybridBlock)
+ assert type(net.__class__) is type
+ net.save_parameters(tmp_filename)
+ net2 = net.__class__()
+ net2.load_parameters(tmp_filename, ctx=ctx)
+ os.remove(tmp_filename)
+ net2.hybridize()
+ net2(nd.ones((1,) + input_state_dim, ctx=ctx))
+ return net2
+
+def get_loss_function(loss_function_name):
+ if loss_function_name not in LOSS_FUNCTIONS:
+ raise ValueError('Loss function does not exist')
+ return LOSS_FUNCTIONS[loss_function_name]
+
+
+class AgentSignalHandler(object):
+ def __init__(self):
+ signal.signal(signal.SIGINT, self.interrupt_training)
+ self.__agent = None
+ self.__times_interrupted = 0
+
+ def register_agent(self, agent):
+ self.__agent = agent
+
+ def interrupt_training(self, sig, frame):
+ self.__times_interrupted = self.__times_interrupted + 1
+ if self.__times_interrupted <= 3:
+ if self.__agent:
+ self.__agent.set_interrupt_flag(True)
+ else:
+ print('Interrupt called three times: Force quit')
+ sys.exit(1)
+
+style.use('fivethirtyeight')
+class TrainingStats(object):
+ def __init__(self, logger, max_episodes, live_plot=True):
+ self.__logger = logger
+ self.__max_episodes = max_episodes
+ self.__all_avg_loss = np.zeros((max_episodes,))
+ self.__all_total_rewards = np.zeros((max_episodes,))
+ self.__all_eps = np.zeros((max_episodes,))
+ self.__all_time = np.zeros((max_episodes,))
+ self.__all_mean_reward_last_100_episodes = np.zeros((max_episodes,))
+ self.__live_plot = live_plot
+
+ @property
+ def logger(self):
+ return self.__logger
+
+ @logger.setter
+ def logger(self, logger):
+ self.__logger = logger
+
+ @logger.deleter
+ def logger(self):
+ self.__logger = None
+
+ def add_avg_loss(self, episode, avg_loss):
+ self.__all_avg_loss[episode] = avg_loss
+
+ def add_total_reward(self, episode, total_reward):
+ self.__all_total_rewards[episode] = total_reward
+
+ def add_eps(self, episode, eps):
+ self.__all_eps[episode] = eps
+
+ def add_time(self, episode, time):
+ self.__all_time[episode] = time
+
+ def add_mean_reward_last_100(self, episode, mean_reward):
+ self.__all_mean_reward_last_100_episodes[episode] = mean_reward
+
+ def log_episode(self, episode, start_time, training_steps, loss, eps, reward):
+ self.add_eps(episode, eps)
+ self.add_total_reward(episode, reward)
+ end = time.time()
+ if training_steps == 0:
+ avg_loss = 0
+ else:
+ avg_loss = float(loss)/float(training_steps)
+
+ mean_reward_last_100 = self.mean_of_reward(episode, last=100)
+
+ time_elapsed = end - start_time
+ info = "Episode: %d, Total Reward: %.3f, Avg. Reward Last 100 Episodes: %.3f, Avg Loss: %.3f, Time: %.3f, Training Steps: %d, Eps: %.3f"\
+ % (episode, reward, mean_reward_last_100, avg_loss, time_elapsed, training_steps, eps)
+ self.__logger.info(info)
+ self.add_avg_loss(episode, avg_loss)
+ self.add_time(episode, time_elapsed)
+ self.add_mean_reward_last_100(episode, mean_reward_last_100)
+
+ return avg_loss, time_elapsed, mean_reward_last_100
+
+ def mean_of_reward(self, cur_episode, last=100):
+ if cur_episode > 0:
+ reward_last_100 = self.__all_total_rewards[max(0, cur_episode-last):cur_episode]
+ return np.mean(reward_last_100)
+ else:
+ return self.__all_total_rewards[0]
+
+ def save_stats(self, file):
+ fig = plt.figure(figsize=(20,20))
+
+ sub_rewards = fig.add_subplot(221)
+ sub_rewards.set_title('Total Rewards per episode')
+ sub_rewards.plot(np.arange(self.__max_episodes), self.__all_total_rewards)
+
+ sub_loss = fig.add_subplot(222)
+ sub_loss.set_title('Avg. Loss per episode')
+ sub_loss.plot(np.arange(self.__max_episodes), self.__all_avg_loss)
+
+ sub_eps = fig.add_subplot(223)
+ sub_eps.set_title('Epsilon per episode')
+ sub_eps.plot(np.arange(self.__max_episodes), self.__all_eps)
+
+ sub_rewards = fig.add_subplot(224)
+ sub_rewards.set_title('Avg. mean reward of last 100 episodes')
+ sub_rewards.plot(np.arange(self.__max_episodes), self.__all_mean_reward_last_100_episodes)
+
+ plt.savefig(file)
\ No newline at end of file
diff --git a/src/test/java/de/monticore/lang/monticar/cnnarch/gluongenerator/GenerationTest.java b/src/test/java/de/monticore/lang/monticar/cnnarch/gluongenerator/GenerationTest.java
index 545ad9215ec2890faaf4aecb449e1b3330a01456..abe45750044437ca9a85a3442c31e746289ed516 100644
--- a/src/test/java/de/monticore/lang/monticar/cnnarch/gluongenerator/GenerationTest.java
+++ b/src/test/java/de/monticore/lang/monticar/cnnarch/gluongenerator/GenerationTest.java
@@ -20,6 +20,8 @@
*/
package de.monticore.lang.monticar.cnnarch.gluongenerator;
+import de.monticore.lang.monticar.cnnarch.gluongenerator.reinforcement.RewardFunctionSourceGenerator;
+import de.se_rwth.commons.logging.Finding;
import de.se_rwth.commons.logging.Log;
import freemarker.template.TemplateException;
import org.junit.Before;
@@ -29,16 +31,21 @@ import java.io.IOException;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.*;
+import java.util.stream.Collector;
+import java.util.stream.Collectors;
import static junit.framework.TestCase.assertTrue;
+import static org.mockito.Mockito.mock;
public class GenerationTest extends AbstractSymtabTest {
+ private RewardFunctionSourceGenerator rewardFunctionSourceGenerator;
@Before
public void setUp() {
// ensure an empty log
Log.getFindings().clear();
Log.enableFailQuick(false);
+ rewardFunctionSourceGenerator = mock(RewardFunctionSourceGenerator.class);
}
@Test
@@ -125,7 +132,7 @@ public class GenerationTest extends AbstractSymtabTest {
public void testFullCfgGeneration() throws IOException, TemplateException {
Log.getFindings().clear();
String sourcePath = "src/test/resources/valid_tests";
- CNNTrain2Gluon trainGenerator = new CNNTrain2Gluon();
+ CNNTrain2Gluon trainGenerator = new CNNTrain2Gluon(rewardFunctionSourceGenerator);
trainGenerator.generate(Paths.get(sourcePath), "FullConfig");
assertTrue(Log.getFindings().isEmpty());
@@ -141,7 +148,7 @@ public class GenerationTest extends AbstractSymtabTest {
public void testSimpleCfgGeneration() throws IOException {
Log.getFindings().clear();
Path modelPath = Paths.get("src/test/resources/valid_tests");
- CNNTrain2Gluon trainGenerator = new CNNTrain2Gluon();
+ CNNTrain2Gluon trainGenerator = new CNNTrain2Gluon(rewardFunctionSourceGenerator);
trainGenerator.generate(modelPath, "SimpleConfig");
@@ -158,7 +165,7 @@ public class GenerationTest extends AbstractSymtabTest {
public void testEmptyCfgGeneration() throws IOException {
Log.getFindings().clear();
Path modelPath = Paths.get("src/test/resources/valid_tests");
- CNNTrain2Gluon trainGenerator = new CNNTrain2Gluon();
+ CNNTrain2Gluon trainGenerator = new CNNTrain2Gluon(rewardFunctionSourceGenerator);
trainGenerator.generate(modelPath, "EmptyConfig");
assertTrue(Log.getFindings().isEmpty());
@@ -170,6 +177,31 @@ public class GenerationTest extends AbstractSymtabTest {
"supervised_trainer.py"));
}
+ @Test
+ public void testReinforcementConfig2() {
+ Log.getFindings().clear();
+ Path modelPath = Paths.get("src/test/resources/valid_tests");
+ CNNTrain2Gluon trainGenerator = new CNNTrain2Gluon(rewardFunctionSourceGenerator);
+
+ trainGenerator.generate(modelPath, "ReinforcementConfig2");
+
+ assertTrue(Log.getFindings().isEmpty());
+ checkFilesAreEqual(
+ Paths.get("./target/generated-sources-cnnarch"),
+ Paths.get("./src/test/resources/target_code/ReinforcementConfig2"),
+ Arrays.asList(
+ "CNNTrainer_reinforcementConfig2.py",
+ "start_training.sh",
+ "reinforcement_learning/__init__.py",
+ "reinforcement_learning/action_policy.py",
+ "reinforcement_learning/agent.py",
+ "reinforcement_learning/environment.py",
+ "reinforcement_learning/replay_memory.py",
+ "reinforcement_learning/util.py"
+ )
+ );
+ }
+
@Test
public void testCMakeGeneration() {
diff --git a/src/test/java/de/monticore/lang/monticar/cnnarch/gluongenerator/IntegrationPythonWrapperTest.java b/src/test/java/de/monticore/lang/monticar/cnnarch/gluongenerator/IntegrationPythonWrapperTest.java
new file mode 100644
index 0000000000000000000000000000000000000000..5456e75a11578f49eee2e0b2e4d6b3ad9f4345fc
--- /dev/null
+++ b/src/test/java/de/monticore/lang/monticar/cnnarch/gluongenerator/IntegrationPythonWrapperTest.java
@@ -0,0 +1,53 @@
+package de.monticore.lang.monticar.cnnarch.gluongenerator;
+
+import de.monticore.lang.monticar.cnnarch.gluongenerator.reinforcement.RewardFunctionSourceGenerator;
+import de.se_rwth.commons.logging.Finding;
+import de.se_rwth.commons.logging.Log;
+import org.junit.Before;
+import org.junit.Test;
+
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.Arrays;
+import java.util.stream.Collectors;
+
+import static junit.framework.TestCase.assertTrue;
+import static org.mockito.Mockito.mock;
+
+public class IntegrationPythonWrapperTest extends AbstractSymtabTest{
+ private RewardFunctionSourceGenerator rewardFunctionSourceGenerator;
+
+ @Before
+ public void setUp() {
+ // ensure an empty log
+ Log.getFindings().clear();
+ Log.enableFailQuick(false);
+ rewardFunctionSourceGenerator = mock(RewardFunctionSourceGenerator.class);
+ }
+
+ @Test
+ public void testReinforcementConfigWithRewardGeneration() {
+ Log.getFindings().clear();
+ Path modelPath = Paths.get("src/test/resources/valid_tests");
+ CNNTrain2Gluon trainGenerator = new CNNTrain2Gluon(rewardFunctionSourceGenerator);
+
+ trainGenerator.generate(modelPath, "ReinforcementConfig1");
+
+ assertTrue(Log.getFindings().stream().filter(Finding::isError).collect(Collectors.toList()).isEmpty());
+ checkFilesAreEqual(
+ Paths.get("./target/generated-sources-cnnarch"),
+ Paths.get("./src/test/resources/target_code/ReinforcementConfig1"),
+ Arrays.asList(
+ "CNNTrainer_reinforcementConfig1.py",
+ "start_training.sh",
+ "reinforcement_learning/__init__.py",
+ "reinforcement_learning/action_policy.py",
+ "reinforcement_learning/agent.py",
+ "reinforcement_learning/environment.py",
+ "reinforcement_learning/replay_memory.py",
+ "reinforcement_learning/util.py"
+ )
+ );
+ assertTrue(Paths.get("./target/generated-sources-cnnarch/reward/pylib").toFile().isDirectory());
+ }
+}
diff --git a/src/test/resources/target_code/CNNCreator_Alexnet.py b/src/test/resources/target_code/CNNCreator_Alexnet.py
index 1b2befae05313959c0c3ec2d13c212e671764bec..4f2158e26a0edcf6bea159738077d0c03815127e 100644
--- a/src/test/resources/target_code/CNNCreator_Alexnet.py
+++ b/src/test/resources/target_code/CNNCreator_Alexnet.py
@@ -6,12 +6,15 @@ from CNNNet_Alexnet import Net
class CNNCreator_Alexnet:
_model_dir_ = "model/Alexnet/"
_model_prefix_ = "model"
- _input_shapes_ = [(3,224,224)]
+ _input_shapes_ = [(3,224,224,)]
def __init__(self):
self.weight_initializer = mx.init.Normal()
self.net = None
+ def get_input_shapes(self):
+ return self._input_shapes_
+
def load(self, context):
lastEpoch = 0
param_file = None
diff --git a/src/test/resources/target_code/CNNCreator_CifarClassifierNetwork.py b/src/test/resources/target_code/CNNCreator_CifarClassifierNetwork.py
index 5d63b10b03d997f924460768cd29cb8b4c3a6937..a1b864217df9e5c89fa215f3d929049b05ca38b2 100644
--- a/src/test/resources/target_code/CNNCreator_CifarClassifierNetwork.py
+++ b/src/test/resources/target_code/CNNCreator_CifarClassifierNetwork.py
@@ -6,12 +6,15 @@ from CNNNet_CifarClassifierNetwork import Net
class CNNCreator_CifarClassifierNetwork:
_model_dir_ = "model/CifarClassifierNetwork/"
_model_prefix_ = "model"
- _input_shapes_ = [(3,32,32)]
+ _input_shapes_ = [(3,32,32,)]
def __init__(self):
self.weight_initializer = mx.init.Normal()
self.net = None
+ def get_input_shapes(self):
+ return self._input_shapes_
+
def load(self, context):
lastEpoch = 0
param_file = None
diff --git a/src/test/resources/target_code/CNNCreator_VGG16.py b/src/test/resources/target_code/CNNCreator_VGG16.py
index af1eed2b9ff4bbe36f422e4e1cde83f2685e75a0..940d926e603ccf99f679c377b1db7260e1da992e 100644
--- a/src/test/resources/target_code/CNNCreator_VGG16.py
+++ b/src/test/resources/target_code/CNNCreator_VGG16.py
@@ -6,12 +6,15 @@ from CNNNet_VGG16 import Net
class CNNCreator_VGG16:
_model_dir_ = "model/VGG16/"
_model_prefix_ = "model"
- _input_shapes_ = [(3,224,224)]
+ _input_shapes_ = [(3,224,224,)]
def __init__(self):
self.weight_initializer = mx.init.Normal()
self.net = None
+ def get_input_shapes(self):
+ return self._input_shapes_
+
def load(self, context):
lastEpoch = 0
param_file = None
diff --git a/src/test/resources/target_code/CNNPredictor_Alexnet.h b/src/test/resources/target_code/CNNPredictor_Alexnet.h
index 9f161f38d9d33d8169068e5212343658de7ce2af..a8f46d311dd98deab4afb8eae8b071faf90d929d 100644
--- a/src/test/resources/target_code/CNNPredictor_Alexnet.h
+++ b/src/test/resources/target_code/CNNPredictor_Alexnet.h
@@ -30,8 +30,7 @@ public:
void predict(const std::vector &data,
std::vector &predictions){
- MXPredSetInput(handle, "data", data.data(), data.size());
- //MXPredSetInput(handle, "data", data.data(), data.size());
+ MXPredSetInput(handle, "data", data.data(), static_cast(data.size()));
MXPredForward(handle);
@@ -61,8 +60,6 @@ public:
int dev_type = use_gpu ? 2 : 1;
int dev_id = 0;
- handle = 0;
-
if (json_data.GetLength() == 0 ||
param_data.GetLength() == 0) {
std::exit(-1);
@@ -70,10 +67,8 @@ public:
const mx_uint num_input_nodes = input_keys.size();
- const char* input_keys_ptr[num_input_nodes];
- for(mx_uint i = 0; i < num_input_nodes; i++){
- input_keys_ptr[i] = input_keys[i].c_str();
- }
+ const char* input_key[1] = { "data" };
+ const char** input_keys_ptr = input_key;
mx_uint shape_data_size = 0;
mx_uint input_shape_indptr[input_shapes.size() + 1];
@@ -92,8 +87,8 @@ public:
}
}
- MXPredCreate((const char*)json_data.GetBuffer(),
- (const char*)param_data.GetBuffer(),
+ MXPredCreate(static_cast(json_data.GetBuffer()),
+ static_cast(param_data.GetBuffer()),
static_cast(param_data.GetLength()),
dev_type,
dev_id,
diff --git a/src/test/resources/target_code/CNNPredictor_CifarClassifierNetwork.h b/src/test/resources/target_code/CNNPredictor_CifarClassifierNetwork.h
index ad0ed59d3f826f5fe0cf5db9d9944aa2a3373615..5a7851ecb2af863b17930f39f044f3b8820ec0c2 100644
--- a/src/test/resources/target_code/CNNPredictor_CifarClassifierNetwork.h
+++ b/src/test/resources/target_code/CNNPredictor_CifarClassifierNetwork.h
@@ -30,8 +30,7 @@ public:
void predict(const std::vector &data,
std::vector &softmax){
- MXPredSetInput(handle, "data", data.data(), data.size());
- //MXPredSetInput(handle, "data", data.data(), data.size());
+ MXPredSetInput(handle, "data", data.data(), static_cast(data.size()));
MXPredForward(handle);
@@ -61,8 +60,6 @@ public:
int dev_type = use_gpu ? 2 : 1;
int dev_id = 0;
- handle = 0;
-
if (json_data.GetLength() == 0 ||
param_data.GetLength() == 0) {
std::exit(-1);
@@ -70,10 +67,8 @@ public:
const mx_uint num_input_nodes = input_keys.size();
- const char* input_keys_ptr[num_input_nodes];
- for(mx_uint i = 0; i < num_input_nodes; i++){
- input_keys_ptr[i] = input_keys[i].c_str();
- }
+ const char* input_key[1] = { "data" };
+ const char** input_keys_ptr = input_key;
mx_uint shape_data_size = 0;
mx_uint input_shape_indptr[input_shapes.size() + 1];
@@ -92,8 +87,8 @@ public:
}
}
- MXPredCreate((const char*)json_data.GetBuffer(),
- (const char*)param_data.GetBuffer(),
+ MXPredCreate(static_cast(json_data.GetBuffer()),
+ static_cast(param_data.GetBuffer()),
static_cast(param_data.GetLength()),
dev_type,
dev_id,
diff --git a/src/test/resources/target_code/CNNPredictor_VGG16.h b/src/test/resources/target_code/CNNPredictor_VGG16.h
index f3487cb602c6e201f2ae5545de4d32e73c7042c2..080f2efe6448a224c18db3c846ec1646fba06cef 100644
--- a/src/test/resources/target_code/CNNPredictor_VGG16.h
+++ b/src/test/resources/target_code/CNNPredictor_VGG16.h
@@ -30,8 +30,7 @@ public:
void predict(const std::vector &data,
std::vector &predictions){
- MXPredSetInput(handle, "data", data.data(), data.size());
- //MXPredSetInput(handle, "data", data.data(), data.size());
+ MXPredSetInput(handle, "data", data.data(), static_cast(data.size()));
MXPredForward(handle);
@@ -61,8 +60,6 @@ public:
int dev_type = use_gpu ? 2 : 1;
int dev_id = 0;
- handle = 0;
-
if (json_data.GetLength() == 0 ||
param_data.GetLength() == 0) {
std::exit(-1);
@@ -70,10 +67,8 @@ public:
const mx_uint num_input_nodes = input_keys.size();
- const char* input_keys_ptr[num_input_nodes];
- for(mx_uint i = 0; i < num_input_nodes; i++){
- input_keys_ptr[i] = input_keys[i].c_str();
- }
+ const char* input_key[1] = { "data" };
+ const char** input_keys_ptr = input_key;
mx_uint shape_data_size = 0;
mx_uint input_shape_indptr[input_shapes.size() + 1];
@@ -92,8 +87,8 @@ public:
}
}
- MXPredCreate((const char*)json_data.GetBuffer(),
- (const char*)param_data.GetBuffer(),
+ MXPredCreate(static_cast(json_data.GetBuffer()),
+ static_cast(param_data.GetBuffer()),
static_cast(param_data.GetLength()),
dev_type,
dev_id,
diff --git a/src/test/resources/target_code/ReinforcementConfig1/CNNTrainer_reinforcementConfig1.py b/src/test/resources/target_code/ReinforcementConfig1/CNNTrainer_reinforcementConfig1.py
new file mode 100644
index 0000000000000000000000000000000000000000..09c905c4288f6b66f10e9fc51829bc9b9d5a96ef
--- /dev/null
+++ b/src/test/resources/target_code/ReinforcementConfig1/CNNTrainer_reinforcementConfig1.py
@@ -0,0 +1,101 @@
+from reinforcement_learning.agent import DqnAgent
+from reinforcement_learning.util import AgentSignalHandler
+import reinforcement_learning.environment
+import CNNCreator_reinforcementConfig1
+
+import os
+import sys
+import re
+import logging
+import mxnet as mx
+
+session_output_dir = 'session'
+agent_name='reinforcement_agent'
+session_param_output = os.path.join(session_output_dir, agent_name)
+
+def resume_session():
+ session_param_output = os.path.join(session_output_dir, agent_name)
+ resume_session = False
+ resume_directory = None
+ if os.path.isdir(session_output_dir) and os.path.isdir(session_param_output):
+ regex = re.compile(r'\d\d\d\d-\d\d-\d\d-\d\d-\d\d')
+ dir_content = os.listdir(session_param_output)
+ session_files = filter(regex.search, dir_content)
+ session_files.sort(reverse=True)
+ for d in session_files:
+ interrupted_session_dir = os.path.join(session_param_output, d, '.interrupted_session')
+ if os.path.isdir(interrupted_session_dir):
+ resume = raw_input('Interrupted session from {} found. Do you want to resume? (y/n) '.format(d))
+ if resume == 'y':
+ resume_session = True
+ resume_directory = interrupted_session_dir
+ break
+ return resume_session, resume_directory
+
+if __name__ == "__main__":
+ env_params = {
+ 'ros_node_name' : 'reinforcementConfig1TrainerNode',
+ 'state_topic' : '/environment/state',
+ 'action_topic' : '/environment/action',
+ 'reset_topic' : '/environment/reset',
+ }
+ env = reinforcement_learning.environment.RosEnvironment(**env_params)
+ context = mx.cpu()
+ net_creator = CNNCreator_reinforcementConfig1.CNNCreator_reinforcementConfig1()
+ net_creator.construct(context)
+
+ replay_memory_params = {
+ 'method':'buffer',
+ 'memory_size':1000000,
+ 'sample_size':64,
+ 'state_dtype':'float32',
+ 'action_dtype':'uint8',
+ 'rewards_dtype':'float32'
+ }
+
+ policy_params = {
+ 'method':'epsgreedy',
+ 'epsilon': 1,
+ 'min_epsilon': 0.02,
+ 'epsilon_decay_method': 'linear',
+ 'epsilon_decay': 0.0001,
+ }
+
+ resume_session, resume_directory = resume_session()
+
+ if resume_session:
+ agent = DqnAgent.resume_from_session(resume_directory, net_creator.net, env)
+ else:
+ agent = DqnAgent(
+ network = net_creator.net,
+ environment=env,
+ replay_memory_params=replay_memory_params,
+ policy_params=policy_params,
+ state_dim=net_creator.get_input_shapes()[0],
+ discount_factor=0.99999,
+ loss_function='huber_loss',
+ optimizer='adam',
+ optimizer_params={
+ 'learning_rate': 0.001 },
+ training_episodes=1000,
+ train_interval=1,
+ use_fix_target=True,
+ target_update_interval=500,
+ double_dqn = True,
+ snapshot_interval=500,
+ agent_name=agent_name,
+ max_episode_step=10000,
+ output_directory=session_output_dir,
+ verbose=True,
+ live_plot = True,
+ make_logfile=True,
+ target_score=35000
+ )
+
+ signal_handler = AgentSignalHandler()
+ signal_handler.register_agent(agent)
+
+ train_successful = agent.train()
+
+ if train_successful:
+ agent.save_best_network(net_creator._model_dir_ + net_creator._model_prefix_ + '_newest', epoch=0)
\ No newline at end of file
diff --git a/src/test/resources/target_code/ReinforcementConfig1/reinforcement_learning/__init__.py b/src/test/resources/target_code/ReinforcementConfig1/reinforcement_learning/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/test/resources/target_code/ReinforcementConfig1/reinforcement_learning/action_policy.py b/src/test/resources/target_code/ReinforcementConfig1/reinforcement_learning/action_policy.py
new file mode 100644
index 0000000000000000000000000000000000000000..f43a211fe353f5fcb95fb8dd38d7e412fd1d1ab4
--- /dev/null
+++ b/src/test/resources/target_code/ReinforcementConfig1/reinforcement_learning/action_policy.py
@@ -0,0 +1,73 @@
+import numpy as np
+
+class ActionPolicyBuilder(object):
+ def __init__(self):
+ pass
+
+ def build_by_params(self,
+ method='epsgreedy',
+ epsilon=0.5,
+ min_epsilon=0.05,
+ epsilon_decay_method='no',
+ epsilon_decay=0.0,
+ action_dim=None):
+
+ if epsilon_decay_method == 'linear':
+ decay = LinearDecay(eps_decay=epsilon_decay, min_eps=min_epsilon)
+ else:
+ decay = NoDecay()
+
+ if method == 'epsgreedy':
+ assert action_dim is not None
+ assert len(action_dim) == 1
+ return EpsilonGreedyActionPolicy(eps=epsilon,
+ number_of_actions=action_dim[0], decay=decay)
+ else:
+ assert action_dim is not None
+ assert len(action_dim) == 1
+ return GreedyActionPolicy()
+
+class EpsilonGreedyActionPolicy(object):
+ def __init__(self, eps, number_of_actions, decay):
+ self.eps = eps
+ self.cur_eps = eps
+ self.__number_of_actions = number_of_actions
+ self.__decay_method = decay
+
+ def select_action(self, values):
+ do_exploration = (np.random.rand() < self.cur_eps)
+ if do_exploration:
+ action = np.random.randint(low=0, high=self.__number_of_actions)
+ else:
+ action = values.asnumpy().argmax()
+ return action
+
+ def decay(self):
+ self.cur_eps = self.__decay_method.decay(self.cur_eps)
+
+
+class GreedyActionPolicy(object):
+ def __init__(self):
+ pass
+
+ def select_action(self, values):
+ return values.asnumpy().argmax()
+
+ def decay(self):
+ pass
+
+
+class NoDecay(object):
+ def __init__(self):
+ pass
+
+ def decay(self, cur_eps):
+ return cur_eps
+
+class LinearDecay(object):
+ def __init__(self, eps_decay, min_eps=0):
+ self.eps_decay = eps_decay
+ self.min_eps = min_eps
+
+ def decay(self, cur_eps):
+ return max(cur_eps - self.eps_decay, self.min_eps)
\ No newline at end of file
diff --git a/src/test/resources/target_code/ReinforcementConfig1/reinforcement_learning/agent.py b/src/test/resources/target_code/ReinforcementConfig1/reinforcement_learning/agent.py
new file mode 100644
index 0000000000000000000000000000000000000000..acbc974d5295431ea6a88a562cff2fbd7d951066
--- /dev/null
+++ b/src/test/resources/target_code/ReinforcementConfig1/reinforcement_learning/agent.py
@@ -0,0 +1,506 @@
+import mxnet as mx
+import numpy as np
+import time
+import os
+import logging
+import sys
+import util
+import matplotlib.pyplot as plt
+from replay_memory import ReplayMemoryBuilder
+from action_policy import ActionPolicyBuilder
+from util import copy_net, get_loss_function
+from mxnet import nd, gluon, autograd
+
+class DqnAgent(object):
+ def __init__(self,
+ network,
+ environment,
+ replay_memory_params,
+ policy_params,
+ state_dim,
+ ctx=None,
+ discount_factor=.9,
+ loss_function='euclidean',
+ optimizer='rmsprop',
+ optimizer_params = {'learning_rate':0.09},
+ training_episodes=50,
+ train_interval=1,
+ use_fix_target=False,
+ double_dqn = False,
+ target_update_interval=10,
+ snapshot_interval=200,
+ agent_name='Dqn_agent',
+ max_episode_step=99999,
+ output_directory='model_parameters',
+ verbose=True,
+ live_plot = True,
+ make_logfile=True,
+ target_score=None):
+ assert 0 < discount_factor <= 1
+ assert train_interval > 0
+ assert target_update_interval > 0
+ assert snapshot_interval > 0
+ assert max_episode_step > 0
+ assert training_episodes > 0
+ assert replay_memory_params is not None
+ assert type(state_dim) is tuple
+
+ self.__ctx = mx.gpu() if ctx == 'gpu' else mx.cpu()
+ self.__qnet = network
+
+ self.__environment = environment
+ self.__discount_factor = discount_factor
+ self.__training_episodes = training_episodes
+ self.__train_interval = train_interval
+ self.__verbose = verbose
+ self.__state_dim = state_dim
+ self.__action_dim = self.__qnet(nd.random_normal(shape=((1,) + self.__state_dim), ctx=self.__ctx)).shape[1:]
+
+ replay_memory_params['state_dim'] = state_dim
+ self.__replay_memory_params = replay_memory_params
+ rm_builder = ReplayMemoryBuilder()
+ self.__memory = rm_builder.build_by_params(**replay_memory_params)
+ self.__minibatch_size = self.__memory.sample_size
+
+ policy_params['action_dim'] = self.__action_dim
+ self.__policy_params = policy_params
+ p_builder = ActionPolicyBuilder()
+ self.__policy = p_builder.build_by_params(**policy_params)
+
+ self.__target_update_interval = target_update_interval
+ self.__target_qnet = copy_net(self.__qnet, self.__state_dim, ctx=self.__ctx)
+ self.__loss_function_str = loss_function
+ self.__loss_function = get_loss_function(loss_function)
+ self.__agent_name = agent_name
+ self.__snapshot_interval = snapshot_interval
+ self.__creation_time = time.time()
+ self.__max_episode_step = max_episode_step
+ self.__optimizer = optimizer
+ self.__optimizer_params = optimizer_params
+ self.__make_logfile = make_logfile
+ self.__double_dqn = double_dqn
+ self.__use_fix_target = use_fix_target
+ self.__live_plot = live_plot
+ self.__user_given_directory = output_directory
+ self.__target_score = target_score
+
+ self.__interrupt_flag = False
+
+ # Training Context
+ self.__current_episode = 0
+ self.__total_steps = 0
+
+ # Initialize best network
+ self.__best_net = copy_net(self.__qnet, self.__state_dim, self.__ctx)
+ self.__best_avg_score = None
+
+ # Gluon Trainer definition
+ self.__training_stats = None
+
+ # Prepare output directory and logger
+ self.__output_directory = output_directory\
+ + '/' + self.__agent_name\
+ + '/' + time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime(self.__creation_time))
+ self.__logger = self.__setup_logging()
+ self.__logger.info('Agent created with following parameters: {}'.format(self.__make_config_dict()))
+
+ @classmethod
+ def from_config_file(cls, network, environment, config_file_path, ctx=None):
+ import json
+ # Load config
+ with open(config_file_path, 'r') as config_file:
+ config_dict = json.load(config_file)
+ return cls(network, environment, ctx=ctx, **config_dict)
+
+ @classmethod
+ def resume_from_session(cls, session_dir, net, environment):
+ import pickle
+ if not os.path.exists(session_dir):
+ raise ValueError('Session directory does not exist')
+
+ files = dict()
+ files['agent'] = os.path.join(session_dir, 'agent.p')
+ files['best_net_params'] = os.path.join(session_dir, 'best_net.params')
+ files['q_net_params'] = os.path.join(session_dir, 'qnet.params')
+ files['target_net_params'] = os.path.join(session_dir, 'target_net.params')
+
+ for file in files.values():
+ if not os.path.exists(file):
+ raise ValueError('Session directory is not complete: {} is missing'.format(file))
+
+ with open(files['agent'], 'rb') as f:
+ agent = pickle.load(f)
+
+ agent.__environment = environment
+ agent.__qnet = net
+ agent.__qnet.load_parameters(files['q_net_params'], agent.__ctx)
+ agent.__qnet.hybridize()
+ agent.__qnet(nd.random_normal(shape=((1,) + agent.__state_dim), ctx=agent.__ctx))
+ agent.__best_net = copy_net(agent.__qnet, agent.__state_dim, agent.__ctx)
+ agent.__best_net.load_parameters(files['best_net_params'], agent.__ctx)
+ agent.__target_qnet = copy_net(agent.__qnet, agent.__state_dim, agent.__ctx)
+ agent.__target_qnet.load_parameters(files['target_net_params'], agent.__ctx)
+
+ agent.__logger = agent.__setup_logging(append=True)
+ agent.__training_stats.logger = agent.__logger
+ agent.__logger.info('Agent was retrieved; Training can be continued')
+
+ return agent
+
+ def __interrupt_training(self):
+ import pickle
+ self.__logger.info('Training interrupted; Store state for resuming')
+ session_dir = os.path.join(self.__output_directory, '.interrupted_session')
+ if not os.path.exists(session_dir):
+ os.mkdir(session_dir)
+
+ del self.__training_stats.logger
+ logger = self.__logger
+ self.__logger = None
+ self.__environment.close()
+ self.__environment = None
+
+ self.__save_net(self.__qnet, 'qnet', session_dir)
+ self.__qnet = None
+ self.__save_net(self.__best_net, 'best_net', session_dir)
+ self.__best_net = None
+ self.__save_net(self.__target_qnet, 'target_net', session_dir)
+ self.__target_qnet = None
+
+ agent_session_file = os.path.join(session_dir, 'agent.p')
+
+ with open(agent_session_file, 'wb') as f:
+ pickle.dump(self, f)
+ self.__logger = logger
+ logger.info('State successfully stored')
+
+ @property
+ def current_episode(self):
+ return self.__current_episode
+
+ @property
+ def environment(self):
+ return self.__environment
+
+ def __adjust_optimizer_params(self, optimizer_params):
+ if 'weight_decay' in optimizer_params:
+ optimizer_params['wd'] = optimizer_params['weight_decay']
+ del optimizer_params['weight_decay']
+ if 'learning_rate_decay' in optimizer_params:
+ min_learning_rate = 1e-8
+ if 'learning_rate_minimum' in optimizer_params:
+ min_learning_rate = optimizer_params['learning_rate_minimum']
+ del optimizer_params['learning_rate_minimum']
+ optimizer_params['lr_scheduler'] = mx.lr.scheduler.FactorScheduler(
+ optimizer_params['step_size'],
+ factor=optimizer_params['learning_rate_decay'],
+ stop_factor_lr=min_learning_rate)
+ del optimizer_params['step_size']
+ del optimizer_params['learning_rate_decay']
+
+ return optimizer_params
+
+ def set_interrupt_flag(self, interrupt):
+ self.__interrupt_flag = interrupt
+
+
+ def __make_output_directory_if_not_exist(self):
+ assert self.__output_directory
+ if not os.path.exists(self.__output_directory):
+ os.makedirs(self.__output_directory)
+
+ def __setup_logging(self, append=False):
+ assert self.__output_directory
+ assert self.__agent_name
+
+ output_level = logging.DEBUG if self.__verbose else logging.WARNING
+ filemode = 'a' if append else 'w'
+
+ logformat = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
+ dateformat = '%d-%b-%y %H:%M:%S'
+ formatter = logging.Formatter(fmt=logformat, datefmt=dateformat)
+
+ logger = logging.getLogger('DQNAgent')
+ logger.setLevel(output_level)
+
+ stream_handler = logging.StreamHandler(sys.stdout)
+ stream_handler.setLevel(output_level)
+ stream_handler.setFormatter(formatter)
+ logger.addHandler(stream_handler)
+
+ if self.__make_logfile:
+ self.__make_output_directory_if_not_exist()
+ log_file = os.path.join(self.__output_directory, self.__agent_name + '.log')
+ file_handler = logging.FileHandler(log_file, mode=filemode)
+ file_handler.setLevel(output_level)
+ file_handler.setFormatter(formatter)
+ logger.addHandler(file_handler)
+
+ return logger
+
+ def __is_target_reached(self, avg_reward):
+ return self.__target_score is not None\
+ and avg_reward > self.__target_score
+
+
+ def get_q_values(self, state, with_best=False):
+ return self.get_batch_q_values(nd.array([state], ctx=self.__ctx), with_best=with_best)[0]
+
+ def get_batch_q_values(self, state_batch, with_best=False):
+ return self.__best_net(state_batch) if with_best else self.__qnet(state_batch)
+
+ def get_next_action(self, state, with_best=False):
+ q_values = self.get_q_values(state, with_best=with_best)
+ action = q_values.asnumpy().argmax()
+ return q_values.asnumpy().argmax()
+
+ def __sample_from_memory(self):
+ states, actions, rewards, next_states, terminals\
+ = self.__memory.sample(batch_size=self.__minibatch_size)
+ states = nd.array(states, ctx=self.__ctx)
+ actions = nd.array(actions, ctx=self.__ctx)
+ rewards = nd.array(rewards, ctx=self.__ctx)
+ next_states = nd.array(next_states, ctx=self.__ctx)
+ terminals = nd.array(terminals, ctx=self.__ctx)
+ return states, actions, rewards, next_states, terminals
+
+ def __determine_target_q_values(self, states, actions, rewards, next_states, terminals):
+ if self.__use_fix_target:
+ q_max_val = self.__target_qnet(next_states)
+ else:
+ q_max_val = self.__qnet(next_states)
+
+ if self.__double_dqn:
+ q_values_next_states = self.__qnet(next_states)
+ target_rewards = rewards + nd.choose_element_0index(q_max_val, nd.argmax_channel(q_values_next_states))\
+ * (1.0 - terminals) * self.__discount_factor
+ else:
+ target_rewards = rewards + nd.choose_element_0index(q_max_val, nd.argmax_channel(q_max_val))\
+ * (1.0 - terminals) * self.__discount_factor
+
+ target_qval = self.__qnet(states)
+ for t in range(target_rewards.shape[0]):
+ target_qval[t][actions[t]] = target_rewards[t]
+
+ return target_qval
+
+ def __train_q_net_step(self, trainer):
+ states, actions, rewards, next_states, terminals = self.__sample_from_memory()
+ target_qval = self.__determine_target_q_values(states, actions, rewards, next_states, terminals)
+ with autograd.record():
+ q_values = self.__qnet(states)
+ loss = self.__loss_function(q_values, target_qval)
+ loss.backward()
+ trainer.step(self.__minibatch_size)
+ return loss
+
+ def __do_snapshot_if_in_interval(self, episode):
+ do_snapshot = (episode != 0 and (episode % self.__snapshot_interval == 0))
+ if do_snapshot:
+ self.save_parameters(episode=episode)
+ self.__evaluate()
+
+ def __do_target_update_if_in_interval(self, total_steps):
+ do_target_update = (self.__use_fix_target and total_steps % self.__target_update_interval == 0)
+ if do_target_update:
+ self.__logger.info('Target network is updated after {} steps'.format(total_steps))
+ self.__target_qnet = copy_net(self.__qnet, self.__state_dim, self.__ctx)
+
+ def train(self, episodes=None):
+ self.__logger.info("--- Start training ---")
+ trainer = gluon.Trainer(self.__qnet.collect_params(), self.__optimizer, self.__adjust_optimizer_params(self.__optimizer_params))
+ episodes = episodes if episodes != None else self.__training_episodes
+
+ resume = (self.__current_episode > 0)
+ if resume:
+ self.__logger.info("Training session resumed")
+ self.__logger.info("Starting from episode {}".format(self.__current_episode))
+ else:
+ self.__training_stats = util.TrainingStats(self.__logger, episodes, self.__live_plot)
+
+ # Implementation Deep Q Learning described by Mnih et. al. in Playing Atari with Deep Reinforcement Learning
+ while self.__current_episode < episodes:
+ # Check interrupt flag
+ if self.__interrupt_flag:
+ self.__interrupt_flag = False
+ self.__interrupt_training()
+ return False
+
+ step = 0
+ episode_reward = 0
+ start = time.time()
+ state = self.__environment.reset()
+ episode_loss = 0
+ training_steps = 0
+ while step < self.__max_episode_step:
+ #1. Choose an action based on current game state and policy
+ q_values = self.__qnet(nd.array([state], ctx=self.__ctx))
+ action = self.__policy.select_action(q_values[0])
+
+ #2. Play the game for a single step
+ next_state, reward, terminal, _ = self.__environment.step(action)
+
+ #3. Store transition in replay memory
+ self.__memory.append(state, action, reward, next_state, terminal)
+
+ #4. Train the network if in interval
+ do_training = (self.__total_steps % self.__train_interval == 0\
+ and self.__memory.is_sample_possible(self.__minibatch_size))
+ if do_training:
+ loss = self.__train_q_net_step(trainer)
+ loss_sum = sum(loss).asnumpy()[0]
+ episode_loss += float(loss_sum)/float(self.__minibatch_size)
+ training_steps += 1
+
+ # Update target network if in interval
+ self.__do_target_update_if_in_interval(self.__total_steps)
+
+ step += 1
+ self.__total_steps += 1
+ episode_reward += reward
+ state = next_state
+
+ if terminal:
+ episode_loss = episode_loss if training_steps > 0 else None
+ _, _, avg_reward = self.__training_stats.log_episode(self.__current_episode, start, training_steps,
+ episode_loss, self.__policy.cur_eps, episode_reward)
+ break
+
+ self.__do_snapshot_if_in_interval(self.__current_episode)
+ self.__policy.decay()
+
+ if self.__is_target_reached(avg_reward):
+ self.__logger.info('Target score is reached in average; Training is stopped')
+ break
+
+ self.__current_episode += 1
+
+ self.__evaluate()
+ training_stats_file = os.path.join(self.__output_directory, 'training_stats.pdf')
+ self.__training_stats.save_stats(training_stats_file)
+ self.__logger.info('--------- Training finished ---------')
+ return True
+
+ def __save_net(self, net, filename, filedir=None):
+ filedir = self.__output_directory if filedir is None else filedir
+ filename = os.path.join(filedir, filename + '.params')
+ net.save_parameters(filename)
+
+
+ def save_parameters(self, episode=None, filename='dqn-agent-params'):
+ assert self.__output_directory
+ self.__make_output_directory_if_not_exist()
+
+ if(episode != None):
+ self.__logger.info('Saving model parameters after episode %d' % episode)
+ filename = filename + '-ep{}'.format(episode)
+ else:
+ self.__logger.info('Saving model parameters')
+ self.__save_net(self.__qnet, filename)
+
+ def evaluate(self, target=None, sample_games=100, verbose=True):
+ target = self.__target_score if target is None else target
+ if target:
+ target_achieved = 0
+ total_reward = 0
+
+ for g in range(sample_games):
+ state = self.__environment.reset()
+ step = 0
+ game_reward = 0
+ while step < self.__max_episode_step:
+ action = self.get_next_action(state)
+ state, reward, terminal, _ = self.__environment.step(action)
+ game_reward += reward
+
+ if terminal:
+ if verbose:
+ info = 'Game %d: Reward %f' % (g,game_reward)
+ self.__logger.debug(info)
+ if target:
+ if game_reward >= target:
+ target_achieved += 1
+ total_reward += game_reward
+ break
+
+ step += 1
+
+ avg_reward = float(total_reward)/float(sample_games)
+ info = 'Avg. Reward: %f' % avg_reward
+ if target:
+ target_achieved_ratio = int((float(target_achieved)/float(sample_games))*100)
+ info += '; Target Achieved in %d%% of games' % (target_achieved_ratio)
+
+ if verbose:
+ self.__logger.info(info)
+ return avg_reward
+
+ def __evaluate(self, verbose=True):
+ sample_games = 100
+ avg_reward = self.evaluate(sample_games=sample_games, verbose=False)
+ info = 'Evaluation -> Average Reward in {} games: {}'.format(sample_games, avg_reward)
+
+ if self.__best_avg_score is None or self.__best_avg_score <= avg_reward:
+ self.__best_net = copy_net(self.__qnet, self.__state_dim, self.__ctx)
+ self.__best_avg_score = avg_reward
+ info += ' (NEW BEST)'
+
+ if verbose:
+ self.__logger.info(info)
+
+
+
+ def play(self, update_frame=1, with_best=False):
+ step = 0
+ state = self.__environment.reset()
+ total_reward = 0
+ while step < self.__max_episode_step:
+ action = self.get_next_action(state, with_best=with_best)
+ state, reward, terminal, _ = self.__environment.step(action)
+ total_reward += reward
+ do_update_frame = (step % update_frame == 0)
+ if do_update_frame:
+ self.__environment.render()
+ time.sleep(.100)
+
+ if terminal:
+ break
+
+ step += 1
+ return total_reward
+
+ def save_best_network(self, path, epoch=0):
+ self.__logger.info('Saving best network with average reward of {}'.format(self.__best_avg_score))
+ self.__best_net.export(path, epoch=epoch)
+
+ def __make_config_dict(self):
+ config = dict()
+ config['discount_factor'] = self.__discount_factor
+ config['optimizer'] = self.__optimizer
+ config['optimizer_params'] = self.__optimizer_params
+ config['policy_params'] = self.__policy_params
+ config['replay_memory_params'] = self.__replay_memory_params
+ config['loss_function'] = self.__loss_function_str
+ config['optimizer'] = self.__optimizer
+ config['training_episodes'] = self.__training_episodes
+ config['train_interval'] = self.__train_interval
+ config['use_fix_target'] = self.__use_fix_target
+ config['double_dqn'] = self.__double_dqn
+ config['target_update_interval'] = self.__target_update_interval
+ config['snapshot_interval']= self.__snapshot_interval
+ config['agent_name'] = self.__agent_name
+ config['max_episode_step'] = self.__max_episode_step
+ config['output_directory'] = self.__user_given_directory
+ config['verbose'] = self.__verbose
+ config['live_plot'] = self.__live_plot
+ config['make_logfile'] = self.__make_logfile
+ config['target_score'] = self.__target_score
+ return config
+
+ def save_config_file(self):
+ import json
+ self.__make_output_directory_if_not_exist()
+ filename = os.path.join(self.__output_directory, 'config.json')
+ config = self.__make_config_dict()
+ with open(filename, mode='w') as fp:
+ json.dump(config, fp, indent=4)
\ No newline at end of file
diff --git a/src/test/resources/target_code/ReinforcementConfig1/reinforcement_learning/environment.py b/src/test/resources/target_code/ReinforcementConfig1/reinforcement_learning/environment.py
new file mode 100644
index 0000000000000000000000000000000000000000..405f1b0c2312a2fbcc01a179e61d31725caec967
--- /dev/null
+++ b/src/test/resources/target_code/ReinforcementConfig1/reinforcement_learning/environment.py
@@ -0,0 +1,144 @@
+import abc
+import logging
+logging.basicConfig(level=logging.INFO)
+logger = logging.getLogger(__name__)
+import reward_rewardFunction_executor
+
+class RewardFunction(object):
+ def __init__(self):
+ self.__reward_wrapper = reward_rewardFunction_executor.reward_rewardFunction_executor()
+ self.__reward_wrapper.init()
+
+ def reward(self, state, terminal):
+ inp = reward_rewardFunction_executor.reward_rewardFunction_input()
+ inp.state = state
+ inp.isTerminal = terminal
+ output = self.__reward_wrapper.execute(inp)
+ return output.reward
+
+
+
+class Environment:
+ __metaclass__ = abc.ABCMeta
+
+ def __init__(self):
+ self._reward_function = RewardFunction()
+
+ @abc.abstractmethod
+ def reset(self):
+ pass
+
+ @abc.abstractmethod
+ def step(self, action):
+ pass
+
+ @abc.abstractmethod
+ def close(self):
+ pass
+
+import rospy
+import thread
+import numpy as np
+import time
+from std_msgs.msg import Float32MultiArray, Bool, Int32
+
+class RosEnvironment(Environment):
+ def __init__(self,
+ ros_node_name='RosTrainingAgent',
+ timeout_in_s=3,
+ state_topic='state',
+ action_topic='action',
+ reset_topic='reset',
+ terminal_state_topic='terminal',
+ meta_topic='meta',
+ greeting_topic='greeting'):
+ super(RosEnvironment, self).__init__()
+ self.__timeout_in_s = timeout_in_s
+
+ self.__waiting_for_state_update = False
+ self.__waiting_for_terminal_update = False
+ self.__last_received_state = 0
+ self.__last_received_terminal = 0
+
+ rospy.loginfo("Initialize node {0}".format(ros_node_name))
+
+ self.__step_publisher = rospy.Publisher(action_topic, Int32, queue_size=1)
+ rospy.loginfo('Step Publisher initialized with topic {}'.format(action_topic))
+
+ self.__reset_publisher = rospy.Publisher(reset_topic, Bool, queue_size=1)
+ rospy.loginfo('Reset Publisher initialized with topic {}'.format(reset_topic))
+
+ rospy.init_node(ros_node_name, anonymous=True)
+
+ self.__state_subscriber = rospy.Subscriber(state_topic, Float32MultiArray, self.__state_callback)
+ rospy.loginfo('State Subscriber registered with topic {}'.format(state_topic))
+
+ self.__terminal_state_subscriber = rospy.Subscriber(terminal_state_topic, Bool, self.__terminal_state_callback)
+ rospy.loginfo('Terminal State Subscriber registered with topic {}'.format(terminal_state_topic))
+
+ rate = rospy.Rate(10)
+
+ thread.start_new_thread(rospy.spin, ())
+ time.sleep(2)
+
+ def reset(self):
+ time.sleep(0.5)
+ reset_message = Bool()
+ reset_message.data = True
+ self.__waiting_for_state_update = True
+ self.__reset_publisher.publish(reset_message)
+ while self.__last_received_terminal:
+ self.__wait_for_new_state(self.__reset_publisher, reset_message)
+ return self.__last_received_state
+
+ def step(self, action):
+ action_rospy = Int32()
+ action_rospy.data = action
+
+ logger.debug('Send action: {}'.format(action))
+
+ self.__waiting_for_state_update = True
+ self.__waiting_for_terminal_update = True
+ self.__step_publisher.publish(action_rospy)
+ self.__wait_for_new_state(self.__step_publisher, action_rospy)
+ next_state = self.__last_received_state
+ terminal = self.__last_received_terminal
+ reward = self.__calc_reward(next_state, terminal)
+ rospy.logdebug('Calculated reward: {}'.format(reward))
+
+ return next_state, reward, terminal, 0
+
+ def __wait_for_new_state(self, publisher, msg):
+ time_of_timeout = time.time() + self.__timeout_in_s
+ timeout_counter = 0
+ while(self.__waiting_for_state_update or self.__waiting_for_terminal_update):
+ is_timeout = (time.time() > time_of_timeout)
+ if (is_timeout):
+ if timeout_counter < 3:
+ rospy.logwarn("Timeout occured: Retry message")
+ publisher.publish(msg)
+ timeout_counter += 1
+ time_of_timeout = time.time() + self.__timeout_in_s
+ else:
+ rospy.logerr("Timeout 3 times in a row: Terminate application")
+ exit()
+ time.sleep(100/1000)
+
+ def close(self):
+ rospy.signal_shutdown('Program ended!')
+
+
+ def __state_callback(self, data):
+ self.__last_received_state = np.array(data.data, dtype='double')
+ rospy.logdebug('Received state: {}'.format(self.__last_received_state))
+ self.__waiting_for_state_update = False
+
+ def __terminal_state_callback(self, data):
+ self.__last_received_terminal = data.data
+ rospy.logdebug('Received terminal flag: {}'.format(self.__last_received_terminal))
+ logger.debug('Received terminal: {}'.format(self.__last_received_terminal))
+ self.__waiting_for_terminal_update = False
+
+ def __calc_reward(self, state, terminal):
+ # C++ Wrapper call
+ return self._reward_function.reward(state, terminal)
diff --git a/src/test/resources/target_code/ReinforcementConfig1/reinforcement_learning/replay_memory.py b/src/test/resources/target_code/ReinforcementConfig1/reinforcement_learning/replay_memory.py
new file mode 100644
index 0000000000000000000000000000000000000000..e66cd9350cab02144f994cab706249ef5a1e4288
--- /dev/null
+++ b/src/test/resources/target_code/ReinforcementConfig1/reinforcement_learning/replay_memory.py
@@ -0,0 +1,155 @@
+import numpy as np
+
+class ReplayMemoryBuilder(object):
+ def __init__(self):
+ self.__supported_methods = ['online', 'buffer', 'combined']
+
+ def build_by_params(self,
+ state_dim,
+ method='online',
+ state_dtype='float32',
+ action_dtype='uint8',
+ rewards_dtype='float32',
+ memory_size=1000,
+ sample_size=32):
+ assert state_dim is not None
+ assert method in self.__supported_methods
+
+ if method == 'online':
+ return self.build_online_memory(state_dim=state_dim, state_dtype=state_dtype,
+ action_dtype=action_dtype, rewards_dtype=rewards_dtype)
+ else:
+ assert memory_size is not None and memory_size > 0
+ assert sample_size is not None and sample_size > 0
+ if method == 'buffer':
+ return self.build_buffered_memory(state_dim=state_dim, sample_size=sample_size,
+ memory_size=memory_size, state_dtype=state_dtype, action_dtype=action_dtype,
+ rewards_dtype=rewards_dtype)
+ else:
+ return self.build_combined_memory(state_dim=state_dim, sample_size=sample_size,
+ memory_size=memory_size, state_dtype=state_dtype, action_dtype=action_dtype,
+ rewards_dtype=rewards_dtype)
+
+ def build_buffered_memory(self, state_dim, memory_size=1000, sample_size=1, state_dtype='float32',
+ action_dtype='uint8', rewards_dtype='float32'):
+ assert memory_size > 0
+ assert sample_size > 0
+ return ReplayMemory(state_dim, size=memory_size, sample_size=sample_size,
+ state_dtype=state_dtype, action_dtype=action_dtype, rewards_dtype=rewards_dtype)
+
+ def build_combined_memory(self, state_dim, memory_size=1000, sample_size=1, state_dtype='float32',
+ action_dtype='uint8', rewards_dtype='float32'):
+ assert memory_size > 0
+ assert sample_size > 0
+ return CombinedReplayMemory(state_dim, size=memory_size, sample_size=sample_size,
+ state_dtype=state_dtype, action_dtype=action_dtype, rewards_dtype=rewards_dtype)
+
+ def build_online_memory(self, state_dim, state_dtype='float32', action_dtype='uint8',
+ rewards_dtype='float32'):
+ return OnlineReplayMemory(state_dim, state_dtype=state_dtype, action_dtype=action_dtype,
+ rewards_dtype=rewards_dtype)
+
+class ReplayMemory(object):
+ def __init__(self, state_dim, sample_size, size=1000, state_dtype='uint8', action_dtype='uint8', rewards_dtype='float32'):
+ assert size > 0, "Size must be greater than zero"
+ assert type(state_dim) is tuple, "State dimension must be a tuple"
+ assert sample_size > 0
+ self._size = size
+ self._sample_size = sample_size
+ self._cur_size = 0
+ self._pointer = 0
+ self._state_dim = state_dim
+ self._state_dtype = state_dtype
+ self._action_dtype = action_dtype
+ self._rewards_dtype = rewards_dtype
+ self._states = np.zeros((self._size,) + state_dim, dtype=state_dtype)
+ self._actions = np.array([0] * self._size, dtype=action_dtype)
+ self._rewards = np.array([0] * self._size, dtype=rewards_dtype)
+ self._next_states = np.zeros((self._size,) + state_dim, dtype=state_dtype)
+ self._terminals = np.array([0] * self._size, dtype='bool')
+
+ @property
+ def sample_size(self):
+ return self._sample_size
+
+ def append(self, state, action, reward, next_state, terminal):
+ self._states[self._pointer] = state
+ self._actions[self._pointer] = action
+ self._rewards[self._pointer] = reward
+ self._next_states[self._pointer] = next_state
+ self._terminals[self._pointer] = terminal
+
+ self._pointer = self._pointer + 1
+ if self._pointer == self._size:
+ self._pointer = 0
+
+ self._cur_size = min(self._size, self._cur_size + 1)
+
+ def at(self, index):
+ return self._states[index],\
+ self._actions[index],\
+ self._rewards[index],\
+ self._next_states[index],\
+ self._terminals[index]
+
+ def is_sample_possible(self, batch_size=None):
+ batch_size = batch_size if batch_size is not None else self._sample_size
+ return self._cur_size >= batch_size
+
+ def sample(self, batch_size=None):
+ batch_size = batch_size if batch_size is not None else self._sample_size
+ assert self._cur_size >= batch_size, "Size of replay memory must be larger than batch size"
+ i=0
+ states = np.zeros((batch_size,)+self._state_dim, dtype=self._state_dtype)
+ actions = np.zeros(batch_size, dtype=self._action_dtype)
+ rewards = np.zeros(batch_size, dtype=self._rewards_dtype)
+ next_states = np.zeros((batch_size,)+self._state_dim, dtype=self._state_dtype)
+ terminals = np.zeros(batch_size, dtype='bool')
+
+ while i < batch_size:
+ rnd_index = np.random.randint(low=0, high=self._cur_size)
+ states[i] = self._states.take(rnd_index, axis=0)
+ actions[i] = self._actions.take(rnd_index, axis=0)
+ rewards[i] = self._rewards.take(rnd_index, axis=0)
+ next_states[i] = self._next_states.take(rnd_index, axis=0)
+ terminals[i] = self._terminals.take(rnd_index, axis=0)
+ i += 1
+
+ return states, actions, rewards, next_states, terminals
+
+
+class OnlineReplayMemory(ReplayMemory):
+ def __init__(self, state_dim, state_dtype='float32', action_dtype='uint8', rewards_dtype='float32'):
+ super(OnlineReplayMemory, self).__init__(state_dim, sample_size=1, size=1,
+ state_dtype=state_dtype, action_dtype=action_dtype, rewards_dtype=rewards_dtype)
+
+
+class CombinedReplayMemory(ReplayMemory):
+ def __init__(self, state_dim, sample_size, size=1000,
+ state_dtype='uint8', action_dtype='uint8', rewards_dtype='float32'):
+ super(CombinedReplayMemory, self).__init__(state_dim, sample_size=(sample_size - 1), size=size,
+ state_dtype=state_dtype, action_dtype=action_dtype, rewards_dtype=rewards_dtype)
+
+ self._last_state = np.zeros((1,) + state_dim, dtype=state_dtype)
+ self._last_action = np.array([0], dtype=action_dtype)
+ self._last_reward = np.array([0], dtype=rewards_dtype)
+ self._last_next_state = np.zeros((1,) + state_dim, dtype=state_dtype)
+ self._last_terminal = np.array([0], dtype='bool')
+
+ def append(self, state, action, reward, next_state, terminal):
+ super(CombinedReplayMemory, self).append(state, action, reward, next_state, terminal)
+ self._last_state = state
+ self._last_action = action
+ self._last_reward = reward
+ self._last_next_state = next_state
+ self._last_terminal = terminal
+
+ def sample(self, batch_size=None):
+ batch_size = (batch_size-1) if batch_size is not None else self._sample_size
+ states, actions, rewards, next_states, terminals = super(CombinedReplayMemory, self).sample(batch_size=batch_size)
+ states = np.append(states, [self._last_state], axis=0)
+ actions = np.append(actions, [self._last_action], axis=0)
+ rewards = np.append(rewards, [self._last_reward], axis=0)
+ next_states = np.append(next_states, [self._last_next_state], axis=0)
+ terminals = np.append(terminals, [self._last_terminal], axis=0)
+ return states, actions, rewards, next_states, terminals
\ No newline at end of file
diff --git a/src/test/resources/target_code/ReinforcementConfig1/reinforcement_learning/util.py b/src/test/resources/target_code/ReinforcementConfig1/reinforcement_learning/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..58578932a13f7d12f1aa840d06af1951bebc3f0f
--- /dev/null
+++ b/src/test/resources/target_code/ReinforcementConfig1/reinforcement_learning/util.py
@@ -0,0 +1,140 @@
+import signal
+import sys
+import numpy as np
+import matplotlib.pyplot as plt
+from matplotlib import style
+import time
+import os
+import mxnet
+from mxnet import gluon, nd
+
+
+LOSS_FUNCTIONS = {
+ 'l1': gluon.loss.L1Loss(),
+ 'euclidean': gluon.loss.L2Loss(),
+ 'huber_loss': gluon.loss.HuberLoss(),
+ 'softmax_cross_entropy': gluon.loss.SoftmaxCrossEntropyLoss(),
+ 'sigmoid_cross_entropy': gluon.loss.SigmoidBinaryCrossEntropyLoss()}
+
+def copy_net(net, input_state_dim, ctx, tmp_filename='tmp.params'):
+ assert isinstance(net, gluon.HybridBlock)
+ assert type(net.__class__) is type
+ net.save_parameters(tmp_filename)
+ net2 = net.__class__()
+ net2.load_parameters(tmp_filename, ctx=ctx)
+ os.remove(tmp_filename)
+ net2.hybridize()
+ net2(nd.ones((1,) + input_state_dim, ctx=ctx))
+ return net2
+
+def get_loss_function(loss_function_name):
+ if loss_function_name not in LOSS_FUNCTIONS:
+ raise ValueError('Loss function does not exist')
+ return LOSS_FUNCTIONS[loss_function_name]
+
+
+class AgentSignalHandler(object):
+ def __init__(self):
+ signal.signal(signal.SIGINT, self.interrupt_training)
+ self.__agent = None
+ self.__times_interrupted = 0
+
+ def register_agent(self, agent):
+ self.__agent = agent
+
+ def interrupt_training(self, sig, frame):
+ self.__times_interrupted = self.__times_interrupted + 1
+ if self.__times_interrupted <= 3:
+ if self.__agent:
+ self.__agent.set_interrupt_flag(True)
+ else:
+ print('Interrupt called three times: Force quit')
+ sys.exit(1)
+
+style.use('fivethirtyeight')
+class TrainingStats(object):
+ def __init__(self, logger, max_episodes, live_plot=True):
+ self.__logger = logger
+ self.__max_episodes = max_episodes
+ self.__all_avg_loss = np.zeros((max_episodes,))
+ self.__all_total_rewards = np.zeros((max_episodes,))
+ self.__all_eps = np.zeros((max_episodes,))
+ self.__all_time = np.zeros((max_episodes,))
+ self.__all_mean_reward_last_100_episodes = np.zeros((max_episodes,))
+ self.__live_plot = live_plot
+
+ @property
+ def logger(self):
+ return self.__logger
+
+ @logger.setter
+ def logger(self, logger):
+ self.__logger = logger
+
+ @logger.deleter
+ def logger(self):
+ self.__logger = None
+
+ def add_avg_loss(self, episode, avg_loss):
+ self.__all_avg_loss[episode] = avg_loss
+
+ def add_total_reward(self, episode, total_reward):
+ self.__all_total_rewards[episode] = total_reward
+
+ def add_eps(self, episode, eps):
+ self.__all_eps[episode] = eps
+
+ def add_time(self, episode, time):
+ self.__all_time[episode] = time
+
+ def add_mean_reward_last_100(self, episode, mean_reward):
+ self.__all_mean_reward_last_100_episodes[episode] = mean_reward
+
+ def log_episode(self, episode, start_time, training_steps, loss, eps, reward):
+ self.add_eps(episode, eps)
+ self.add_total_reward(episode, reward)
+ end = time.time()
+ if training_steps == 0:
+ avg_loss = 0
+ else:
+ avg_loss = float(loss)/float(training_steps)
+
+ mean_reward_last_100 = self.mean_of_reward(episode, last=100)
+
+ time_elapsed = end - start_time
+ info = "Episode: %d, Total Reward: %.3f, Avg. Reward Last 100 Episodes: %.3f, Avg Loss: %.3f, Time: %.3f, Training Steps: %d, Eps: %.3f"\
+ % (episode, reward, mean_reward_last_100, avg_loss, time_elapsed, training_steps, eps)
+ self.__logger.info(info)
+ self.add_avg_loss(episode, avg_loss)
+ self.add_time(episode, time_elapsed)
+ self.add_mean_reward_last_100(episode, mean_reward_last_100)
+
+ return avg_loss, time_elapsed, mean_reward_last_100
+
+ def mean_of_reward(self, cur_episode, last=100):
+ if cur_episode > 0:
+ reward_last_100 = self.__all_total_rewards[max(0, cur_episode-last):cur_episode]
+ return np.mean(reward_last_100)
+ else:
+ return self.__all_total_rewards[0]
+
+ def save_stats(self, file):
+ fig = plt.figure(figsize=(20,20))
+
+ sub_rewards = fig.add_subplot(221)
+ sub_rewards.set_title('Total Rewards per episode')
+ sub_rewards.plot(np.arange(self.__max_episodes), self.__all_total_rewards)
+
+ sub_loss = fig.add_subplot(222)
+ sub_loss.set_title('Avg. Loss per episode')
+ sub_loss.plot(np.arange(self.__max_episodes), self.__all_avg_loss)
+
+ sub_eps = fig.add_subplot(223)
+ sub_eps.set_title('Epsilon per episode')
+ sub_eps.plot(np.arange(self.__max_episodes), self.__all_eps)
+
+ sub_rewards = fig.add_subplot(224)
+ sub_rewards.set_title('Avg. mean reward of last 100 episodes')
+ sub_rewards.plot(np.arange(self.__max_episodes), self.__all_mean_reward_last_100_episodes)
+
+ plt.savefig(file)
\ No newline at end of file
diff --git a/src/test/resources/target_code/ReinforcementConfig1/start_training.sh b/src/test/resources/target_code/ReinforcementConfig1/start_training.sh
new file mode 100644
index 0000000000000000000000000000000000000000..d04ec2bd6217424a17ca9e35e740683c2a05358f
--- /dev/null
+++ b/src/test/resources/target_code/ReinforcementConfig1/start_training.sh
@@ -0,0 +1,2 @@
+#!/bin/bash
+python CNNTrainer_reinforcementConfig1.py
\ No newline at end of file
diff --git a/src/test/resources/target_code/ReinforcementConfig2/CNNTrainer_reinforcementConfig2.py b/src/test/resources/target_code/ReinforcementConfig2/CNNTrainer_reinforcementConfig2.py
new file mode 100644
index 0000000000000000000000000000000000000000..afcfdbb0d785c5a2a58800e21aa24d72819eaeee
--- /dev/null
+++ b/src/test/resources/target_code/ReinforcementConfig2/CNNTrainer_reinforcementConfig2.py
@@ -0,0 +1,106 @@
+from reinforcement_learning.agent import DqnAgent
+from reinforcement_learning.util import AgentSignalHandler
+import reinforcement_learning.environment
+import CNNCreator_reinforcementConfig2
+
+import os
+import sys
+import re
+import logging
+import mxnet as mx
+
+session_output_dir = 'session'
+agent_name='reinforcement_agent'
+session_param_output = os.path.join(session_output_dir, agent_name)
+
+def resume_session():
+ session_param_output = os.path.join(session_output_dir, agent_name)
+ resume_session = False
+ resume_directory = None
+ if os.path.isdir(session_output_dir) and os.path.isdir(session_param_output):
+ regex = re.compile(r'\d\d\d\d-\d\d-\d\d-\d\d-\d\d')
+ dir_content = os.listdir(session_param_output)
+ session_files = filter(regex.search, dir_content)
+ session_files.sort(reverse=True)
+ for d in session_files:
+ interrupted_session_dir = os.path.join(session_param_output, d, '.interrupted_session')
+ if os.path.isdir(interrupted_session_dir):
+ resume = raw_input('Interrupted session from {} found. Do you want to resume? (y/n) '.format(d))
+ if resume == 'y':
+ resume_session = True
+ resume_directory = interrupted_session_dir
+ break
+ return resume_session, resume_directory
+
+if __name__ == "__main__":
+ env = reinforcement_learning.environment.GymEnvironment('CartPole-v1')
+ context = mx.cpu()
+ net_creator = CNNCreator_reinforcementConfig2.CNNCreator_reinforcementConfig2()
+ net_creator.construct(context)
+
+ replay_memory_params = {
+ 'method':'buffer',
+ 'memory_size':10000,
+ 'sample_size':32,
+ 'state_dtype':'float32',
+ 'action_dtype':'uint8',
+ 'rewards_dtype':'float32'
+ }
+
+ policy_params = {
+ 'method':'epsgreedy',
+ 'epsilon': 1,
+ 'min_epsilon': 0.01,
+ 'epsilon_decay_method': 'linear',
+ 'epsilon_decay': 0.0001,
+ }
+
+ resume_session, resume_directory = resume_session()
+
+ if resume_session:
+ agent = DqnAgent.resume_from_session(resume_directory, net_creator.net, env)
+ else:
+ agent = DqnAgent(
+ network = net_creator.net,
+ environment=env,
+ replay_memory_params=replay_memory_params,
+ policy_params=policy_params,
+ state_dim=net_creator.get_input_shapes()[0],
+ discount_factor=0.999,
+ loss_function='euclidean',
+ optimizer='rmsprop',
+ optimizer_params={
+ 'weight_decay': 0.01,
+ 'centered': True,
+ 'gamma2': 0.9,
+ 'gamma1': 0.9,
+ 'clip_weights': 10.0,
+ 'learning_rate_decay': 0.9,
+ 'epsilon': 1.0E-6,
+ 'rescale_grad': 1.1,
+ 'clip_gradient': 10.0,
+ 'learning_rate_minimum': 1.0E-5,
+ 'learning_rate_policy': 'step',
+ 'learning_rate': 0.001,
+ 'step_size': 1000 },
+ training_episodes=200,
+ train_interval=1,
+ use_fix_target=False,
+ double_dqn = False,
+ snapshot_interval=20,
+ agent_name=agent_name,
+ max_episode_step=250,
+ output_directory=session_output_dir,
+ verbose=True,
+ live_plot = True,
+ make_logfile=True,
+ target_score=185.5
+ )
+
+ signal_handler = AgentSignalHandler()
+ signal_handler.register_agent(agent)
+
+ train_successful = agent.train()
+
+ if train_successful:
+ agent.save_best_network(net_creator._model_dir_ + net_creator._model_prefix_ + '_newest', epoch=0)
\ No newline at end of file
diff --git a/src/test/resources/target_code/ReinforcementConfig2/reinforcement_learning/__init__.py b/src/test/resources/target_code/ReinforcementConfig2/reinforcement_learning/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/test/resources/target_code/ReinforcementConfig2/reinforcement_learning/action_policy.py b/src/test/resources/target_code/ReinforcementConfig2/reinforcement_learning/action_policy.py
new file mode 100644
index 0000000000000000000000000000000000000000..f43a211fe353f5fcb95fb8dd38d7e412fd1d1ab4
--- /dev/null
+++ b/src/test/resources/target_code/ReinforcementConfig2/reinforcement_learning/action_policy.py
@@ -0,0 +1,73 @@
+import numpy as np
+
+class ActionPolicyBuilder(object):
+ def __init__(self):
+ pass
+
+ def build_by_params(self,
+ method='epsgreedy',
+ epsilon=0.5,
+ min_epsilon=0.05,
+ epsilon_decay_method='no',
+ epsilon_decay=0.0,
+ action_dim=None):
+
+ if epsilon_decay_method == 'linear':
+ decay = LinearDecay(eps_decay=epsilon_decay, min_eps=min_epsilon)
+ else:
+ decay = NoDecay()
+
+ if method == 'epsgreedy':
+ assert action_dim is not None
+ assert len(action_dim) == 1
+ return EpsilonGreedyActionPolicy(eps=epsilon,
+ number_of_actions=action_dim[0], decay=decay)
+ else:
+ assert action_dim is not None
+ assert len(action_dim) == 1
+ return GreedyActionPolicy()
+
+class EpsilonGreedyActionPolicy(object):
+ def __init__(self, eps, number_of_actions, decay):
+ self.eps = eps
+ self.cur_eps = eps
+ self.__number_of_actions = number_of_actions
+ self.__decay_method = decay
+
+ def select_action(self, values):
+ do_exploration = (np.random.rand() < self.cur_eps)
+ if do_exploration:
+ action = np.random.randint(low=0, high=self.__number_of_actions)
+ else:
+ action = values.asnumpy().argmax()
+ return action
+
+ def decay(self):
+ self.cur_eps = self.__decay_method.decay(self.cur_eps)
+
+
+class GreedyActionPolicy(object):
+ def __init__(self):
+ pass
+
+ def select_action(self, values):
+ return values.asnumpy().argmax()
+
+ def decay(self):
+ pass
+
+
+class NoDecay(object):
+ def __init__(self):
+ pass
+
+ def decay(self, cur_eps):
+ return cur_eps
+
+class LinearDecay(object):
+ def __init__(self, eps_decay, min_eps=0):
+ self.eps_decay = eps_decay
+ self.min_eps = min_eps
+
+ def decay(self, cur_eps):
+ return max(cur_eps - self.eps_decay, self.min_eps)
\ No newline at end of file
diff --git a/src/test/resources/target_code/ReinforcementConfig2/reinforcement_learning/agent.py b/src/test/resources/target_code/ReinforcementConfig2/reinforcement_learning/agent.py
new file mode 100644
index 0000000000000000000000000000000000000000..acbc974d5295431ea6a88a562cff2fbd7d951066
--- /dev/null
+++ b/src/test/resources/target_code/ReinforcementConfig2/reinforcement_learning/agent.py
@@ -0,0 +1,506 @@
+import mxnet as mx
+import numpy as np
+import time
+import os
+import logging
+import sys
+import util
+import matplotlib.pyplot as plt
+from replay_memory import ReplayMemoryBuilder
+from action_policy import ActionPolicyBuilder
+from util import copy_net, get_loss_function
+from mxnet import nd, gluon, autograd
+
+class DqnAgent(object):
+ def __init__(self,
+ network,
+ environment,
+ replay_memory_params,
+ policy_params,
+ state_dim,
+ ctx=None,
+ discount_factor=.9,
+ loss_function='euclidean',
+ optimizer='rmsprop',
+ optimizer_params = {'learning_rate':0.09},
+ training_episodes=50,
+ train_interval=1,
+ use_fix_target=False,
+ double_dqn = False,
+ target_update_interval=10,
+ snapshot_interval=200,
+ agent_name='Dqn_agent',
+ max_episode_step=99999,
+ output_directory='model_parameters',
+ verbose=True,
+ live_plot = True,
+ make_logfile=True,
+ target_score=None):
+ assert 0 < discount_factor <= 1
+ assert train_interval > 0
+ assert target_update_interval > 0
+ assert snapshot_interval > 0
+ assert max_episode_step > 0
+ assert training_episodes > 0
+ assert replay_memory_params is not None
+ assert type(state_dim) is tuple
+
+ self.__ctx = mx.gpu() if ctx == 'gpu' else mx.cpu()
+ self.__qnet = network
+
+ self.__environment = environment
+ self.__discount_factor = discount_factor
+ self.__training_episodes = training_episodes
+ self.__train_interval = train_interval
+ self.__verbose = verbose
+ self.__state_dim = state_dim
+ self.__action_dim = self.__qnet(nd.random_normal(shape=((1,) + self.__state_dim), ctx=self.__ctx)).shape[1:]
+
+ replay_memory_params['state_dim'] = state_dim
+ self.__replay_memory_params = replay_memory_params
+ rm_builder = ReplayMemoryBuilder()
+ self.__memory = rm_builder.build_by_params(**replay_memory_params)
+ self.__minibatch_size = self.__memory.sample_size
+
+ policy_params['action_dim'] = self.__action_dim
+ self.__policy_params = policy_params
+ p_builder = ActionPolicyBuilder()
+ self.__policy = p_builder.build_by_params(**policy_params)
+
+ self.__target_update_interval = target_update_interval
+ self.__target_qnet = copy_net(self.__qnet, self.__state_dim, ctx=self.__ctx)
+ self.__loss_function_str = loss_function
+ self.__loss_function = get_loss_function(loss_function)
+ self.__agent_name = agent_name
+ self.__snapshot_interval = snapshot_interval
+ self.__creation_time = time.time()
+ self.__max_episode_step = max_episode_step
+ self.__optimizer = optimizer
+ self.__optimizer_params = optimizer_params
+ self.__make_logfile = make_logfile
+ self.__double_dqn = double_dqn
+ self.__use_fix_target = use_fix_target
+ self.__live_plot = live_plot
+ self.__user_given_directory = output_directory
+ self.__target_score = target_score
+
+ self.__interrupt_flag = False
+
+ # Training Context
+ self.__current_episode = 0
+ self.__total_steps = 0
+
+ # Initialize best network
+ self.__best_net = copy_net(self.__qnet, self.__state_dim, self.__ctx)
+ self.__best_avg_score = None
+
+ # Gluon Trainer definition
+ self.__training_stats = None
+
+ # Prepare output directory and logger
+ self.__output_directory = output_directory\
+ + '/' + self.__agent_name\
+ + '/' + time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime(self.__creation_time))
+ self.__logger = self.__setup_logging()
+ self.__logger.info('Agent created with following parameters: {}'.format(self.__make_config_dict()))
+
+ @classmethod
+ def from_config_file(cls, network, environment, config_file_path, ctx=None):
+ import json
+ # Load config
+ with open(config_file_path, 'r') as config_file:
+ config_dict = json.load(config_file)
+ return cls(network, environment, ctx=ctx, **config_dict)
+
+ @classmethod
+ def resume_from_session(cls, session_dir, net, environment):
+ import pickle
+ if not os.path.exists(session_dir):
+ raise ValueError('Session directory does not exist')
+
+ files = dict()
+ files['agent'] = os.path.join(session_dir, 'agent.p')
+ files['best_net_params'] = os.path.join(session_dir, 'best_net.params')
+ files['q_net_params'] = os.path.join(session_dir, 'qnet.params')
+ files['target_net_params'] = os.path.join(session_dir, 'target_net.params')
+
+ for file in files.values():
+ if not os.path.exists(file):
+ raise ValueError('Session directory is not complete: {} is missing'.format(file))
+
+ with open(files['agent'], 'rb') as f:
+ agent = pickle.load(f)
+
+ agent.__environment = environment
+ agent.__qnet = net
+ agent.__qnet.load_parameters(files['q_net_params'], agent.__ctx)
+ agent.__qnet.hybridize()
+ agent.__qnet(nd.random_normal(shape=((1,) + agent.__state_dim), ctx=agent.__ctx))
+ agent.__best_net = copy_net(agent.__qnet, agent.__state_dim, agent.__ctx)
+ agent.__best_net.load_parameters(files['best_net_params'], agent.__ctx)
+ agent.__target_qnet = copy_net(agent.__qnet, agent.__state_dim, agent.__ctx)
+ agent.__target_qnet.load_parameters(files['target_net_params'], agent.__ctx)
+
+ agent.__logger = agent.__setup_logging(append=True)
+ agent.__training_stats.logger = agent.__logger
+ agent.__logger.info('Agent was retrieved; Training can be continued')
+
+ return agent
+
+ def __interrupt_training(self):
+ import pickle
+ self.__logger.info('Training interrupted; Store state for resuming')
+ session_dir = os.path.join(self.__output_directory, '.interrupted_session')
+ if not os.path.exists(session_dir):
+ os.mkdir(session_dir)
+
+ del self.__training_stats.logger
+ logger = self.__logger
+ self.__logger = None
+ self.__environment.close()
+ self.__environment = None
+
+ self.__save_net(self.__qnet, 'qnet', session_dir)
+ self.__qnet = None
+ self.__save_net(self.__best_net, 'best_net', session_dir)
+ self.__best_net = None
+ self.__save_net(self.__target_qnet, 'target_net', session_dir)
+ self.__target_qnet = None
+
+ agent_session_file = os.path.join(session_dir, 'agent.p')
+
+ with open(agent_session_file, 'wb') as f:
+ pickle.dump(self, f)
+ self.__logger = logger
+ logger.info('State successfully stored')
+
+ @property
+ def current_episode(self):
+ return self.__current_episode
+
+ @property
+ def environment(self):
+ return self.__environment
+
+ def __adjust_optimizer_params(self, optimizer_params):
+ if 'weight_decay' in optimizer_params:
+ optimizer_params['wd'] = optimizer_params['weight_decay']
+ del optimizer_params['weight_decay']
+ if 'learning_rate_decay' in optimizer_params:
+ min_learning_rate = 1e-8
+ if 'learning_rate_minimum' in optimizer_params:
+ min_learning_rate = optimizer_params['learning_rate_minimum']
+ del optimizer_params['learning_rate_minimum']
+ optimizer_params['lr_scheduler'] = mx.lr.scheduler.FactorScheduler(
+ optimizer_params['step_size'],
+ factor=optimizer_params['learning_rate_decay'],
+ stop_factor_lr=min_learning_rate)
+ del optimizer_params['step_size']
+ del optimizer_params['learning_rate_decay']
+
+ return optimizer_params
+
+ def set_interrupt_flag(self, interrupt):
+ self.__interrupt_flag = interrupt
+
+
+ def __make_output_directory_if_not_exist(self):
+ assert self.__output_directory
+ if not os.path.exists(self.__output_directory):
+ os.makedirs(self.__output_directory)
+
+ def __setup_logging(self, append=False):
+ assert self.__output_directory
+ assert self.__agent_name
+
+ output_level = logging.DEBUG if self.__verbose else logging.WARNING
+ filemode = 'a' if append else 'w'
+
+ logformat = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
+ dateformat = '%d-%b-%y %H:%M:%S'
+ formatter = logging.Formatter(fmt=logformat, datefmt=dateformat)
+
+ logger = logging.getLogger('DQNAgent')
+ logger.setLevel(output_level)
+
+ stream_handler = logging.StreamHandler(sys.stdout)
+ stream_handler.setLevel(output_level)
+ stream_handler.setFormatter(formatter)
+ logger.addHandler(stream_handler)
+
+ if self.__make_logfile:
+ self.__make_output_directory_if_not_exist()
+ log_file = os.path.join(self.__output_directory, self.__agent_name + '.log')
+ file_handler = logging.FileHandler(log_file, mode=filemode)
+ file_handler.setLevel(output_level)
+ file_handler.setFormatter(formatter)
+ logger.addHandler(file_handler)
+
+ return logger
+
+ def __is_target_reached(self, avg_reward):
+ return self.__target_score is not None\
+ and avg_reward > self.__target_score
+
+
+ def get_q_values(self, state, with_best=False):
+ return self.get_batch_q_values(nd.array([state], ctx=self.__ctx), with_best=with_best)[0]
+
+ def get_batch_q_values(self, state_batch, with_best=False):
+ return self.__best_net(state_batch) if with_best else self.__qnet(state_batch)
+
+ def get_next_action(self, state, with_best=False):
+ q_values = self.get_q_values(state, with_best=with_best)
+ action = q_values.asnumpy().argmax()
+ return q_values.asnumpy().argmax()
+
+ def __sample_from_memory(self):
+ states, actions, rewards, next_states, terminals\
+ = self.__memory.sample(batch_size=self.__minibatch_size)
+ states = nd.array(states, ctx=self.__ctx)
+ actions = nd.array(actions, ctx=self.__ctx)
+ rewards = nd.array(rewards, ctx=self.__ctx)
+ next_states = nd.array(next_states, ctx=self.__ctx)
+ terminals = nd.array(terminals, ctx=self.__ctx)
+ return states, actions, rewards, next_states, terminals
+
+ def __determine_target_q_values(self, states, actions, rewards, next_states, terminals):
+ if self.__use_fix_target:
+ q_max_val = self.__target_qnet(next_states)
+ else:
+ q_max_val = self.__qnet(next_states)
+
+ if self.__double_dqn:
+ q_values_next_states = self.__qnet(next_states)
+ target_rewards = rewards + nd.choose_element_0index(q_max_val, nd.argmax_channel(q_values_next_states))\
+ * (1.0 - terminals) * self.__discount_factor
+ else:
+ target_rewards = rewards + nd.choose_element_0index(q_max_val, nd.argmax_channel(q_max_val))\
+ * (1.0 - terminals) * self.__discount_factor
+
+ target_qval = self.__qnet(states)
+ for t in range(target_rewards.shape[0]):
+ target_qval[t][actions[t]] = target_rewards[t]
+
+ return target_qval
+
+ def __train_q_net_step(self, trainer):
+ states, actions, rewards, next_states, terminals = self.__sample_from_memory()
+ target_qval = self.__determine_target_q_values(states, actions, rewards, next_states, terminals)
+ with autograd.record():
+ q_values = self.__qnet(states)
+ loss = self.__loss_function(q_values, target_qval)
+ loss.backward()
+ trainer.step(self.__minibatch_size)
+ return loss
+
+ def __do_snapshot_if_in_interval(self, episode):
+ do_snapshot = (episode != 0 and (episode % self.__snapshot_interval == 0))
+ if do_snapshot:
+ self.save_parameters(episode=episode)
+ self.__evaluate()
+
+ def __do_target_update_if_in_interval(self, total_steps):
+ do_target_update = (self.__use_fix_target and total_steps % self.__target_update_interval == 0)
+ if do_target_update:
+ self.__logger.info('Target network is updated after {} steps'.format(total_steps))
+ self.__target_qnet = copy_net(self.__qnet, self.__state_dim, self.__ctx)
+
+ def train(self, episodes=None):
+ self.__logger.info("--- Start training ---")
+ trainer = gluon.Trainer(self.__qnet.collect_params(), self.__optimizer, self.__adjust_optimizer_params(self.__optimizer_params))
+ episodes = episodes if episodes != None else self.__training_episodes
+
+ resume = (self.__current_episode > 0)
+ if resume:
+ self.__logger.info("Training session resumed")
+ self.__logger.info("Starting from episode {}".format(self.__current_episode))
+ else:
+ self.__training_stats = util.TrainingStats(self.__logger, episodes, self.__live_plot)
+
+ # Implementation Deep Q Learning described by Mnih et. al. in Playing Atari with Deep Reinforcement Learning
+ while self.__current_episode < episodes:
+ # Check interrupt flag
+ if self.__interrupt_flag:
+ self.__interrupt_flag = False
+ self.__interrupt_training()
+ return False
+
+ step = 0
+ episode_reward = 0
+ start = time.time()
+ state = self.__environment.reset()
+ episode_loss = 0
+ training_steps = 0
+ while step < self.__max_episode_step:
+ #1. Choose an action based on current game state and policy
+ q_values = self.__qnet(nd.array([state], ctx=self.__ctx))
+ action = self.__policy.select_action(q_values[0])
+
+ #2. Play the game for a single step
+ next_state, reward, terminal, _ = self.__environment.step(action)
+
+ #3. Store transition in replay memory
+ self.__memory.append(state, action, reward, next_state, terminal)
+
+ #4. Train the network if in interval
+ do_training = (self.__total_steps % self.__train_interval == 0\
+ and self.__memory.is_sample_possible(self.__minibatch_size))
+ if do_training:
+ loss = self.__train_q_net_step(trainer)
+ loss_sum = sum(loss).asnumpy()[0]
+ episode_loss += float(loss_sum)/float(self.__minibatch_size)
+ training_steps += 1
+
+ # Update target network if in interval
+ self.__do_target_update_if_in_interval(self.__total_steps)
+
+ step += 1
+ self.__total_steps += 1
+ episode_reward += reward
+ state = next_state
+
+ if terminal:
+ episode_loss = episode_loss if training_steps > 0 else None
+ _, _, avg_reward = self.__training_stats.log_episode(self.__current_episode, start, training_steps,
+ episode_loss, self.__policy.cur_eps, episode_reward)
+ break
+
+ self.__do_snapshot_if_in_interval(self.__current_episode)
+ self.__policy.decay()
+
+ if self.__is_target_reached(avg_reward):
+ self.__logger.info('Target score is reached in average; Training is stopped')
+ break
+
+ self.__current_episode += 1
+
+ self.__evaluate()
+ training_stats_file = os.path.join(self.__output_directory, 'training_stats.pdf')
+ self.__training_stats.save_stats(training_stats_file)
+ self.__logger.info('--------- Training finished ---------')
+ return True
+
+ def __save_net(self, net, filename, filedir=None):
+ filedir = self.__output_directory if filedir is None else filedir
+ filename = os.path.join(filedir, filename + '.params')
+ net.save_parameters(filename)
+
+
+ def save_parameters(self, episode=None, filename='dqn-agent-params'):
+ assert self.__output_directory
+ self.__make_output_directory_if_not_exist()
+
+ if(episode != None):
+ self.__logger.info('Saving model parameters after episode %d' % episode)
+ filename = filename + '-ep{}'.format(episode)
+ else:
+ self.__logger.info('Saving model parameters')
+ self.__save_net(self.__qnet, filename)
+
+ def evaluate(self, target=None, sample_games=100, verbose=True):
+ target = self.__target_score if target is None else target
+ if target:
+ target_achieved = 0
+ total_reward = 0
+
+ for g in range(sample_games):
+ state = self.__environment.reset()
+ step = 0
+ game_reward = 0
+ while step < self.__max_episode_step:
+ action = self.get_next_action(state)
+ state, reward, terminal, _ = self.__environment.step(action)
+ game_reward += reward
+
+ if terminal:
+ if verbose:
+ info = 'Game %d: Reward %f' % (g,game_reward)
+ self.__logger.debug(info)
+ if target:
+ if game_reward >= target:
+ target_achieved += 1
+ total_reward += game_reward
+ break
+
+ step += 1
+
+ avg_reward = float(total_reward)/float(sample_games)
+ info = 'Avg. Reward: %f' % avg_reward
+ if target:
+ target_achieved_ratio = int((float(target_achieved)/float(sample_games))*100)
+ info += '; Target Achieved in %d%% of games' % (target_achieved_ratio)
+
+ if verbose:
+ self.__logger.info(info)
+ return avg_reward
+
+ def __evaluate(self, verbose=True):
+ sample_games = 100
+ avg_reward = self.evaluate(sample_games=sample_games, verbose=False)
+ info = 'Evaluation -> Average Reward in {} games: {}'.format(sample_games, avg_reward)
+
+ if self.__best_avg_score is None or self.__best_avg_score <= avg_reward:
+ self.__best_net = copy_net(self.__qnet, self.__state_dim, self.__ctx)
+ self.__best_avg_score = avg_reward
+ info += ' (NEW BEST)'
+
+ if verbose:
+ self.__logger.info(info)
+
+
+
+ def play(self, update_frame=1, with_best=False):
+ step = 0
+ state = self.__environment.reset()
+ total_reward = 0
+ while step < self.__max_episode_step:
+ action = self.get_next_action(state, with_best=with_best)
+ state, reward, terminal, _ = self.__environment.step(action)
+ total_reward += reward
+ do_update_frame = (step % update_frame == 0)
+ if do_update_frame:
+ self.__environment.render()
+ time.sleep(.100)
+
+ if terminal:
+ break
+
+ step += 1
+ return total_reward
+
+ def save_best_network(self, path, epoch=0):
+ self.__logger.info('Saving best network with average reward of {}'.format(self.__best_avg_score))
+ self.__best_net.export(path, epoch=epoch)
+
+ def __make_config_dict(self):
+ config = dict()
+ config['discount_factor'] = self.__discount_factor
+ config['optimizer'] = self.__optimizer
+ config['optimizer_params'] = self.__optimizer_params
+ config['policy_params'] = self.__policy_params
+ config['replay_memory_params'] = self.__replay_memory_params
+ config['loss_function'] = self.__loss_function_str
+ config['optimizer'] = self.__optimizer
+ config['training_episodes'] = self.__training_episodes
+ config['train_interval'] = self.__train_interval
+ config['use_fix_target'] = self.__use_fix_target
+ config['double_dqn'] = self.__double_dqn
+ config['target_update_interval'] = self.__target_update_interval
+ config['snapshot_interval']= self.__snapshot_interval
+ config['agent_name'] = self.__agent_name
+ config['max_episode_step'] = self.__max_episode_step
+ config['output_directory'] = self.__user_given_directory
+ config['verbose'] = self.__verbose
+ config['live_plot'] = self.__live_plot
+ config['make_logfile'] = self.__make_logfile
+ config['target_score'] = self.__target_score
+ return config
+
+ def save_config_file(self):
+ import json
+ self.__make_output_directory_if_not_exist()
+ filename = os.path.join(self.__output_directory, 'config.json')
+ config = self.__make_config_dict()
+ with open(filename, mode='w') as fp:
+ json.dump(config, fp, indent=4)
\ No newline at end of file
diff --git a/src/test/resources/target_code/ReinforcementConfig2/reinforcement_learning/environment.py b/src/test/resources/target_code/ReinforcementConfig2/reinforcement_learning/environment.py
new file mode 100644
index 0000000000000000000000000000000000000000..47787a862653f56d0b5a69c2816647e691a257d8
--- /dev/null
+++ b/src/test/resources/target_code/ReinforcementConfig2/reinforcement_learning/environment.py
@@ -0,0 +1,71 @@
+import abc
+import logging
+logging.basicConfig(level=logging.INFO)
+logger = logging.getLogger(__name__)
+
+class Environment:
+ __metaclass__ = abc.ABCMeta
+
+ def __init__(self):
+ pass
+
+ @abc.abstractmethod
+ def reset(self):
+ pass
+
+ @abc.abstractmethod
+ def step(self, action):
+ pass
+
+ @abc.abstractmethod
+ def close(self):
+ pass
+
+import gym
+class GymEnvironment(Environment):
+ def __init__(self, env_name, **kwargs):
+ super(GymEnvironment, self).__init__(**kwargs)
+ self.__seed = 42
+ self.__env = gym.make(env_name)
+ self.__env.seed(self.__seed)
+
+ @property
+ def state_dim(self):
+ return self.__env.observation_space.shape
+
+ @property
+ def state_dtype(self):
+ return 'float32'
+
+ @property
+ def action_dtype(self):
+ return 'uint8'
+
+ @property
+ def number_of_actions(self):
+ return self.__env.action_space.n
+
+ @property
+ def rewards_dtype(self):
+ return 'float32'
+
+ def reset(self):
+ return self.__env.reset()
+
+ def step(self, action):
+ return self.__env.step(action)
+
+ def close(self):
+ self.__env.close()
+
+ def action_space(self):
+ self.__env.action_space
+
+ def is_in_action_space(self, action):
+ return self.__env.action_space.contains(action)
+
+ def sample_action(self):
+ return self.__env.action_space.sample()
+
+ def render(self):
+ self.__env.render()
diff --git a/src/test/resources/target_code/ReinforcementConfig2/reinforcement_learning/replay_memory.py b/src/test/resources/target_code/ReinforcementConfig2/reinforcement_learning/replay_memory.py
new file mode 100644
index 0000000000000000000000000000000000000000..e66cd9350cab02144f994cab706249ef5a1e4288
--- /dev/null
+++ b/src/test/resources/target_code/ReinforcementConfig2/reinforcement_learning/replay_memory.py
@@ -0,0 +1,155 @@
+import numpy as np
+
+class ReplayMemoryBuilder(object):
+ def __init__(self):
+ self.__supported_methods = ['online', 'buffer', 'combined']
+
+ def build_by_params(self,
+ state_dim,
+ method='online',
+ state_dtype='float32',
+ action_dtype='uint8',
+ rewards_dtype='float32',
+ memory_size=1000,
+ sample_size=32):
+ assert state_dim is not None
+ assert method in self.__supported_methods
+
+ if method == 'online':
+ return self.build_online_memory(state_dim=state_dim, state_dtype=state_dtype,
+ action_dtype=action_dtype, rewards_dtype=rewards_dtype)
+ else:
+ assert memory_size is not None and memory_size > 0
+ assert sample_size is not None and sample_size > 0
+ if method == 'buffer':
+ return self.build_buffered_memory(state_dim=state_dim, sample_size=sample_size,
+ memory_size=memory_size, state_dtype=state_dtype, action_dtype=action_dtype,
+ rewards_dtype=rewards_dtype)
+ else:
+ return self.build_combined_memory(state_dim=state_dim, sample_size=sample_size,
+ memory_size=memory_size, state_dtype=state_dtype, action_dtype=action_dtype,
+ rewards_dtype=rewards_dtype)
+
+ def build_buffered_memory(self, state_dim, memory_size=1000, sample_size=1, state_dtype='float32',
+ action_dtype='uint8', rewards_dtype='float32'):
+ assert memory_size > 0
+ assert sample_size > 0
+ return ReplayMemory(state_dim, size=memory_size, sample_size=sample_size,
+ state_dtype=state_dtype, action_dtype=action_dtype, rewards_dtype=rewards_dtype)
+
+ def build_combined_memory(self, state_dim, memory_size=1000, sample_size=1, state_dtype='float32',
+ action_dtype='uint8', rewards_dtype='float32'):
+ assert memory_size > 0
+ assert sample_size > 0
+ return CombinedReplayMemory(state_dim, size=memory_size, sample_size=sample_size,
+ state_dtype=state_dtype, action_dtype=action_dtype, rewards_dtype=rewards_dtype)
+
+ def build_online_memory(self, state_dim, state_dtype='float32', action_dtype='uint8',
+ rewards_dtype='float32'):
+ return OnlineReplayMemory(state_dim, state_dtype=state_dtype, action_dtype=action_dtype,
+ rewards_dtype=rewards_dtype)
+
+class ReplayMemory(object):
+ def __init__(self, state_dim, sample_size, size=1000, state_dtype='uint8', action_dtype='uint8', rewards_dtype='float32'):
+ assert size > 0, "Size must be greater than zero"
+ assert type(state_dim) is tuple, "State dimension must be a tuple"
+ assert sample_size > 0
+ self._size = size
+ self._sample_size = sample_size
+ self._cur_size = 0
+ self._pointer = 0
+ self._state_dim = state_dim
+ self._state_dtype = state_dtype
+ self._action_dtype = action_dtype
+ self._rewards_dtype = rewards_dtype
+ self._states = np.zeros((self._size,) + state_dim, dtype=state_dtype)
+ self._actions = np.array([0] * self._size, dtype=action_dtype)
+ self._rewards = np.array([0] * self._size, dtype=rewards_dtype)
+ self._next_states = np.zeros((self._size,) + state_dim, dtype=state_dtype)
+ self._terminals = np.array([0] * self._size, dtype='bool')
+
+ @property
+ def sample_size(self):
+ return self._sample_size
+
+ def append(self, state, action, reward, next_state, terminal):
+ self._states[self._pointer] = state
+ self._actions[self._pointer] = action
+ self._rewards[self._pointer] = reward
+ self._next_states[self._pointer] = next_state
+ self._terminals[self._pointer] = terminal
+
+ self._pointer = self._pointer + 1
+ if self._pointer == self._size:
+ self._pointer = 0
+
+ self._cur_size = min(self._size, self._cur_size + 1)
+
+ def at(self, index):
+ return self._states[index],\
+ self._actions[index],\
+ self._rewards[index],\
+ self._next_states[index],\
+ self._terminals[index]
+
+ def is_sample_possible(self, batch_size=None):
+ batch_size = batch_size if batch_size is not None else self._sample_size
+ return self._cur_size >= batch_size
+
+ def sample(self, batch_size=None):
+ batch_size = batch_size if batch_size is not None else self._sample_size
+ assert self._cur_size >= batch_size, "Size of replay memory must be larger than batch size"
+ i=0
+ states = np.zeros((batch_size,)+self._state_dim, dtype=self._state_dtype)
+ actions = np.zeros(batch_size, dtype=self._action_dtype)
+ rewards = np.zeros(batch_size, dtype=self._rewards_dtype)
+ next_states = np.zeros((batch_size,)+self._state_dim, dtype=self._state_dtype)
+ terminals = np.zeros(batch_size, dtype='bool')
+
+ while i < batch_size:
+ rnd_index = np.random.randint(low=0, high=self._cur_size)
+ states[i] = self._states.take(rnd_index, axis=0)
+ actions[i] = self._actions.take(rnd_index, axis=0)
+ rewards[i] = self._rewards.take(rnd_index, axis=0)
+ next_states[i] = self._next_states.take(rnd_index, axis=0)
+ terminals[i] = self._terminals.take(rnd_index, axis=0)
+ i += 1
+
+ return states, actions, rewards, next_states, terminals
+
+
+class OnlineReplayMemory(ReplayMemory):
+ def __init__(self, state_dim, state_dtype='float32', action_dtype='uint8', rewards_dtype='float32'):
+ super(OnlineReplayMemory, self).__init__(state_dim, sample_size=1, size=1,
+ state_dtype=state_dtype, action_dtype=action_dtype, rewards_dtype=rewards_dtype)
+
+
+class CombinedReplayMemory(ReplayMemory):
+ def __init__(self, state_dim, sample_size, size=1000,
+ state_dtype='uint8', action_dtype='uint8', rewards_dtype='float32'):
+ super(CombinedReplayMemory, self).__init__(state_dim, sample_size=(sample_size - 1), size=size,
+ state_dtype=state_dtype, action_dtype=action_dtype, rewards_dtype=rewards_dtype)
+
+ self._last_state = np.zeros((1,) + state_dim, dtype=state_dtype)
+ self._last_action = np.array([0], dtype=action_dtype)
+ self._last_reward = np.array([0], dtype=rewards_dtype)
+ self._last_next_state = np.zeros((1,) + state_dim, dtype=state_dtype)
+ self._last_terminal = np.array([0], dtype='bool')
+
+ def append(self, state, action, reward, next_state, terminal):
+ super(CombinedReplayMemory, self).append(state, action, reward, next_state, terminal)
+ self._last_state = state
+ self._last_action = action
+ self._last_reward = reward
+ self._last_next_state = next_state
+ self._last_terminal = terminal
+
+ def sample(self, batch_size=None):
+ batch_size = (batch_size-1) if batch_size is not None else self._sample_size
+ states, actions, rewards, next_states, terminals = super(CombinedReplayMemory, self).sample(batch_size=batch_size)
+ states = np.append(states, [self._last_state], axis=0)
+ actions = np.append(actions, [self._last_action], axis=0)
+ rewards = np.append(rewards, [self._last_reward], axis=0)
+ next_states = np.append(next_states, [self._last_next_state], axis=0)
+ terminals = np.append(terminals, [self._last_terminal], axis=0)
+ return states, actions, rewards, next_states, terminals
\ No newline at end of file
diff --git a/src/test/resources/target_code/ReinforcementConfig2/reinforcement_learning/util.py b/src/test/resources/target_code/ReinforcementConfig2/reinforcement_learning/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..58578932a13f7d12f1aa840d06af1951bebc3f0f
--- /dev/null
+++ b/src/test/resources/target_code/ReinforcementConfig2/reinforcement_learning/util.py
@@ -0,0 +1,140 @@
+import signal
+import sys
+import numpy as np
+import matplotlib.pyplot as plt
+from matplotlib import style
+import time
+import os
+import mxnet
+from mxnet import gluon, nd
+
+
+LOSS_FUNCTIONS = {
+ 'l1': gluon.loss.L1Loss(),
+ 'euclidean': gluon.loss.L2Loss(),
+ 'huber_loss': gluon.loss.HuberLoss(),
+ 'softmax_cross_entropy': gluon.loss.SoftmaxCrossEntropyLoss(),
+ 'sigmoid_cross_entropy': gluon.loss.SigmoidBinaryCrossEntropyLoss()}
+
+def copy_net(net, input_state_dim, ctx, tmp_filename='tmp.params'):
+ assert isinstance(net, gluon.HybridBlock)
+ assert type(net.__class__) is type
+ net.save_parameters(tmp_filename)
+ net2 = net.__class__()
+ net2.load_parameters(tmp_filename, ctx=ctx)
+ os.remove(tmp_filename)
+ net2.hybridize()
+ net2(nd.ones((1,) + input_state_dim, ctx=ctx))
+ return net2
+
+def get_loss_function(loss_function_name):
+ if loss_function_name not in LOSS_FUNCTIONS:
+ raise ValueError('Loss function does not exist')
+ return LOSS_FUNCTIONS[loss_function_name]
+
+
+class AgentSignalHandler(object):
+ def __init__(self):
+ signal.signal(signal.SIGINT, self.interrupt_training)
+ self.__agent = None
+ self.__times_interrupted = 0
+
+ def register_agent(self, agent):
+ self.__agent = agent
+
+ def interrupt_training(self, sig, frame):
+ self.__times_interrupted = self.__times_interrupted + 1
+ if self.__times_interrupted <= 3:
+ if self.__agent:
+ self.__agent.set_interrupt_flag(True)
+ else:
+ print('Interrupt called three times: Force quit')
+ sys.exit(1)
+
+style.use('fivethirtyeight')
+class TrainingStats(object):
+ def __init__(self, logger, max_episodes, live_plot=True):
+ self.__logger = logger
+ self.__max_episodes = max_episodes
+ self.__all_avg_loss = np.zeros((max_episodes,))
+ self.__all_total_rewards = np.zeros((max_episodes,))
+ self.__all_eps = np.zeros((max_episodes,))
+ self.__all_time = np.zeros((max_episodes,))
+ self.__all_mean_reward_last_100_episodes = np.zeros((max_episodes,))
+ self.__live_plot = live_plot
+
+ @property
+ def logger(self):
+ return self.__logger
+
+ @logger.setter
+ def logger(self, logger):
+ self.__logger = logger
+
+ @logger.deleter
+ def logger(self):
+ self.__logger = None
+
+ def add_avg_loss(self, episode, avg_loss):
+ self.__all_avg_loss[episode] = avg_loss
+
+ def add_total_reward(self, episode, total_reward):
+ self.__all_total_rewards[episode] = total_reward
+
+ def add_eps(self, episode, eps):
+ self.__all_eps[episode] = eps
+
+ def add_time(self, episode, time):
+ self.__all_time[episode] = time
+
+ def add_mean_reward_last_100(self, episode, mean_reward):
+ self.__all_mean_reward_last_100_episodes[episode] = mean_reward
+
+ def log_episode(self, episode, start_time, training_steps, loss, eps, reward):
+ self.add_eps(episode, eps)
+ self.add_total_reward(episode, reward)
+ end = time.time()
+ if training_steps == 0:
+ avg_loss = 0
+ else:
+ avg_loss = float(loss)/float(training_steps)
+
+ mean_reward_last_100 = self.mean_of_reward(episode, last=100)
+
+ time_elapsed = end - start_time
+ info = "Episode: %d, Total Reward: %.3f, Avg. Reward Last 100 Episodes: %.3f, Avg Loss: %.3f, Time: %.3f, Training Steps: %d, Eps: %.3f"\
+ % (episode, reward, mean_reward_last_100, avg_loss, time_elapsed, training_steps, eps)
+ self.__logger.info(info)
+ self.add_avg_loss(episode, avg_loss)
+ self.add_time(episode, time_elapsed)
+ self.add_mean_reward_last_100(episode, mean_reward_last_100)
+
+ return avg_loss, time_elapsed, mean_reward_last_100
+
+ def mean_of_reward(self, cur_episode, last=100):
+ if cur_episode > 0:
+ reward_last_100 = self.__all_total_rewards[max(0, cur_episode-last):cur_episode]
+ return np.mean(reward_last_100)
+ else:
+ return self.__all_total_rewards[0]
+
+ def save_stats(self, file):
+ fig = plt.figure(figsize=(20,20))
+
+ sub_rewards = fig.add_subplot(221)
+ sub_rewards.set_title('Total Rewards per episode')
+ sub_rewards.plot(np.arange(self.__max_episodes), self.__all_total_rewards)
+
+ sub_loss = fig.add_subplot(222)
+ sub_loss.set_title('Avg. Loss per episode')
+ sub_loss.plot(np.arange(self.__max_episodes), self.__all_avg_loss)
+
+ sub_eps = fig.add_subplot(223)
+ sub_eps.set_title('Epsilon per episode')
+ sub_eps.plot(np.arange(self.__max_episodes), self.__all_eps)
+
+ sub_rewards = fig.add_subplot(224)
+ sub_rewards.set_title('Avg. mean reward of last 100 episodes')
+ sub_rewards.plot(np.arange(self.__max_episodes), self.__all_mean_reward_last_100_episodes)
+
+ plt.savefig(file)
\ No newline at end of file
diff --git a/src/test/resources/target_code/ReinforcementConfig2/start_training.sh b/src/test/resources/target_code/ReinforcementConfig2/start_training.sh
new file mode 100644
index 0000000000000000000000000000000000000000..3662d39e9c1140b18812c763a6b3e49df584c23f
--- /dev/null
+++ b/src/test/resources/target_code/ReinforcementConfig2/start_training.sh
@@ -0,0 +1,2 @@
+#!/bin/bash
+python CNNTrainer_reinforcementConfig2.py
\ No newline at end of file
diff --git a/src/test/resources/target_code/reinforcement_learning/__init__.py b/src/test/resources/target_code/reinforcement_learning/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/test/resources/target_code/reinforcement_learning/action_policy.py b/src/test/resources/target_code/reinforcement_learning/action_policy.py
new file mode 100644
index 0000000000000000000000000000000000000000..f43a211fe353f5fcb95fb8dd38d7e412fd1d1ab4
--- /dev/null
+++ b/src/test/resources/target_code/reinforcement_learning/action_policy.py
@@ -0,0 +1,73 @@
+import numpy as np
+
+class ActionPolicyBuilder(object):
+ def __init__(self):
+ pass
+
+ def build_by_params(self,
+ method='epsgreedy',
+ epsilon=0.5,
+ min_epsilon=0.05,
+ epsilon_decay_method='no',
+ epsilon_decay=0.0,
+ action_dim=None):
+
+ if epsilon_decay_method == 'linear':
+ decay = LinearDecay(eps_decay=epsilon_decay, min_eps=min_epsilon)
+ else:
+ decay = NoDecay()
+
+ if method == 'epsgreedy':
+ assert action_dim is not None
+ assert len(action_dim) == 1
+ return EpsilonGreedyActionPolicy(eps=epsilon,
+ number_of_actions=action_dim[0], decay=decay)
+ else:
+ assert action_dim is not None
+ assert len(action_dim) == 1
+ return GreedyActionPolicy()
+
+class EpsilonGreedyActionPolicy(object):
+ def __init__(self, eps, number_of_actions, decay):
+ self.eps = eps
+ self.cur_eps = eps
+ self.__number_of_actions = number_of_actions
+ self.__decay_method = decay
+
+ def select_action(self, values):
+ do_exploration = (np.random.rand() < self.cur_eps)
+ if do_exploration:
+ action = np.random.randint(low=0, high=self.__number_of_actions)
+ else:
+ action = values.asnumpy().argmax()
+ return action
+
+ def decay(self):
+ self.cur_eps = self.__decay_method.decay(self.cur_eps)
+
+
+class GreedyActionPolicy(object):
+ def __init__(self):
+ pass
+
+ def select_action(self, values):
+ return values.asnumpy().argmax()
+
+ def decay(self):
+ pass
+
+
+class NoDecay(object):
+ def __init__(self):
+ pass
+
+ def decay(self, cur_eps):
+ return cur_eps
+
+class LinearDecay(object):
+ def __init__(self, eps_decay, min_eps=0):
+ self.eps_decay = eps_decay
+ self.min_eps = min_eps
+
+ def decay(self, cur_eps):
+ return max(cur_eps - self.eps_decay, self.min_eps)
\ No newline at end of file
diff --git a/src/test/resources/target_code/reinforcement_learning/agent.py b/src/test/resources/target_code/reinforcement_learning/agent.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b53a2be2c1a684c2ed25830672dd3d210d6d840
--- /dev/null
+++ b/src/test/resources/target_code/reinforcement_learning/agent.py
@@ -0,0 +1,503 @@
+import mxnet as mx
+import numpy as np
+import time
+import os
+import logging
+import sys
+import util
+import matplotlib.pyplot as plt
+from replay_memory import ReplayMemoryBuilder
+from action_policy import ActionPolicyBuilder
+from util import copy_net, get_loss_function
+from mxnet import nd, gluon, autograd
+
+class DqnAgent(object):
+ def __init__(self,
+ network,
+ environment,
+ replay_memory_params,
+ policy_params,
+ state_dim,
+ ctx=None,
+ discount_factor=.9,
+ loss_function='euclidean',
+ optimizer='rmsprop',
+ optimizer_params = {'learning_rate':0.09},
+ training_episodes=50,
+ train_interval=1,
+ use_fix_target=False,
+ double_dqn = False,
+ target_update_interval=10,
+ snapshot_interval=200,
+ agent_name='Dqn_agent',
+ max_episode_step=99999,
+ output_directory='model_parameters',
+ verbose=True,
+ live_plot = True,
+ make_logfile=True,
+ target_score=None):
+ assert 0 < discount_factor <= 1
+ assert train_interval > 0
+ assert target_update_interval > 0
+ assert snapshot_interval > 0
+ assert max_episode_step > 0
+ assert training_episodes > 0
+ assert replay_memory_params is not None
+ assert type(state_dim) is tuple
+
+ self.__ctx = mx.gpu() if ctx == 'gpu' else mx.cpu()
+ self.__qnet = network
+
+ self.__environment = environment
+ self.__discount_factor = discount_factor
+ self.__training_episodes = training_episodes
+ self.__train_interval = train_interval
+ self.__verbose = verbose
+ self.__state_dim = state_dim
+ self.__action_dim = self.__qnet(nd.random_normal(shape=((1,) + self.__state_dim), ctx=self.__ctx)).shape[1:]
+
+ replay_memory_params['state_dim'] = state_dim
+ self.__replay_memory_params = replay_memory_params
+ rm_builder = ReplayMemoryBuilder()
+ self.__memory = rm_builder.build_by_params(**replay_memory_params)
+ self.__minibatch_size = self.__memory.sample_size
+
+ policy_params['action_dim'] = self.__action_dim
+ self.__policy_params = policy_params
+ p_builder = ActionPolicyBuilder()
+ self.__policy = p_builder.build_by_params(**policy_params)
+
+ self.__target_update_interval = target_update_interval
+ self.__target_qnet = copy_net(self.__qnet, self.__state_dim, ctx=self.__ctx)
+ self.__loss_function_str = loss_function
+ self.__loss_function = get_loss_function(loss_function)
+ self.__agent_name = agent_name
+ self.__snapshot_interval = snapshot_interval
+ self.__creation_time = time.time()
+ self.__max_episode_step = max_episode_step
+ self.__optimizer = optimizer
+ self.__optimizer_params = optimizer_params
+ self.__make_logfile = make_logfile
+ self.__double_dqn = double_dqn
+ self.__use_fix_target = use_fix_target
+ self.__live_plot = live_plot
+ self.__user_given_directory = output_directory
+ self.__target_score = target_score
+
+ self.__interrupt_flag = False
+
+ # Training Context
+ self.__current_episode = 0
+ self.__total_steps = 0
+
+ # Initialize best network
+ self.__best_net = copy_net(self.__qnet, self.__state_dim, self.__ctx)
+ self.__best_avg_score = None
+
+ # Gluon Trainer definition
+ self.__training_stats = None
+
+ # Prepare output directory and logger
+ self.__output_directory = output_directory\
+ + '/' + self.__agent_name\
+ + '/' + time.strftime('%d-%m-%Y-%H-%M-%S', time.localtime(self.__creation_time))
+ self.__logger = self.__setup_logging()
+ self.__logger.info('Agent created with following parameters: {}'.format(self.__make_config_dict()))
+
+ @classmethod
+ def from_config_file(cls, network, environment, config_file_path, ctx=None):
+ import json
+ # Load config
+ with open(config_file_path, 'r') as config_file:
+ config_dict = json.load(config_file)
+ return cls(network, environment, ctx=ctx, **config_dict)
+
+ @classmethod
+ def resume_from_session(cls, session_dir, network_type):
+ import pickle
+ session_dir = os.path.join(session_dir, '.interrupted_session')
+ if not os.path.exists(session_dir):
+ raise ValueError('Session directory does not exist')
+
+ files = dict()
+ files['agent'] = os.path.join(session_dir, 'agent.p')
+ files['best_net_params'] = os.path.join(session_dir, 'best_net.params')
+ files['q_net_params'] = os.path.join(session_dir, 'qnet.params')
+ files['target_net_params'] = os.path.join(session_dir, 'target_net.params')
+
+ for file in files.values():
+ if not os.path.exists(file):
+ raise ValueError('Session directory is not complete: {} is missing'.format(file))
+
+ with open(files['agent'], 'rb') as f:
+ agent = pickle.load(f)
+
+ agent.__qnet = network_type()
+ agent.__qnet.load_parameters(files['q_net_params'], agent.__ctx)
+ agent.__qnet.hybridize()
+ agent.__qnet(nd.ones((1,) + agent.__environment.state_dim))
+ agent.__best_net = network_type()
+ agent.__best_net.load_parameters(files['best_net_params'], agent.__ctx)
+ agent.__target_qnet = network_type()
+ agent.__target_qnet.load_parameters(files['target_net_params'], agent.__ctx)
+
+ agent.__logger = agent.__setup_logging(append=True)
+ agent.__training_stats.logger = agent.__logger
+ agent.__logger.info('Agent was retrieved; Training can be continued')
+
+ return agent
+
+ def __interrupt_training(self):
+ import pickle
+ self.__logger.info('Training interrupted; Store state for resuming')
+ session_dir = os.path.join(self.__output_directory, '.interrupted_session')
+ if not os.path.exists(session_dir):
+ os.mkdir(session_dir)
+
+ del self.__training_stats.logger
+ logger = self.__logger
+ self.__logger = None
+
+ self.__save_net(self.__qnet, 'qnet', session_dir)
+ self.__qnet = None
+ self.__save_net(self.__best_net, 'best_net', session_dir)
+ self.__best_net = None
+ self.__save_net(self.__target_qnet, 'target_net', session_dir)
+ self.__target_qnet = None
+
+ agent_session_file = os.path.join(session_dir, 'agent.p')
+
+ with open(agent_session_file, 'wb') as f:
+ pickle.dump(self, f)
+
+ logger.info('State successfully stored')
+
+ @property
+ def current_episode(self):
+ return self.__current_episode
+
+ @property
+ def environment(self):
+ return self.__environment
+
+ def __adjust_optimizer_params(self, optimizer_params):
+ if 'weight_decay' in optimizer_params:
+ optimizer_params['wd'] = optimizer_params['weight_decay']
+ del optimizer_params['weight_decay']
+ if 'learning_rate_decay' in optimizer_params:
+ min_learning_rate = 1e-8
+ if 'learning_rate_minimum' in optimizer_params:
+ min_learning_rate = optimizer_params['learning_rate_minimum']
+ del optimizer_params['learning_rate_minimum']
+ optimizer_params['lr_scheduler'] = mx.lr.scheduler.FactorScheduler(
+ optimizer_params['step_size'],
+ factor=optimizer_params['learning_rate_decay'],
+ stop_factor_lr=min_learning_rate)
+ del optimizer_params['step_size']
+ del optimizer_params['learning_rate_decay']
+
+ return optimizer_params
+
+ def set_interrupt_flag(self, interrupt):
+ self.__interrupt_flag = interrupt
+
+
+ def __make_output_directory_if_not_exist(self):
+ assert self.__output_directory
+ if not os.path.exists(self.__output_directory):
+ os.makedirs(self.__output_directory)
+
+ def __setup_logging(self, append=False):
+ assert self.__output_directory
+ assert self.__agent_name
+
+ output_level = logging.DEBUG if self.__verbose else logging.WARNING
+ filemode = 'a' if append else 'w'
+
+ logformat = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
+ dateformat = '%d-%b-%y %H:%M:%S'
+ formatter = logging.Formatter(fmt=logformat, datefmt=dateformat)
+
+ logger = logging.getLogger('DQNAgent')
+ logger.setLevel(output_level)
+
+ stream_handler = logging.StreamHandler(sys.stdout)
+ stream_handler.setLevel(output_level)
+ stream_handler.setFormatter(formatter)
+ logger.addHandler(stream_handler)
+
+ if self.__make_logfile:
+ self.__make_output_directory_if_not_exist()
+ log_file = os.path.join(self.__output_directory, self.__agent_name + '.log')
+ file_handler = logging.FileHandler(log_file, mode=filemode)
+ file_handler.setLevel(output_level)
+ file_handler.setFormatter(formatter)
+ logger.addHandler(file_handler)
+
+ return logger
+
+ def __is_target_reached(self, avg_reward):
+ return self.__target_score is not None\
+ and avg_reward > self.__target_score
+
+
+ def get_q_values(self, state, with_best=False):
+ return self.get_batch_q_values(nd.array([state], ctx=self.__ctx), with_best=with_best)[0]
+
+ def get_batch_q_values(self, state_batch, with_best=False):
+ return self.__best_net(state_batch) if with_best else self.__qnet(state_batch)
+
+ def get_next_action(self, state, with_best=False):
+ q_values = self.get_q_values(state, with_best=with_best)
+ action = q_values.asnumpy().argmax()
+ return q_values.asnumpy().argmax()
+
+ def __sample_from_memory(self):
+ states, actions, rewards, next_states, terminals\
+ = self.__memory.sample(batch_size=self.__minibatch_size)
+ states = nd.array(states, ctx=self.__ctx)
+ actions = nd.array(actions, ctx=self.__ctx)
+ rewards = nd.array(rewards, ctx=self.__ctx)
+ next_states = nd.array(next_states, ctx=self.__ctx)
+ terminals = nd.array(terminals, ctx=self.__ctx)
+ return states, actions, rewards, next_states, terminals
+
+ def __determine_target_q_values(self, states, actions, rewards, next_states, terminals):
+ if self.__use_fix_target:
+ q_max_val = self.__target_qnet(next_states)
+ else:
+ q_max_val = self.__qnet(next_states)
+
+ if self.__double_dqn:
+ q_values_next_states = self.__qnet(next_states)
+ target_rewards = rewards + nd.choose_element_0index(q_max_val, nd.argmax_channel(q_values_next_states))\
+ * (1.0 - terminals) * self.__discount_factor
+ else:
+ target_rewards = rewards + nd.choose_element_0index(q_max_val, nd.argmax_channel(q_max_val))\
+ * (1.0 - terminals) * self.__discount_factor
+
+ target_qval = self.__qnet(states)
+ for t in range(target_rewards.shape[0]):
+ target_qval[t][actions[t]] = target_rewards[t]
+
+ return target_qval
+
+ def __train_q_net_step(self, trainer):
+ states, actions, rewards, next_states, terminals = self.__sample_from_memory()
+ target_qval = self.__determine_target_q_values(states, actions, rewards, next_states, terminals)
+ with autograd.record():
+ q_values = self.__qnet(states)
+ loss = self.__loss_function(q_values, target_qval)
+ loss.backward()
+ trainer.step(self.__minibatch_size)
+ return loss
+
+ def __do_snapshot_if_in_interval(self, episode):
+ do_snapshot = (episode % self.__snapshot_interval == 0)
+ if do_snapshot:
+ self.save_parameters(episode=episode)
+ self.__evaluate()
+
+ def __do_target_update_if_in_interval(self, total_steps):
+ do_target_update = (self.__use_fix_target and total_steps % self.__target_update_interval == 0)
+ if do_target_update:
+ self.__logger.info('Target network is updated after {} steps'.format(total_steps))
+ self.__target_qnet = copy_net(self.__qnet, self.__state_dim, self.__ctx)
+
+ def train(self, episodes=None):
+ self.__logger.info("--- Start training ---")
+ trainer = gluon.Trainer(self.__qnet.collect_params(), self.__optimizer, self.__adjust_optimizer_params(self.__optimizer_params))
+ episodes = episodes if episodes != None else self.__training_episodes
+
+ resume = (self.__current_episode > 0)
+ if resume:
+ self.__logger.info("Training session resumed")
+ self.__logger.info("Starting from episode {}".format(self.__current_episode))
+ else:
+ self.__training_stats = util.TrainingStats(self.__logger, episodes, self.__live_plot)
+
+ # Implementation Deep Q Learning described by Mnih et. al. in Playing Atari with Deep Reinforcement Learning
+ while self.__current_episode < episodes:
+ if self.__interrupt_flag:
+ self.__interrupt_flag = False
+ self.__interrupt_training()
+ return False
+
+ step = 0
+ episode_reward = 0
+ start = time.time()
+ state = self.__environment.reset()
+ episode_loss = 0
+ training_steps = 0
+ while step < self.__max_episode_step:
+ #1. Choose an action based on current game state and policy
+ q_values = self.__qnet(nd.array([state], ctx=self.__ctx))
+ action = self.__policy.select_action(q_values[0])
+
+ #2. Play the game for a single step
+ next_state, reward, terminal, _ = self.__environment.step(action)
+
+ #3. Store transition in replay memory
+ self.__memory.append(state, action, reward, next_state, terminal)
+
+ #4. Train the network if in interval
+ do_training = (self.__total_steps % self.__train_interval == 0\
+ and self.__memory.is_sample_possible(self.__minibatch_size))
+ if do_training:
+ loss = self.__train_q_net_step(trainer)
+ loss_sum = sum(loss).asnumpy()[0]
+ episode_loss += float(loss_sum)/float(self.__minibatch_size)
+ training_steps += 1
+
+ # Update target network if in interval
+ self.__do_target_update_if_in_interval(self.__total_steps)
+
+ step += 1
+ self.__total_steps += 1
+ episode_reward += reward
+ state = next_state
+
+ if terminal:
+ episode_loss = episode_loss if training_steps > 0 else None
+ _, _, avg_reward = self.__training_stats.log_episode(self.__current_episode, start, training_steps,
+ episode_loss, self.__policy.cur_eps, episode_reward)
+ break
+
+ self.__do_snapshot_if_in_interval(self.__current_episode)
+ self.__policy.decay()
+
+ if self.__is_target_reached(avg_reward):
+ self.__logger.info('Target score is reached in average; Training is stopped')
+ break
+
+ self.__current_episode += 1
+
+ self.__evaluate()
+ training_stats_file = os.path.join(self.__output_directory, 'training_stats.pdf')
+ self.__training_stats.save_stats(training_stats_file)
+ self.__logger.info('--------- Training finished ---------')
+ return True
+
+ def __save_net(self, net, filename, filedir=None):
+ filedir = self.__output_directory if filedir is None else filedir
+ filename = os.path.join(filedir, filename + '.params')
+ net.save_parameters(filename)
+
+
+ def save_parameters(self, episode=None, filename='dqn-agent-params'):
+ assert self.__output_directory
+ self.__make_output_directory_if_not_exist()
+
+ if(episode != None):
+ self.__logger.info('Saving model parameters after episode %d' % episode)
+ filename = filename + '-ep{}'.format(episode)
+ else:
+ self.__logger.info('Saving model parameters')
+ self.__save_net(self.__qnet, filename)
+
+ def evaluate(self, target=None, sample_games=100, verbose=True):
+ target = self.__target_score if target is None else target
+ if target:
+ target_achieved = 0
+ total_reward = 0
+
+ for g in range(sample_games):
+ state = self.__environment.reset()
+ step = 0
+ game_reward = 0
+ while step < self.__max_episode_step:
+ action = self.get_next_action(state)
+ state, reward, terminal, _ = self.__environment.step(action)
+ game_reward += reward
+
+ if terminal:
+ if verbose:
+ info = 'Game %d: Reward %f' % (g,game_reward)
+ self.__logger.debug(info)
+ if target:
+ if game_reward >= target:
+ target_achieved += 1
+ total_reward += game_reward
+ break
+
+ step += 1
+
+ avg_reward = float(total_reward)/float(sample_games)
+ info = 'Avg. Reward: %f' % avg_reward
+ if target:
+ target_achieved_ratio = int((float(target_achieved)/float(sample_games))*100)
+ info += '; Target Achieved in %d%% of games' % (target_achieved_ratio)
+
+ if verbose:
+ self.__logger.info(info)
+ return avg_reward
+
+ def __evaluate(self, verbose=True):
+ sample_games = 100
+ avg_reward = self.evaluate(sample_games=sample_games, verbose=False)
+ info = 'Evaluation -> Average Reward in {} games: {}'.format(sample_games, avg_reward)
+
+ if self.__best_avg_score is None or self.__best_avg_score <= avg_reward:
+ self.__best_net = copy_net(self.__qnet, self.__state_dim, self.__ctx)
+ self.__best_avg_score = avg_reward
+ info += ' (NEW BEST)'
+
+ if verbose:
+ self.__logger.info(info)
+
+
+
+ def play(self, update_frame=1, with_best=False):
+ step = 0
+ state = self.__environment.reset()
+ total_reward = 0
+ while step < self.__max_episode_step:
+ action = self.get_next_action(state, with_best=with_best)
+ state, reward, terminal, _ = self.__environment.step(action)
+ total_reward += reward
+ do_update_frame = (step % update_frame == 0)
+ if do_update_frame:
+ self.__environment.render()
+ time.sleep(.100)
+
+ if terminal:
+ break
+
+ step += 1
+ return total_reward
+
+ def save_best_network(self, path, epoch=0):
+ self.__logger.info('Saving best network with average reward of {}'.format(self.__best_avg_score))
+ self.__best_net.export(path, epoch=epoch)
+
+ def __make_config_dict(self):
+ config = dict()
+ config['discount_factor'] = self.__discount_factor
+ config['optimizer'] = self.__optimizer
+ config['optimizer_params'] = self.__optimizer_params
+ config['policy_params'] = self.__policy_params
+ config['replay_memory_params'] = self.__replay_memory_params
+ config['loss_function'] = self.__loss_function_str
+ config['optimizer'] = self.__optimizer
+ config['training_episodes'] = self.__training_episodes
+ config['train_interval'] = self.__train_interval
+ config['use_fix_target'] = self.__use_fix_target
+ config['double_dqn'] = self.__double_dqn
+ config['target_update_interval'] = self.__target_update_interval
+ config['snapshot_interval']= self.__snapshot_interval
+ config['agent_name'] = self.__agent_name
+ config['max_episode_step'] = self.__max_episode_step
+ config['output_directory'] = self.__user_given_directory
+ config['verbose'] = self.__verbose
+ config['live_plot'] = self.__live_plot
+ config['make_logfile'] = self.__make_logfile
+ config['target_score'] = self.__target_score
+ return config
+
+ def save_config_file(self):
+ import json
+ self.__make_output_directory_if_not_exist()
+ filename = os.path.join(self.__output_directory, 'config.json')
+ config = self.__make_config_dict()
+ with open(filename, mode='w') as fp:
+ json.dump(config, fp, indent=4)
\ No newline at end of file
diff --git a/src/test/resources/target_code/reinforcement_learning/environment.py b/src/test/resources/target_code/reinforcement_learning/environment.py
new file mode 100644
index 0000000000000000000000000000000000000000..5d9a342f2cef88b184b2c4bb4976c7987f5cc1bf
--- /dev/null
+++ b/src/test/resources/target_code/reinforcement_learning/environment.py
@@ -0,0 +1,67 @@
+import abc
+import logging
+logging.basicConfig(level=logging.INFO)
+logger = logging.getLogger(__name__)
+
+class Environment:
+ __metaclass__ = abc.ABCMeta
+
+ def __init__(self):
+ pass
+
+ @abc.abstractmethod
+ def reset(self):
+ pass
+
+ @abc.abstractmethod
+ def step(self, action):
+ pass
+
+import gym
+class GymEnvironment(Environment):
+ def __init__(self, env_name, **kwargs):
+ super(GymEnvironment, self).__init__(**kwargs)
+ self.__seed = 42
+ self.__env = gym.make(env_name)
+ self.__env.seed(self.__seed)
+
+ @property
+ def state_dim(self):
+ return self.__env.observation_space.shape
+
+ @property
+ def state_dtype(self):
+ return 'float32'
+
+ @property
+ def action_dtype(self):
+ return 'uint8'
+
+ @property
+ def number_of_actions(self):
+ return self.__env.action_space.n
+
+ @property
+ def rewards_dtype(self):
+ return 'float32'
+
+ def reset(self):
+ return self.__env.reset()
+
+ def step(self, action):
+ return self.__env.step(action)
+
+ def close(self):
+ self.__env.close()
+
+ def action_space(self):
+ self.__env.action_space
+
+ def is_in_action_space(self, action):
+ return self.__env.action_space.contains(action)
+
+ def sample_action(self):
+ return self.__env.action_space.sample()
+
+ def render(self):
+ self.__env.render()
diff --git a/src/test/resources/target_code/reinforcement_learning/replay_memory.py b/src/test/resources/target_code/reinforcement_learning/replay_memory.py
new file mode 100644
index 0000000000000000000000000000000000000000..e66cd9350cab02144f994cab706249ef5a1e4288
--- /dev/null
+++ b/src/test/resources/target_code/reinforcement_learning/replay_memory.py
@@ -0,0 +1,155 @@
+import numpy as np
+
+class ReplayMemoryBuilder(object):
+ def __init__(self):
+ self.__supported_methods = ['online', 'buffer', 'combined']
+
+ def build_by_params(self,
+ state_dim,
+ method='online',
+ state_dtype='float32',
+ action_dtype='uint8',
+ rewards_dtype='float32',
+ memory_size=1000,
+ sample_size=32):
+ assert state_dim is not None
+ assert method in self.__supported_methods
+
+ if method == 'online':
+ return self.build_online_memory(state_dim=state_dim, state_dtype=state_dtype,
+ action_dtype=action_dtype, rewards_dtype=rewards_dtype)
+ else:
+ assert memory_size is not None and memory_size > 0
+ assert sample_size is not None and sample_size > 0
+ if method == 'buffer':
+ return self.build_buffered_memory(state_dim=state_dim, sample_size=sample_size,
+ memory_size=memory_size, state_dtype=state_dtype, action_dtype=action_dtype,
+ rewards_dtype=rewards_dtype)
+ else:
+ return self.build_combined_memory(state_dim=state_dim, sample_size=sample_size,
+ memory_size=memory_size, state_dtype=state_dtype, action_dtype=action_dtype,
+ rewards_dtype=rewards_dtype)
+
+ def build_buffered_memory(self, state_dim, memory_size=1000, sample_size=1, state_dtype='float32',
+ action_dtype='uint8', rewards_dtype='float32'):
+ assert memory_size > 0
+ assert sample_size > 0
+ return ReplayMemory(state_dim, size=memory_size, sample_size=sample_size,
+ state_dtype=state_dtype, action_dtype=action_dtype, rewards_dtype=rewards_dtype)
+
+ def build_combined_memory(self, state_dim, memory_size=1000, sample_size=1, state_dtype='float32',
+ action_dtype='uint8', rewards_dtype='float32'):
+ assert memory_size > 0
+ assert sample_size > 0
+ return CombinedReplayMemory(state_dim, size=memory_size, sample_size=sample_size,
+ state_dtype=state_dtype, action_dtype=action_dtype, rewards_dtype=rewards_dtype)
+
+ def build_online_memory(self, state_dim, state_dtype='float32', action_dtype='uint8',
+ rewards_dtype='float32'):
+ return OnlineReplayMemory(state_dim, state_dtype=state_dtype, action_dtype=action_dtype,
+ rewards_dtype=rewards_dtype)
+
+class ReplayMemory(object):
+ def __init__(self, state_dim, sample_size, size=1000, state_dtype='uint8', action_dtype='uint8', rewards_dtype='float32'):
+ assert size > 0, "Size must be greater than zero"
+ assert type(state_dim) is tuple, "State dimension must be a tuple"
+ assert sample_size > 0
+ self._size = size
+ self._sample_size = sample_size
+ self._cur_size = 0
+ self._pointer = 0
+ self._state_dim = state_dim
+ self._state_dtype = state_dtype
+ self._action_dtype = action_dtype
+ self._rewards_dtype = rewards_dtype
+ self._states = np.zeros((self._size,) + state_dim, dtype=state_dtype)
+ self._actions = np.array([0] * self._size, dtype=action_dtype)
+ self._rewards = np.array([0] * self._size, dtype=rewards_dtype)
+ self._next_states = np.zeros((self._size,) + state_dim, dtype=state_dtype)
+ self._terminals = np.array([0] * self._size, dtype='bool')
+
+ @property
+ def sample_size(self):
+ return self._sample_size
+
+ def append(self, state, action, reward, next_state, terminal):
+ self._states[self._pointer] = state
+ self._actions[self._pointer] = action
+ self._rewards[self._pointer] = reward
+ self._next_states[self._pointer] = next_state
+ self._terminals[self._pointer] = terminal
+
+ self._pointer = self._pointer + 1
+ if self._pointer == self._size:
+ self._pointer = 0
+
+ self._cur_size = min(self._size, self._cur_size + 1)
+
+ def at(self, index):
+ return self._states[index],\
+ self._actions[index],\
+ self._rewards[index],\
+ self._next_states[index],\
+ self._terminals[index]
+
+ def is_sample_possible(self, batch_size=None):
+ batch_size = batch_size if batch_size is not None else self._sample_size
+ return self._cur_size >= batch_size
+
+ def sample(self, batch_size=None):
+ batch_size = batch_size if batch_size is not None else self._sample_size
+ assert self._cur_size >= batch_size, "Size of replay memory must be larger than batch size"
+ i=0
+ states = np.zeros((batch_size,)+self._state_dim, dtype=self._state_dtype)
+ actions = np.zeros(batch_size, dtype=self._action_dtype)
+ rewards = np.zeros(batch_size, dtype=self._rewards_dtype)
+ next_states = np.zeros((batch_size,)+self._state_dim, dtype=self._state_dtype)
+ terminals = np.zeros(batch_size, dtype='bool')
+
+ while i < batch_size:
+ rnd_index = np.random.randint(low=0, high=self._cur_size)
+ states[i] = self._states.take(rnd_index, axis=0)
+ actions[i] = self._actions.take(rnd_index, axis=0)
+ rewards[i] = self._rewards.take(rnd_index, axis=0)
+ next_states[i] = self._next_states.take(rnd_index, axis=0)
+ terminals[i] = self._terminals.take(rnd_index, axis=0)
+ i += 1
+
+ return states, actions, rewards, next_states, terminals
+
+
+class OnlineReplayMemory(ReplayMemory):
+ def __init__(self, state_dim, state_dtype='float32', action_dtype='uint8', rewards_dtype='float32'):
+ super(OnlineReplayMemory, self).__init__(state_dim, sample_size=1, size=1,
+ state_dtype=state_dtype, action_dtype=action_dtype, rewards_dtype=rewards_dtype)
+
+
+class CombinedReplayMemory(ReplayMemory):
+ def __init__(self, state_dim, sample_size, size=1000,
+ state_dtype='uint8', action_dtype='uint8', rewards_dtype='float32'):
+ super(CombinedReplayMemory, self).__init__(state_dim, sample_size=(sample_size - 1), size=size,
+ state_dtype=state_dtype, action_dtype=action_dtype, rewards_dtype=rewards_dtype)
+
+ self._last_state = np.zeros((1,) + state_dim, dtype=state_dtype)
+ self._last_action = np.array([0], dtype=action_dtype)
+ self._last_reward = np.array([0], dtype=rewards_dtype)
+ self._last_next_state = np.zeros((1,) + state_dim, dtype=state_dtype)
+ self._last_terminal = np.array([0], dtype='bool')
+
+ def append(self, state, action, reward, next_state, terminal):
+ super(CombinedReplayMemory, self).append(state, action, reward, next_state, terminal)
+ self._last_state = state
+ self._last_action = action
+ self._last_reward = reward
+ self._last_next_state = next_state
+ self._last_terminal = terminal
+
+ def sample(self, batch_size=None):
+ batch_size = (batch_size-1) if batch_size is not None else self._sample_size
+ states, actions, rewards, next_states, terminals = super(CombinedReplayMemory, self).sample(batch_size=batch_size)
+ states = np.append(states, [self._last_state], axis=0)
+ actions = np.append(actions, [self._last_action], axis=0)
+ rewards = np.append(rewards, [self._last_reward], axis=0)
+ next_states = np.append(next_states, [self._last_next_state], axis=0)
+ terminals = np.append(terminals, [self._last_terminal], axis=0)
+ return states, actions, rewards, next_states, terminals
\ No newline at end of file
diff --git a/src/test/resources/target_code/reinforcement_learning/util.py b/src/test/resources/target_code/reinforcement_learning/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..e3f86a03b7240c51ab957b0170e08174efe2bc2e
--- /dev/null
+++ b/src/test/resources/target_code/reinforcement_learning/util.py
@@ -0,0 +1,134 @@
+import signal
+import sys
+import numpy as np
+import matplotlib.pyplot as plt
+from matplotlib import style
+import time
+import os
+import mxnet
+from mxnet import gluon, nd
+
+
+LOSS_FUNCTIONS = {
+ 'l1': gluon.loss.L1Loss(),
+ 'euclidean': gluon.loss.L2Loss(),
+ 'huber_loss': gluon.loss.HuberLoss(),
+ 'softmax_cross_entropy': gluon.loss.SoftmaxCrossEntropyLoss(),
+ 'sigmoid_cross_entropy': gluon.loss.SigmoidBinaryCrossEntropyLoss()}
+
+def copy_net(net, input_state_dim, ctx, tmp_filename='tmp.params'):
+ assert isinstance(net, gluon.HybridBlock)
+ assert type(net.__class__) is type
+ net.save_parameters(tmp_filename)
+ net2 = net.__class__()
+ net2.load_parameters(tmp_filename, ctx=ctx)
+ os.remove(tmp_filename)
+ net2.hybridize()
+ net2(nd.ones((1,) + input_state_dim, ctx=ctx))
+ return net2
+
+def get_loss_function(loss_function_name):
+ if loss_function_name not in LOSS_FUNCTIONS:
+ raise ValueError('Loss function does not exist')
+ return LOSS_FUNCTIONS[loss_function_name]
+
+
+class AgentSignalHandler(object):
+ def __init__(self):
+ signal.signal(signal.SIGINT, self.interrupt_training)
+ self.__agent = None
+
+ def register_agent(self, agent):
+ self.__agent = agent
+
+ def interrupt_training(self, sig, frame):
+ if self.__agent:
+ self.__agent.set_interrupt_flag(True)
+
+style.use('fivethirtyeight')
+class TrainingStats(object):
+ def __init__(self, logger, max_episodes, live_plot=True):
+ self.__logger = logger
+ self.__max_episodes = max_episodes
+ self.__all_avg_loss = np.zeros((max_episodes,))
+ self.__all_total_rewards = np.zeros((max_episodes,))
+ self.__all_eps = np.zeros((max_episodes,))
+ self.__all_time = np.zeros((max_episodes,))
+ self.__all_mean_reward_last_100_episodes = np.zeros((max_episodes,))
+ self.__live_plot = live_plot
+
+ @property
+ def logger(self):
+ return self.__logger
+
+ @logger.setter
+ def logger(self, logger):
+ self.__logger = logger
+
+ @logger.deleter
+ def logger(self):
+ self.__logger = None
+
+ def add_avg_loss(self, episode, avg_loss):
+ self.__all_avg_loss[episode] = avg_loss
+
+ def add_total_reward(self, episode, total_reward):
+ self.__all_total_rewards[episode] = total_reward
+
+ def add_eps(self, episode, eps):
+ self.__all_eps[episode] = eps
+
+ def add_time(self, episode, time):
+ self.__all_time[episode] = time
+
+ def add_mean_reward_last_100(self, episode, mean_reward):
+ self.__all_mean_reward_last_100_episodes[episode] = mean_reward
+
+ def log_episode(self, episode, start_time, training_steps, loss, eps, reward):
+ self.add_eps(episode, eps)
+ self.add_total_reward(episode, reward)
+ end = time.time()
+ if training_steps == 0:
+ avg_loss = 0
+ else:
+ avg_loss = float(loss)/float(training_steps)
+
+ mean_reward_last_100 = self.mean_of_reward(episode, last=100)
+
+ time_elapsed = end - start_time
+ info = "Episode: %d, Total Reward: %.3f, Avg. Reward Last 100 Episodes: %.3f, Avg Loss: %.3f, Time: %.3f, Training Steps: %d, Eps: %.3f"\
+ % (episode, reward, mean_reward_last_100, avg_loss, time_elapsed, training_steps, eps)
+ self.__logger.info(info)
+ self.add_avg_loss(episode, avg_loss)
+ self.add_time(episode, time_elapsed)
+ self.add_mean_reward_last_100(episode, mean_reward_last_100)
+
+ return avg_loss, time_elapsed, mean_reward_last_100
+
+ def mean_of_reward(self, cur_episode, last=100):
+ if cur_episode > 0:
+ reward_last_100 = self.__all_total_rewards[max(0, cur_episode-last):cur_episode]
+ return np.mean(reward_last_100)
+ else:
+ return self.__all_total_rewards[0]
+
+ def save_stats(self, file):
+ fig = plt.figure(figsize=(20,20))
+
+ sub_rewards = fig.add_subplot(221)
+ sub_rewards.set_title('Total Rewards per episode')
+ sub_rewards.plot(np.arange(self.__max_episodes), self.__all_total_rewards)
+
+ sub_loss = fig.add_subplot(222)
+ sub_loss.set_title('Avg. Loss per episode')
+ sub_loss.plot(np.arange(self.__max_episodes), self.__all_avg_loss)
+
+ sub_eps = fig.add_subplot(223)
+ sub_eps.set_title('Epsilon per episode')
+ sub_eps.plot(np.arange(self.__max_episodes), self.__all_eps)
+
+ sub_rewards = fig.add_subplot(224)
+ sub_rewards.set_title('Avg. mean reward of last 100 episodes')
+ sub_rewards.plot(np.arange(self.__max_episodes), self.__all_mean_reward_last_100_episodes)
+
+ plt.savefig(file)
\ No newline at end of file
diff --git a/src/test/resources/valid_tests/ReinforcementConfig1.cnnt b/src/test/resources/valid_tests/ReinforcementConfig1.cnnt
new file mode 100644
index 0000000000000000000000000000000000000000..101d98820434305f7a2627fcaf4433115ff83d5d
--- /dev/null
+++ b/src/test/resources/valid_tests/ReinforcementConfig1.cnnt
@@ -0,0 +1,45 @@
+configuration ReinforcementConfig1 {
+ learning_method : reinforcement
+
+ environment : ros_interface {
+ state_topic : "/environment/state"
+ action_topic : "/environment/action"
+ reset_topic : "/environment/reset"
+ }
+
+ agent_name : "reinforcement_agent"
+
+ reward_function : reward.rewardFunction
+
+ num_episodes : 1000
+ target_score : 35000
+ discount_factor : 0.99999
+ num_max_steps : 10000
+ training_interval : 1
+
+ use_fix_target_network : true
+ target_network_update_interval : 500
+
+ snapshot_interval : 500
+
+ use_double_dqn : true
+
+ loss : huber_loss
+
+ replay_memory : buffer{
+ memory_size : 1000000
+ sample_size : 64
+ }
+
+ action_selection : epsgreedy{
+ epsilon : 1.0
+ min_epsilon : 0.02
+ epsilon_decay_method: linear
+ epsilon_decay : 0.0001
+ }
+
+ optimizer : adam{
+ learning_rate : 0.001
+ }
+
+}
\ No newline at end of file
diff --git a/src/test/resources/valid_tests/ReinforcementConfig2.cnnt b/src/test/resources/valid_tests/ReinforcementConfig2.cnnt
new file mode 100644
index 0000000000000000000000000000000000000000..61deb55674e7a4c23913a60bd90760755cb211da
--- /dev/null
+++ b/src/test/resources/valid_tests/ReinforcementConfig2.cnnt
@@ -0,0 +1,50 @@
+configuration ReinforcementConfig2 {
+ learning_method : reinforcement
+
+ environment : gym { name:"CartPole-v1" }
+
+ agent_name : "reinforcement_agent"
+
+ num_episodes : 200
+ target_score : 185.5
+ discount_factor : 0.999
+ num_max_steps : 250
+ training_interval : 1
+
+ use_fix_target_network : false
+
+ snapshot_interval : 20
+
+ use_double_dqn : false
+
+ loss : euclidean
+
+ replay_memory : buffer{
+ memory_size : 10000
+ sample_size : 32
+ }
+
+ action_selection : epsgreedy{
+ epsilon : 1.0
+ min_epsilon : 0.01
+ epsilon_decay_method: linear
+ epsilon_decay : 0.0001
+ }
+
+ optimizer : rmsprop{
+ learning_rate : 0.001
+ learning_rate_minimum : 0.00001
+ weight_decay : 0.01
+ learning_rate_decay : 0.9
+ learning_rate_policy : step
+ step_size : 1000
+ rescale_grad : 1.1
+ clip_gradient : 10
+ gamma1 : 0.9
+ gamma2 : 0.9
+ epsilon : 0.000001
+ centered : true
+ clip_weights : 10
+ }
+
+}
\ No newline at end of file
diff --git a/src/test/resources/valid_tests/reward/RewardFunction.emadl b/src/test/resources/valid_tests/reward/RewardFunction.emadl
new file mode 100644
index 0000000000000000000000000000000000000000..0e9b97d92e6995082a995a953b1a87fbb4df3148
--- /dev/null
+++ b/src/test/resources/valid_tests/reward/RewardFunction.emadl
@@ -0,0 +1,15 @@
+package reward;
+
+component RewardFunction {
+ ports
+ in Q^{16} state,
+ in B isTerminal,
+ out Q reward;
+
+ implementation Math {
+ Q speed = state(15);
+ Q angle = state(1);
+
+ reward = speed * cos(angle);
+ }
+}
\ No newline at end of file