From a403480d145e2ca7df13cd90e8a4cdea2b3dc329 Mon Sep 17 00:00:00 2001
From: Nicola Gatto <nicola.gatto@rwth-aachen.de>
Date: Wed, 17 Jul 2019 01:15:55 +0200
Subject: [PATCH] Refactor network export

---
 .../templates/gluon/reinforcement/Trainer.ftl |  4 +-
 .../gluon/reinforcement/agent/Agent.ftl       | 53 ++++++++++---------
 2 files changed, 31 insertions(+), 26 deletions(-)

diff --git a/src/main/resources/templates/gluon/reinforcement/Trainer.ftl b/src/main/resources/templates/gluon/reinforcement/Trainer.ftl
index e2a18893..99cca6ef 100644
--- a/src/main/resources/templates/gluon/reinforcement/Trainer.ftl
+++ b/src/main/resources/templates/gluon/reinforcement/Trainer.ftl
@@ -170,7 +170,7 @@ if __name__ == "__main__":
 
     if train_successful:
 <#if (config.rlAlgorithm == "dqn")>
-        agent.save_best_network(qnet_creator._model_dir_ + qnet_creator._model_prefix_ + '_0_newest', epoch=0)
+        agent.export_best_network(path=qnet_creator._model_dir_ + qnet_creator._model_prefix_ + '_0_newest', epoch=0)
 <#else>
-        agent.save_best_network(actor_creator._model_dir_ + actor_creator._model_prefix_ + '_0_newest', epoch=0)
+        agent.export_best_network(path=actor_creator._model_dir_ + actor_creator._model_prefix_ + '_0_newest', epoch=0)
 </#if>
\ No newline at end of file
diff --git a/src/main/resources/templates/gluon/reinforcement/agent/Agent.ftl b/src/main/resources/templates/gluon/reinforcement/agent/Agent.ftl
index 88129c81..58d98edc 100644
--- a/src/main/resources/templates/gluon/reinforcement/agent/Agent.ftl
+++ b/src/main/resources/templates/gluon/reinforcement/agent/Agent.ftl
@@ -127,7 +127,7 @@ class Agent(object):
         self._logger = None
         self._environment.close()
         self._environment = None
-        self._save_net(self._best_net, 'best_net', session_dir)
+        self._export_net(self._best_net, 'best_net', filedir=session_dir)
         self._best_net = None
 
     def _make_config_dict(self):
@@ -258,26 +258,26 @@ class Agent(object):
         return self._target_score is not None\
             and avg_reward > self._target_score
 
-    def _save_parameters(self, net, episode=None, filename='dqn-agent-params'):
+    def _export_net(self, net, filename, filedir=None, episode=None):
         assert self._output_directory
         assert isinstance(net, gluon.HybridBlock)
         make_directory_if_not_exist(self._output_directory)
+        filedir = self._output_directory if filedir is None else filedir
+        filename = os.path.join(filedir, filename)
 
         if(episode is not None):
-            self._logger.info(
+            self._logger.debug(
                 'Saving model parameters after episode %d' % episode)
             filename = filename + '-ep{}'.format(episode)
         else:
-            self._logger.info('Saving model parameters')
-        self._save_net(net, filename)
+            self._logger.debug('Saving model parameters')
 
-    def _save_net(self, net, filename, filedir=None):
-        filedir = self._output_directory if filedir is None else filedir
-        filename = os.path.join(filedir, filename)
         net.save_parameters(filename + '.params')
         net.export(filename, epoch=0)
 
-    def save_best_network(self, path, epoch=0):
+    def export_best_network(self, path=None, epoch=0):
+        path = os.path.join(self._output_directory, 'best_network')\
+            if path is None else path
         self._logger.info(
             'Saving best network with average reward of {}'.format(
                 self._best_avg_score))
@@ -373,15 +373,17 @@ class DdpgAgent(Agent):
 
     def _make_pickle_ready(self, session_dir):
         super(DdpgAgent, self)._make_pickle_ready(session_dir)
-        self._save_net(self._actor, 'current_actor')
+        self._export_net(self._actor, 'current_actor')
 
-        self._save_net(self._actor, 'actor', session_dir)
+        self._export_net(self._actor, 'actor', filedir=session_dir)
         self._actor = None
-        self._save_net(self._critic, 'critic', session_dir)
+        self._export_net(self._critic, 'critic', filedir=session_dir)
         self._critic = None
-        self._save_net(self._actor_target, 'actor_target', session_dir)
+        self._export_net(
+            self._actor_target, 'actor_target', filedir=session_dir)
         self._actor_target = None
-        self._save_net(self._critic_target, 'critic_target', session_dir)
+        self._export_net(
+            self._critic_target, 'critic_target', filedir=session_dir)
         self._critic_target = None
 
     @classmethod
@@ -449,7 +451,8 @@ class DdpgAgent(Agent):
         return action[0].asnumpy()
 
     def save_parameters(self, episode):
-        self._save_parameters(self._actor, episode=episode)
+        self._export_net(
+            self._actor, self._agent_name + '_actor', episode=episode)
 
     def train(self, episodes=None):
         self.save_config_file()
@@ -605,7 +608,7 @@ class DdpgAgent(Agent):
 
         self._evaluate()
         self.save_parameters(episode=self._current_episode)
-        self.save_best_network(os.path.join(self._output_directory, 'best'))
+        self.export_best_network()
         self._training_stats.save_stats(self._output_directory)
         self._logger.info('--------- Training finished ---------')
         return True
@@ -707,9 +710,10 @@ class TwinDelayedDdpgAgent(DdpgAgent):
 
     def _make_pickle_ready(self, session_dir):
         super(TwinDelayedDdpgAgent, self)._make_pickle_ready(session_dir)
-        self._save_net(self._critic2, 'critic2', session_dir)
+        self._export_net(self._critic2, 'critic2', filedir=session_dir)
         self._critic2 = None
-        self._save_net(self._critic2_target, 'critic2_target', session_dir)
+        self._export_net(
+            self._critic2_target, 'critic2_target', filedir=session_dir)
         self._critic2_target = None
 
     @classmethod
@@ -980,7 +984,7 @@ class TwinDelayedDdpgAgent(DdpgAgent):
 
         self._evaluate()
         self.save_parameters(episode=self._current_episode)
-        self.save_best_network(os.path.join(self._output_directory, 'best'))
+        self.export_best_network()
         self._training_stats.save_stats(self._output_directory)
         self._logger.info('--------- Training finished ---------')
         return True
@@ -1092,10 +1096,10 @@ class DqnAgent(Agent):
 
     def _make_pickle_ready(self, session_dir):
         super(DqnAgent, self)._make_pickle_ready(session_dir)
-        self._save_net(self._qnet, 'current_qnet')
-        self._save_net(self._qnet, 'qnet', session_dir)
+        self._export_net(self._qnet, 'current_qnet')
+        self._export_net(self._qnet, 'qnet', filedir=session_dir)
         self._qnet = None
-        self._save_net(self._target_qnet, 'target_net', session_dir)
+        self._export_net(self._target_qnet, 'target_net', filedir=session_dir)
         self._target_qnet = None
 
     def get_q_values(self, state, with_best=False):
@@ -1237,7 +1241,7 @@ class DqnAgent(Agent):
 
         self._evaluate()
         self.save_parameters(episode=self._current_episode)
-        self.save_best_network(os.path.join(self._output_directory, 'best'))
+        self.export_best_network()
         self._training_stats.save_stats(self._output_directory)
         self._logger.info('--------- Training finished ---------')
         return True
@@ -1253,7 +1257,8 @@ class DqnAgent(Agent):
         return config
 
     def save_parameters(self, episode):
-        self._save_parameters(self._qnet, episode=episode)
+        self._export_net(
+            self._qnet, self._agent_name + '_qnet', episode=episode)
 
     def _save_current_as_best_net(self):
         self._best_net = copy_net(
-- 
GitLab