Commit 3027c51d authored by Nicola Gatto's avatar Nicola Gatto
Browse files

Generate reward function when generate is called

parent ceb58e8c
......@@ -52,13 +52,6 @@ public class CNNTrain2Gluon extends CNNTrainGenerator {
@Override
public ConfigurationSymbol getConfigurationSymbol(Path modelsDirPath, String rootModelName) {
ConfigurationSymbol configurationSymbol = super.getConfigurationSymbol(modelsDirPath, rootModelName);
// Generate Reward function if necessary
if (configurationSymbol.getLearningMethod().equals(LearningMethod.REINFORCEMENT)
&& configurationSymbol.getRlRewardFunction().isPresent()) {
generateRewardFunction(configurationSymbol.getRlRewardFunction().get(), modelsDirPath);
}
return configurationSymbol;
}
......@@ -155,6 +148,11 @@ public class CNNTrain2Gluon extends CNNTrainGenerator {
ftlContext.put("criticInstanceName", criticInstanceName);
}
// Generate Reward function if necessary
if (configuration.getRlRewardFunction().isPresent()) {
generateRewardFunction(configuration.getRlRewardFunction().get(), Paths.get(rootProjectModelsDir));
}
ftlContext.put("trainerName", trainerName);
Map<String, String> rlFrameworkContentMap = constructReinforcementLearningFramework(templateConfiguration, ftlContext, rlAlgorithm);
fileContentMap.putAll(rlFrameworkContentMap);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment