From 247b1b173efc5cdfa10eb95f3a3ab927c6883edd Mon Sep 17 00:00:00 2001 From: Nicola Gatto <nicola.gatto@rwth-aachen.de> Date: Mon, 15 Jul 2019 15:40:16 +0200 Subject: [PATCH] Adapt new loss names --- .../resources/templates/gluon/reinforcement/agent/Agent.ftl | 2 +- .../templates/gluon/reinforcement/params/DqnAgentParams.ftl | 2 +- .../resources/templates/gluon/reinforcement/util/Util.ftl | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/main/resources/templates/gluon/reinforcement/agent/Agent.ftl b/src/main/resources/templates/gluon/reinforcement/agent/Agent.ftl index 21ecbdaa..dff669c4 100644 --- a/src/main/resources/templates/gluon/reinforcement/agent/Agent.ftl +++ b/src/main/resources/templates/gluon/reinforcement/agent/Agent.ftl @@ -1003,7 +1003,7 @@ class DqnAgent(Agent): action_dim, ctx=None, discount_factor=.9, - loss_function='euclidean', + loss_function='l2', optimizer='rmsprop', optimizer_params={'learning_rate': 0.09}, training_episodes=50, diff --git a/src/main/resources/templates/gluon/reinforcement/params/DqnAgentParams.ftl b/src/main/resources/templates/gluon/reinforcement/params/DqnAgentParams.ftl index ccf84701..ff16ea6d 100644 --- a/src/main/resources/templates/gluon/reinforcement/params/DqnAgentParams.ftl +++ b/src/main/resources/templates/gluon/reinforcement/params/DqnAgentParams.ftl @@ -6,7 +6,7 @@ 'use_fix_target': False, </#if> <#if (config.configuration.loss)??> - 'loss': '${config.lossName}', + 'loss_function': '${config.lossName}', <#if (config.lossParams)??> 'loss_params': { <#list config.lossParams?keys as param> diff --git a/src/main/resources/templates/gluon/reinforcement/util/Util.ftl b/src/main/resources/templates/gluon/reinforcement/util/Util.ftl index 78e7795f..1d8b38c4 100644 --- a/src/main/resources/templates/gluon/reinforcement/util/Util.ftl +++ b/src/main/resources/templates/gluon/reinforcement/util/Util.ftl @@ -11,8 +11,8 @@ import cnnarch_logger LOSS_FUNCTIONS = { 'l1': gluon.loss.L1Loss(), - 'euclidean': gluon.loss.L2Loss(), - 'huber_loss': gluon.loss.HuberLoss(), + 'l2': gluon.loss.L2Loss(), + 'huber': gluon.loss.HuberLoss(), 'softmax_cross_entropy': gluon.loss.SoftmaxCrossEntropyLoss(), 'sigmoid_cross_entropy': gluon.loss.SigmoidBinaryCrossEntropyLoss()} -- GitLab