Commit aa7af2ba authored by Nicola Gatto's avatar Nicola Gatto

Add python wrapper to CNNArch2Gluon for reward fucntion generation

parent 74cd8aa9
......@@ -40,7 +40,7 @@ public enum Backend {
}
@Override
public CNNTrainGenerator getCNNTrainGenerator() {
return new CNNTrain2Gluon();
return new CNNTrain2Gluon(new RewardFunctionCppGenerator());
}
};
......
......@@ -30,6 +30,8 @@ import de.monticore.lang.math._symboltable.MathStatementsSymbol;
import de.monticore.lang.monticar.cnnarch.CNNArchGenerator;
import de.monticore.lang.monticar.cnnarch.DataPathConfigParser;
import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureSymbol;
import de.monticore.lang.monticar.cnnarch.gluongenerator.CNNArch2Gluon;
import de.monticore.lang.monticar.cnnarch.gluongenerator.CNNTrain2Gluon;
import de.monticore.lang.monticar.cnntrain.CNNTrainGenerator;
import de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol;
import de.monticore.lang.monticar.emadl._cocos.EMADLCocos;
......@@ -495,6 +497,10 @@ public class EMADLGenerator {
String trainConfigFilename = getConfigFilename(mainComponentName, component.getFullName(), component.getName());
//should be removed when CNNTrain supports packages
cnnTrainGenerator.setGenerationTargetPath(getGenerationTargetPath());
if (cnnTrainGenerator instanceof CNNTrain2Gluon) {
((CNNTrain2Gluon) cnnTrainGenerator).setRootProjectModelsDir(getModelsPath());
}
List<String> names = Splitter.on("/").splitToList(trainConfigFilename);
trainConfigFilename = names.get(names.size()-1);
Path modelPath = Paths.get(getModelsPath() + Joiner.on("/").join(names.subList(0,names.size()-1)));
......
package de.monticore.lang.monticar.emadl.generator;
import de.monticore.lang.embeddedmontiarc.embeddedmontiarc._symboltable.instanceStructure.EMAComponentInstanceSymbol;
import de.monticore.lang.monticar.cnnarch.gluongenerator.reinforcement.RewardFunctionSourceGenerator;
import de.monticore.lang.monticar.generator.cpp.GeneratorEMAMOpt2CPP;
import de.monticore.lang.tagging._symboltable.TaggingResolver;
import de.se_rwth.commons.logging.Log;
import java.io.IOException;
import java.util.Optional;
/**
*
*/
public class RewardFunctionCppGenerator implements RewardFunctionSourceGenerator {
public RewardFunctionCppGenerator() {
}
@Override
public void generate(String modelPath, String rootModel, String targetPath) {
GeneratorEMAMOpt2CPP generator = new GeneratorEMAMOpt2CPP();
generator.useArmadilloBackend();
TaggingResolver taggingResolver = EMADLAbstractSymtab.createSymTabAndTaggingResolver(modelPath);
Optional<EMAComponentInstanceSymbol> instanceSymbol = taggingResolver
.<EMAComponentInstanceSymbol>resolve(rootModel, EMAComponentInstanceSymbol.KIND);
if (!instanceSymbol.isPresent()) {
Log.error("Generation of reward function is not possible: Cannot resolve component instance "
+ rootModel);
}
generator.setGenerationTargetPath(targetPath);
try {
generator.generate(instanceSymbol.get(), taggingResolver);
} catch (IOException e) {
Log.error("Generation of reward function is not possible: " + e.getMessage());
}
}
}
......@@ -23,6 +23,7 @@ package de.monticore.lang.monticar.emadl;
import de.monticore.lang.monticar.emadl.generator.Backend;
import de.monticore.lang.monticar.emadl.generator.EMADLGenerator;
import de.monticore.lang.monticar.emadl.generator.EMADLGeneratorCli;
import de.se_rwth.commons.logging.Finding;
import de.se_rwth.commons.logging.Log;
import freemarker.template.TemplateException;
import org.junit.Before;
......@@ -35,6 +36,7 @@ import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
import static junit.framework.TestCase.assertTrue;
import static org.junit.Assert.assertFalse;
......@@ -187,6 +189,14 @@ public class GenerationTest extends AbstractSymtabTest {
"mnist_mnistClassifier_net.h"));
}
@Test
public void testGluonReinforcementModel() {
Log.getFindings().clear();
String[] args = {"-m", "src/test/resources/models/reinforcementModel", "-r", "cartpole.Master", "-b", "GLUON", "-f", "n", "-c", "n"};
EMADLGeneratorCli.main(args);
assertTrue(Log.getFindings().stream().filter(Finding::isError).collect(Collectors.toList()).isEmpty());
}
@Test
public void testHashFunction() {
EMADLGenerator tester = new EMADLGenerator(Backend.MXNET);
......
package cartpole;
import cartpole.agent.*;
import cartpole.policy.*;
component Master {
ports
in Q^{4} state,
out Z action;
instance CartPoleDQN dqn;
instance Greedy policy;
connect state -> dqn.state;
connect dqn.qvalues -> policy.values;
connect policy.action -> action;
}
\ No newline at end of file
package cartpole;
conforms to de.monticore.lang.monticar.generator.roscpp.RosToEmamTagSchema;
tags Master {
tag master.state with RosConnection = {topic=(/state, std_msgs/Float32MultiArray)};
tag master.action with RosConnection = {topic=(/step, std_msgs/Int32)};
}
\ No newline at end of file
configuration CartPoleDQN {
learning_method : reinforcement
environment : gym { name : "CartPole-v0" }
agent_name : "reinforcement_agent"
reward_function : cartpole.agent.reward.reward
num_episodes : 1000
target_score : 185.5
discount_factor : 0.999
num_max_steps : 500
training_interval : 1
use_fix_target_network : true
target_network_update_interval : 200
snapshot_interval : 50
use_double_dqn : true
loss : huber_loss
replay_memory : buffer{
memory_size : 20000
sample_size : 32
}
action_selection : epsgreedy{
epsilon : 1.0
min_epsilon : 0.01
epsilon_decay_method: linear
epsilon_decay : 0.001
}
optimizer : rmsprop{
learning_rate : 0.001
}
}
\ No newline at end of file
package cartpole.agent;
component CartPoleDQN {
ports
in Q^{4} state,
out Q(-oo:oo)^{2} qvalues;
implementation CNN {
state ->
FullyConnected(units=256) ->
Relu() ->
FullyConnected(units=128) ->
Relu() ->
FullyConnected(units=2) ->
qvalues
}
}
\ No newline at end of file
package cartpole.agent.reward;
component Reward {
ports
in Q^{4} state,
out Q reward;
implementation Math {
Q rew = state(1);
reward = rew;
}
}
\ No newline at end of file
package cartpole.policy;
component Greedy {
ports
in Q(-oo:oo)^{2} values,
out Z action;
implementation Math {
Q maxQValue = values(0);
Z maxValueAction = 0;
for i = 1:2
if maxQValue > values(i)
maxQValue = values(i);
maxValueAction = i-1;
end
end
action = maxValueAction;
}
}
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment