Skip to content
Snippets Groups Projects
Commit 3027c51d authored by Nicola Gatto's avatar Nicola Gatto
Browse files

Generate reward function when generate is called

parent ceb58e8c
Branches
No related tags found
3 merge requests!20Implemented layer variables and RNN layer,!19Integrate TD3 Algorithm and Gaussian Noise,!18Integrate TD3 Algorithm and Gaussian Noise
......@@ -52,13 +52,6 @@ public class CNNTrain2Gluon extends CNNTrainGenerator {
@Override
public ConfigurationSymbol getConfigurationSymbol(Path modelsDirPath, String rootModelName) {
ConfigurationSymbol configurationSymbol = super.getConfigurationSymbol(modelsDirPath, rootModelName);
// Generate Reward function if necessary
if (configurationSymbol.getLearningMethod().equals(LearningMethod.REINFORCEMENT)
&& configurationSymbol.getRlRewardFunction().isPresent()) {
generateRewardFunction(configurationSymbol.getRlRewardFunction().get(), modelsDirPath);
}
return configurationSymbol;
}
......@@ -155,6 +148,11 @@ public class CNNTrain2Gluon extends CNNTrainGenerator {
ftlContext.put("criticInstanceName", criticInstanceName);
}
// Generate Reward function if necessary
if (configuration.getRlRewardFunction().isPresent()) {
generateRewardFunction(configuration.getRlRewardFunction().get(), Paths.get(rootProjectModelsDir));
}
ftlContext.put("trainerName", trainerName);
Map<String, String> rlFrameworkContentMap = constructReinforcementLearningFramework(templateConfiguration, ftlContext, rlAlgorithm);
fileContentMap.putAll(rlFrameworkContentMap);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment