diff --git a/src/main/resources/templates/gluon/reinforcement/agent/Agent.ftl b/src/main/resources/templates/gluon/reinforcement/agent/Agent.ftl index 21ecbdaa604a6f2b9b4ee5ee14305646cd3ff1f8..dff669c45520e46cb33965ab90667f9b9b7b123e 100644 --- a/src/main/resources/templates/gluon/reinforcement/agent/Agent.ftl +++ b/src/main/resources/templates/gluon/reinforcement/agent/Agent.ftl @@ -1003,7 +1003,7 @@ class DqnAgent(Agent): action_dim, ctx=None, discount_factor=.9, - loss_function='euclidean', + loss_function='l2', optimizer='rmsprop', optimizer_params={'learning_rate': 0.09}, training_episodes=50, diff --git a/src/main/resources/templates/gluon/reinforcement/params/DqnAgentParams.ftl b/src/main/resources/templates/gluon/reinforcement/params/DqnAgentParams.ftl index ccf847018bd004f9aeab146b84c1c2717d4c3361..ff16ea6d49251ede96c17884632eeb473a22589b 100644 --- a/src/main/resources/templates/gluon/reinforcement/params/DqnAgentParams.ftl +++ b/src/main/resources/templates/gluon/reinforcement/params/DqnAgentParams.ftl @@ -6,7 +6,7 @@ 'use_fix_target': False, </#if> <#if (config.configuration.loss)??> - 'loss': '${config.lossName}', + 'loss_function': '${config.lossName}', <#if (config.lossParams)??> 'loss_params': { <#list config.lossParams?keys as param> diff --git a/src/main/resources/templates/gluon/reinforcement/util/Util.ftl b/src/main/resources/templates/gluon/reinforcement/util/Util.ftl index 78e7795fd290f99980c689ad1c76f274a821f830..1d8b38c4a32a1cc6c0e364054e746e9115d7d05b 100644 --- a/src/main/resources/templates/gluon/reinforcement/util/Util.ftl +++ b/src/main/resources/templates/gluon/reinforcement/util/Util.ftl @@ -11,8 +11,8 @@ import cnnarch_logger LOSS_FUNCTIONS = { 'l1': gluon.loss.L1Loss(), - 'euclidean': gluon.loss.L2Loss(), - 'huber_loss': gluon.loss.HuberLoss(), + 'l2': gluon.loss.L2Loss(), + 'huber': gluon.loss.HuberLoss(), 'softmax_cross_entropy': gluon.loss.SoftmaxCrossEntropyLoss(), 'sigmoid_cross_entropy': gluon.loss.SigmoidBinaryCrossEntropyLoss()}