From 3027c51ddd8ea401655939d5adf12ff2d6036fee Mon Sep 17 00:00:00 2001 From: Nicola Gatto <nicola.gatto@rwth-aachen.de> Date: Tue, 23 Jul 2019 18:19:57 +0200 Subject: [PATCH] Generate reward function when generate is called --- .../cnnarch/gluongenerator/CNNTrain2Gluon.java | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/CNNTrain2Gluon.java b/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/CNNTrain2Gluon.java index 2960d094..e358c2c3 100644 --- a/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/CNNTrain2Gluon.java +++ b/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/CNNTrain2Gluon.java @@ -52,13 +52,6 @@ public class CNNTrain2Gluon extends CNNTrainGenerator { @Override public ConfigurationSymbol getConfigurationSymbol(Path modelsDirPath, String rootModelName) { ConfigurationSymbol configurationSymbol = super.getConfigurationSymbol(modelsDirPath, rootModelName); - - // Generate Reward function if necessary - if (configurationSymbol.getLearningMethod().equals(LearningMethod.REINFORCEMENT) - && configurationSymbol.getRlRewardFunction().isPresent()) { - generateRewardFunction(configurationSymbol.getRlRewardFunction().get(), modelsDirPath); - } - return configurationSymbol; } @@ -155,6 +148,11 @@ public class CNNTrain2Gluon extends CNNTrainGenerator { ftlContext.put("criticInstanceName", criticInstanceName); } + // Generate Reward function if necessary + if (configuration.getRlRewardFunction().isPresent()) { + generateRewardFunction(configuration.getRlRewardFunction().get(), Paths.get(rootProjectModelsDir)); + } + ftlContext.put("trainerName", trainerName); Map<String, String> rlFrameworkContentMap = constructReinforcementLearningFramework(templateConfiguration, ftlContext, rlAlgorithm); fileContentMap.putAll(rlFrameworkContentMap); -- GitLab