From ed0c66c860b3eacb71d68e30cb77503535a0776d Mon Sep 17 00:00:00 2001
From: Nicola Gatto <nicola.gatto@rwth-aachen.de>
Date: Tue, 16 Jul 2019 00:53:35 +0200
Subject: [PATCH] Add average Q-values for TD3 algorithm

---
 .../templates/gluon/reinforcement/agent/Agent.ftl         | 8 +++++---
 1 file changed, 5 insertions(+), 3 deletions(-)

diff --git a/src/main/resources/templates/gluon/reinforcement/agent/Agent.ftl b/src/main/resources/templates/gluon/reinforcement/agent/Agent.ftl
index ec9ea30a..aa0d968b 100644
--- a/src/main/resources/templates/gluon/reinforcement/agent/Agent.ftl
+++ b/src/main/resources/templates/gluon/reinforcement/agent/Agent.ftl
@@ -916,6 +916,9 @@ class TwinDelayedDdpgAgent(DdpgAgent):
 
                     if self._total_steps % self._policy_delay == 0:
                         tmp_critic = self._copy_critic()
+                        episode_avg_q_value +=\
+                        np.sum(tmp_critic(
+                            states, self._actor(states)).asnumpy()) / self._minibatch_size
                         with autograd.record():
                             actor_loss = -tmp_critic(
                                 states, self._actor(states)).mean()
@@ -942,7 +945,6 @@ class TwinDelayedDdpgAgent(DdpgAgent):
                         np.sum(critic_loss.asnumpy()) / self._minibatch_size
                     episode_actor_loss += 0 if actor_updates == 0 else\
                         np.sum(actor_loss.asnumpy()[0])
-                    episode_avg_q_value = 0
 
                     training_steps += 1
 
@@ -961,8 +963,8 @@ class TwinDelayedDdpgAgent(DdpgAgent):
                 else (episode_actor_loss / actor_updates)
             episode_critic_loss = 0 if training_steps == 0\
                 else (episode_critic_loss / training_steps)
-            episode_avg_q_value = 0 if training_steps == 0\
-                else (episode_avg_q_value / training_steps)
+            episode_avg_q_value = 0 if actor_updates == 0\
+                    else (episode_avg_q_value / actor_updates)
 
             avg_reward = self._training_stats.log_episode(
                 self._current_episode, start, training_steps,
-- 
GitLab