From aa7af2ba4b8b504b42950961dfa3b505759192eb Mon Sep 17 00:00:00 2001 From: Nicola Gatto Date: Thu, 25 Apr 2019 17:39:48 +0200 Subject: [PATCH] Add python wrapper to CNNArch2Gluon for reward fucntion generation --- .../monticar/emadl/generator/Backend.java | 2 +- .../emadl/generator/EMADLGenerator.java | 6 +++ .../generator/RewardFunctionCppGenerator.java | 41 +++++++++++++++++++ .../lang/monticar/emadl/GenerationTest.java | 10 +++++ .../reinforcementModel/cartpole/Master.emadl | 17 ++++++++ .../reinforcementModel/cartpole/Master.tag | 7 ++++ .../cartpole/agent/CartPoleDQN.cnnt | 40 ++++++++++++++++++ .../cartpole/agent/CartPoleDQN.emadl | 17 ++++++++ .../cartpole/agent/reward/Reward.emadl | 12 ++++++ .../cartpole/policy/Greedy.emadl | 21 ++++++++++ 10 files changed, 172 insertions(+), 1 deletion(-) create mode 100644 src/main/java/de/monticore/lang/monticar/emadl/generator/RewardFunctionCppGenerator.java create mode 100644 src/test/resources/models/reinforcementModel/cartpole/Master.emadl create mode 100644 src/test/resources/models/reinforcementModel/cartpole/Master.tag create mode 100644 src/test/resources/models/reinforcementModel/cartpole/agent/CartPoleDQN.cnnt create mode 100644 src/test/resources/models/reinforcementModel/cartpole/agent/CartPoleDQN.emadl create mode 100644 src/test/resources/models/reinforcementModel/cartpole/agent/reward/Reward.emadl create mode 100644 src/test/resources/models/reinforcementModel/cartpole/policy/Greedy.emadl diff --git a/src/main/java/de/monticore/lang/monticar/emadl/generator/Backend.java b/src/main/java/de/monticore/lang/monticar/emadl/generator/Backend.java index f08096e..9675093 100644 --- a/src/main/java/de/monticore/lang/monticar/emadl/generator/Backend.java +++ b/src/main/java/de/monticore/lang/monticar/emadl/generator/Backend.java @@ -40,7 +40,7 @@ public enum Backend { } @Override public CNNTrainGenerator getCNNTrainGenerator() { - return new CNNTrain2Gluon(); + return new CNNTrain2Gluon(new RewardFunctionCppGenerator()); } }; diff --git a/src/main/java/de/monticore/lang/monticar/emadl/generator/EMADLGenerator.java b/src/main/java/de/monticore/lang/monticar/emadl/generator/EMADLGenerator.java index 2d3586d..24bf8ec 100644 --- a/src/main/java/de/monticore/lang/monticar/emadl/generator/EMADLGenerator.java +++ b/src/main/java/de/monticore/lang/monticar/emadl/generator/EMADLGenerator.java @@ -30,6 +30,8 @@ import de.monticore.lang.math._symboltable.MathStatementsSymbol; import de.monticore.lang.monticar.cnnarch.CNNArchGenerator; import de.monticore.lang.monticar.cnnarch.DataPathConfigParser; import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureSymbol; +import de.monticore.lang.monticar.cnnarch.gluongenerator.CNNArch2Gluon; +import de.monticore.lang.monticar.cnnarch.gluongenerator.CNNTrain2Gluon; import de.monticore.lang.monticar.cnntrain.CNNTrainGenerator; import de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol; import de.monticore.lang.monticar.emadl._cocos.EMADLCocos; @@ -495,6 +497,10 @@ public class EMADLGenerator { String trainConfigFilename = getConfigFilename(mainComponentName, component.getFullName(), component.getName()); //should be removed when CNNTrain supports packages + cnnTrainGenerator.setGenerationTargetPath(getGenerationTargetPath()); + if (cnnTrainGenerator instanceof CNNTrain2Gluon) { + ((CNNTrain2Gluon) cnnTrainGenerator).setRootProjectModelsDir(getModelsPath()); + } List names = Splitter.on("/").splitToList(trainConfigFilename); trainConfigFilename = names.get(names.size()-1); Path modelPath = Paths.get(getModelsPath() + Joiner.on("/").join(names.subList(0,names.size()-1))); diff --git a/src/main/java/de/monticore/lang/monticar/emadl/generator/RewardFunctionCppGenerator.java b/src/main/java/de/monticore/lang/monticar/emadl/generator/RewardFunctionCppGenerator.java new file mode 100644 index 0000000..7b56824 --- /dev/null +++ b/src/main/java/de/monticore/lang/monticar/emadl/generator/RewardFunctionCppGenerator.java @@ -0,0 +1,41 @@ +package de.monticore.lang.monticar.emadl.generator; + +import de.monticore.lang.embeddedmontiarc.embeddedmontiarc._symboltable.instanceStructure.EMAComponentInstanceSymbol; +import de.monticore.lang.monticar.cnnarch.gluongenerator.reinforcement.RewardFunctionSourceGenerator; +import de.monticore.lang.monticar.generator.cpp.GeneratorEMAMOpt2CPP; +import de.monticore.lang.tagging._symboltable.TaggingResolver; +import de.se_rwth.commons.logging.Log; + +import java.io.IOException; +import java.util.Optional; + +/** + * + */ +public class RewardFunctionCppGenerator implements RewardFunctionSourceGenerator { + public RewardFunctionCppGenerator() { + } + + @Override + public void generate(String modelPath, String rootModel, String targetPath) { + GeneratorEMAMOpt2CPP generator = new GeneratorEMAMOpt2CPP(); + generator.useArmadilloBackend(); + + TaggingResolver taggingResolver = EMADLAbstractSymtab.createSymTabAndTaggingResolver(modelPath); + Optional instanceSymbol = taggingResolver + .resolve(rootModel, EMAComponentInstanceSymbol.KIND); + + if (!instanceSymbol.isPresent()) { + Log.error("Generation of reward function is not possible: Cannot resolve component instance " + + rootModel); + } + + generator.setGenerationTargetPath(targetPath); + + try { + generator.generate(instanceSymbol.get(), taggingResolver); + } catch (IOException e) { + Log.error("Generation of reward function is not possible: " + e.getMessage()); + } + } +} diff --git a/src/test/java/de/monticore/lang/monticar/emadl/GenerationTest.java b/src/test/java/de/monticore/lang/monticar/emadl/GenerationTest.java index c6f2e78..f6b0a19 100644 --- a/src/test/java/de/monticore/lang/monticar/emadl/GenerationTest.java +++ b/src/test/java/de/monticore/lang/monticar/emadl/GenerationTest.java @@ -23,6 +23,7 @@ package de.monticore.lang.monticar.emadl; import de.monticore.lang.monticar.emadl.generator.Backend; import de.monticore.lang.monticar.emadl.generator.EMADLGenerator; import de.monticore.lang.monticar.emadl.generator.EMADLGeneratorCli; +import de.se_rwth.commons.logging.Finding; import de.se_rwth.commons.logging.Log; import freemarker.template.TemplateException; import org.junit.Before; @@ -35,6 +36,7 @@ import java.nio.file.Path; import java.nio.file.Paths; import java.util.Arrays; import java.util.List; +import java.util.stream.Collectors; import static junit.framework.TestCase.assertTrue; import static org.junit.Assert.assertFalse; @@ -187,6 +189,14 @@ public class GenerationTest extends AbstractSymtabTest { "mnist_mnistClassifier_net.h")); } + @Test + public void testGluonReinforcementModel() { + Log.getFindings().clear(); + String[] args = {"-m", "src/test/resources/models/reinforcementModel", "-r", "cartpole.Master", "-b", "GLUON", "-f", "n", "-c", "n"}; + EMADLGeneratorCli.main(args); + assertTrue(Log.getFindings().stream().filter(Finding::isError).collect(Collectors.toList()).isEmpty()); + } + @Test public void testHashFunction() { EMADLGenerator tester = new EMADLGenerator(Backend.MXNET); diff --git a/src/test/resources/models/reinforcementModel/cartpole/Master.emadl b/src/test/resources/models/reinforcementModel/cartpole/Master.emadl new file mode 100644 index 0000000..a086a75 --- /dev/null +++ b/src/test/resources/models/reinforcementModel/cartpole/Master.emadl @@ -0,0 +1,17 @@ +package cartpole; + +import cartpole.agent.*; +import cartpole.policy.*; + +component Master { + ports + in Q^{4} state, + out Z action; + + instance CartPoleDQN dqn; + instance Greedy policy; + + connect state -> dqn.state; + connect dqn.qvalues -> policy.values; + connect policy.action -> action; +} \ No newline at end of file diff --git a/src/test/resources/models/reinforcementModel/cartpole/Master.tag b/src/test/resources/models/reinforcementModel/cartpole/Master.tag new file mode 100644 index 0000000..1e3520c --- /dev/null +++ b/src/test/resources/models/reinforcementModel/cartpole/Master.tag @@ -0,0 +1,7 @@ +package cartpole; +conforms to de.monticore.lang.monticar.generator.roscpp.RosToEmamTagSchema; + +tags Master { + tag master.state with RosConnection = {topic=(/state, std_msgs/Float32MultiArray)}; + tag master.action with RosConnection = {topic=(/step, std_msgs/Int32)}; +} \ No newline at end of file diff --git a/src/test/resources/models/reinforcementModel/cartpole/agent/CartPoleDQN.cnnt b/src/test/resources/models/reinforcementModel/cartpole/agent/CartPoleDQN.cnnt new file mode 100644 index 0000000..b4cb800 --- /dev/null +++ b/src/test/resources/models/reinforcementModel/cartpole/agent/CartPoleDQN.cnnt @@ -0,0 +1,40 @@ +configuration CartPoleDQN { + learning_method : reinforcement + + environment : gym { name : "CartPole-v0" } + + agent_name : "reinforcement_agent" + + reward_function : cartpole.agent.reward.reward + + num_episodes : 1000 + target_score : 185.5 + discount_factor : 0.999 + num_max_steps : 500 + training_interval : 1 + + use_fix_target_network : true + target_network_update_interval : 200 + + snapshot_interval : 50 + + use_double_dqn : true + + loss : huber_loss + + replay_memory : buffer{ + memory_size : 20000 + sample_size : 32 + } + + action_selection : epsgreedy{ + epsilon : 1.0 + min_epsilon : 0.01 + epsilon_decay_method: linear + epsilon_decay : 0.001 + } + + optimizer : rmsprop{ + learning_rate : 0.001 + } +} \ No newline at end of file diff --git a/src/test/resources/models/reinforcementModel/cartpole/agent/CartPoleDQN.emadl b/src/test/resources/models/reinforcementModel/cartpole/agent/CartPoleDQN.emadl new file mode 100644 index 0000000..acf83e6 --- /dev/null +++ b/src/test/resources/models/reinforcementModel/cartpole/agent/CartPoleDQN.emadl @@ -0,0 +1,17 @@ +package cartpole.agent; + +component CartPoleDQN { + ports + in Q^{4} state, + out Q(-oo:oo)^{2} qvalues; + + implementation CNN { + state -> + FullyConnected(units=256) -> + Relu() -> + FullyConnected(units=128) -> + Relu() -> + FullyConnected(units=2) -> + qvalues + } +} \ No newline at end of file diff --git a/src/test/resources/models/reinforcementModel/cartpole/agent/reward/Reward.emadl b/src/test/resources/models/reinforcementModel/cartpole/agent/reward/Reward.emadl new file mode 100644 index 0000000..74866f9 --- /dev/null +++ b/src/test/resources/models/reinforcementModel/cartpole/agent/reward/Reward.emadl @@ -0,0 +1,12 @@ +package cartpole.agent.reward; + +component Reward { + ports + in Q^{4} state, + out Q reward; + + implementation Math { + Q rew = state(1); + reward = rew; + } +} \ No newline at end of file diff --git a/src/test/resources/models/reinforcementModel/cartpole/policy/Greedy.emadl b/src/test/resources/models/reinforcementModel/cartpole/policy/Greedy.emadl new file mode 100644 index 0000000..7318dbe --- /dev/null +++ b/src/test/resources/models/reinforcementModel/cartpole/policy/Greedy.emadl @@ -0,0 +1,21 @@ +package cartpole.policy; + +component Greedy { + ports + in Q(-oo:oo)^{2} values, + out Z action; + + implementation Math { + Q maxQValue = values(0); + Z maxValueAction = 0; + + for i = 1:2 + if maxQValue > values(i) + maxQValue = values(i); + maxValueAction = i-1; + end + end + + action = maxValueAction; + } +} \ No newline at end of file -- GitLab