From 247b1b173efc5cdfa10eb95f3a3ab927c6883edd Mon Sep 17 00:00:00 2001
From: Nicola Gatto <nicola.gatto@rwth-aachen.de>
Date: Mon, 15 Jul 2019 15:40:16 +0200
Subject: [PATCH] Adapt new loss names

---
 .../resources/templates/gluon/reinforcement/agent/Agent.ftl   | 2 +-
 .../templates/gluon/reinforcement/params/DqnAgentParams.ftl   | 2 +-
 .../resources/templates/gluon/reinforcement/util/Util.ftl     | 4 ++--
 3 files changed, 4 insertions(+), 4 deletions(-)

diff --git a/src/main/resources/templates/gluon/reinforcement/agent/Agent.ftl b/src/main/resources/templates/gluon/reinforcement/agent/Agent.ftl
index 21ecbdaa..dff669c4 100644
--- a/src/main/resources/templates/gluon/reinforcement/agent/Agent.ftl
+++ b/src/main/resources/templates/gluon/reinforcement/agent/Agent.ftl
@@ -1003,7 +1003,7 @@ class DqnAgent(Agent):
         action_dim,
         ctx=None,
         discount_factor=.9,
-        loss_function='euclidean',
+        loss_function='l2',
         optimizer='rmsprop',
         optimizer_params={'learning_rate': 0.09},
         training_episodes=50,
diff --git a/src/main/resources/templates/gluon/reinforcement/params/DqnAgentParams.ftl b/src/main/resources/templates/gluon/reinforcement/params/DqnAgentParams.ftl
index ccf84701..ff16ea6d 100644
--- a/src/main/resources/templates/gluon/reinforcement/params/DqnAgentParams.ftl
+++ b/src/main/resources/templates/gluon/reinforcement/params/DqnAgentParams.ftl
@@ -6,7 +6,7 @@
         'use_fix_target': False,
 </#if>
 <#if (config.configuration.loss)??>
-        'loss': '${config.lossName}',
+        'loss_function': '${config.lossName}',
 <#if (config.lossParams)??>
         'loss_params': {
 <#list config.lossParams?keys as param>
diff --git a/src/main/resources/templates/gluon/reinforcement/util/Util.ftl b/src/main/resources/templates/gluon/reinforcement/util/Util.ftl
index 78e7795f..1d8b38c4 100644
--- a/src/main/resources/templates/gluon/reinforcement/util/Util.ftl
+++ b/src/main/resources/templates/gluon/reinforcement/util/Util.ftl
@@ -11,8 +11,8 @@ import cnnarch_logger
 
 LOSS_FUNCTIONS = {
         'l1': gluon.loss.L1Loss(),
-        'euclidean': gluon.loss.L2Loss(),
-        'huber_loss': gluon.loss.HuberLoss(),
+        'l2': gluon.loss.L2Loss(),
+        'huber': gluon.loss.HuberLoss(),
         'softmax_cross_entropy': gluon.loss.SoftmaxCrossEntropyLoss(),
         'sigmoid_cross_entropy': gluon.loss.SigmoidBinaryCrossEntropyLoss()}
 
-- 
GitLab