From 3027c51ddd8ea401655939d5adf12ff2d6036fee Mon Sep 17 00:00:00 2001
From: Nicola Gatto <nicola.gatto@rwth-aachen.de>
Date: Tue, 23 Jul 2019 18:19:57 +0200
Subject: [PATCH] Generate reward function when generate is called

---
 .../cnnarch/gluongenerator/CNNTrain2Gluon.java       | 12 +++++-------
 1 file changed, 5 insertions(+), 7 deletions(-)

diff --git a/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/CNNTrain2Gluon.java b/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/CNNTrain2Gluon.java
index 2960d094..e358c2c3 100644
--- a/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/CNNTrain2Gluon.java
+++ b/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/CNNTrain2Gluon.java
@@ -52,13 +52,6 @@ public class CNNTrain2Gluon extends CNNTrainGenerator {
     @Override
     public ConfigurationSymbol getConfigurationSymbol(Path modelsDirPath, String rootModelName) {
         ConfigurationSymbol configurationSymbol = super.getConfigurationSymbol(modelsDirPath, rootModelName);
-
-        // Generate Reward function if necessary
-        if (configurationSymbol.getLearningMethod().equals(LearningMethod.REINFORCEMENT)
-                && configurationSymbol.getRlRewardFunction().isPresent()) {
-            generateRewardFunction(configurationSymbol.getRlRewardFunction().get(), modelsDirPath);
-        }
-
         return configurationSymbol;
     }
 
@@ -155,6 +148,11 @@ public class CNNTrain2Gluon extends CNNTrainGenerator {
                 ftlContext.put("criticInstanceName", criticInstanceName);
             }
 
+            // Generate Reward function if necessary
+            if (configuration.getRlRewardFunction().isPresent()) {
+                generateRewardFunction(configuration.getRlRewardFunction().get(), Paths.get(rootProjectModelsDir));
+            }
+
             ftlContext.put("trainerName", trainerName);
             Map<String, String> rlFrameworkContentMap = constructReinforcementLearningFramework(templateConfiguration, ftlContext, rlAlgorithm);
             fileContentMap.putAll(rlFrameworkContentMap);
-- 
GitLab