From ed0c66c860b3eacb71d68e30cb77503535a0776d Mon Sep 17 00:00:00 2001 From: Nicola Gatto <nicola.gatto@rwth-aachen.de> Date: Tue, 16 Jul 2019 00:53:35 +0200 Subject: [PATCH] Add average Q-values for TD3 algorithm --- .../templates/gluon/reinforcement/agent/Agent.ftl | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/main/resources/templates/gluon/reinforcement/agent/Agent.ftl b/src/main/resources/templates/gluon/reinforcement/agent/Agent.ftl index ec9ea30a..aa0d968b 100644 --- a/src/main/resources/templates/gluon/reinforcement/agent/Agent.ftl +++ b/src/main/resources/templates/gluon/reinforcement/agent/Agent.ftl @@ -916,6 +916,9 @@ class TwinDelayedDdpgAgent(DdpgAgent): if self._total_steps % self._policy_delay == 0: tmp_critic = self._copy_critic() + episode_avg_q_value +=\ + np.sum(tmp_critic( + states, self._actor(states)).asnumpy()) / self._minibatch_size with autograd.record(): actor_loss = -tmp_critic( states, self._actor(states)).mean() @@ -942,7 +945,6 @@ class TwinDelayedDdpgAgent(DdpgAgent): np.sum(critic_loss.asnumpy()) / self._minibatch_size episode_actor_loss += 0 if actor_updates == 0 else\ np.sum(actor_loss.asnumpy()[0]) - episode_avg_q_value = 0 training_steps += 1 @@ -961,8 +963,8 @@ class TwinDelayedDdpgAgent(DdpgAgent): else (episode_actor_loss / actor_updates) episode_critic_loss = 0 if training_steps == 0\ else (episode_critic_loss / training_steps) - episode_avg_q_value = 0 if training_steps == 0\ - else (episode_avg_q_value / training_steps) + episode_avg_q_value = 0 if actor_updates == 0\ + else (episode_avg_q_value / actor_updates) avg_reward = self._training_stats.log_episode( self._current_episode, start, training_steps, -- GitLab