From a403480d145e2ca7df13cd90e8a4cdea2b3dc329 Mon Sep 17 00:00:00 2001 From: Nicola Gatto <nicola.gatto@rwth-aachen.de> Date: Wed, 17 Jul 2019 01:15:55 +0200 Subject: [PATCH] Refactor network export --- .../templates/gluon/reinforcement/Trainer.ftl | 4 +- .../gluon/reinforcement/agent/Agent.ftl | 53 ++++++++++--------- 2 files changed, 31 insertions(+), 26 deletions(-) diff --git a/src/main/resources/templates/gluon/reinforcement/Trainer.ftl b/src/main/resources/templates/gluon/reinforcement/Trainer.ftl index e2a18893..99cca6ef 100644 --- a/src/main/resources/templates/gluon/reinforcement/Trainer.ftl +++ b/src/main/resources/templates/gluon/reinforcement/Trainer.ftl @@ -170,7 +170,7 @@ if __name__ == "__main__": if train_successful: <#if (config.rlAlgorithm == "dqn")> - agent.save_best_network(qnet_creator._model_dir_ + qnet_creator._model_prefix_ + '_0_newest', epoch=0) + agent.export_best_network(path=qnet_creator._model_dir_ + qnet_creator._model_prefix_ + '_0_newest', epoch=0) <#else> - agent.save_best_network(actor_creator._model_dir_ + actor_creator._model_prefix_ + '_0_newest', epoch=0) + agent.export_best_network(path=actor_creator._model_dir_ + actor_creator._model_prefix_ + '_0_newest', epoch=0) </#if> \ No newline at end of file diff --git a/src/main/resources/templates/gluon/reinforcement/agent/Agent.ftl b/src/main/resources/templates/gluon/reinforcement/agent/Agent.ftl index 88129c81..58d98edc 100644 --- a/src/main/resources/templates/gluon/reinforcement/agent/Agent.ftl +++ b/src/main/resources/templates/gluon/reinforcement/agent/Agent.ftl @@ -127,7 +127,7 @@ class Agent(object): self._logger = None self._environment.close() self._environment = None - self._save_net(self._best_net, 'best_net', session_dir) + self._export_net(self._best_net, 'best_net', filedir=session_dir) self._best_net = None def _make_config_dict(self): @@ -258,26 +258,26 @@ class Agent(object): return self._target_score is not None\ and avg_reward > self._target_score - def _save_parameters(self, net, episode=None, filename='dqn-agent-params'): + def _export_net(self, net, filename, filedir=None, episode=None): assert self._output_directory assert isinstance(net, gluon.HybridBlock) make_directory_if_not_exist(self._output_directory) + filedir = self._output_directory if filedir is None else filedir + filename = os.path.join(filedir, filename) if(episode is not None): - self._logger.info( + self._logger.debug( 'Saving model parameters after episode %d' % episode) filename = filename + '-ep{}'.format(episode) else: - self._logger.info('Saving model parameters') - self._save_net(net, filename) + self._logger.debug('Saving model parameters') - def _save_net(self, net, filename, filedir=None): - filedir = self._output_directory if filedir is None else filedir - filename = os.path.join(filedir, filename) net.save_parameters(filename + '.params') net.export(filename, epoch=0) - def save_best_network(self, path, epoch=0): + def export_best_network(self, path=None, epoch=0): + path = os.path.join(self._output_directory, 'best_network')\ + if path is None else path self._logger.info( 'Saving best network with average reward of {}'.format( self._best_avg_score)) @@ -373,15 +373,17 @@ class DdpgAgent(Agent): def _make_pickle_ready(self, session_dir): super(DdpgAgent, self)._make_pickle_ready(session_dir) - self._save_net(self._actor, 'current_actor') + self._export_net(self._actor, 'current_actor') - self._save_net(self._actor, 'actor', session_dir) + self._export_net(self._actor, 'actor', filedir=session_dir) self._actor = None - self._save_net(self._critic, 'critic', session_dir) + self._export_net(self._critic, 'critic', filedir=session_dir) self._critic = None - self._save_net(self._actor_target, 'actor_target', session_dir) + self._export_net( + self._actor_target, 'actor_target', filedir=session_dir) self._actor_target = None - self._save_net(self._critic_target, 'critic_target', session_dir) + self._export_net( + self._critic_target, 'critic_target', filedir=session_dir) self._critic_target = None @classmethod @@ -449,7 +451,8 @@ class DdpgAgent(Agent): return action[0].asnumpy() def save_parameters(self, episode): - self._save_parameters(self._actor, episode=episode) + self._export_net( + self._actor, self._agent_name + '_actor', episode=episode) def train(self, episodes=None): self.save_config_file() @@ -605,7 +608,7 @@ class DdpgAgent(Agent): self._evaluate() self.save_parameters(episode=self._current_episode) - self.save_best_network(os.path.join(self._output_directory, 'best')) + self.export_best_network() self._training_stats.save_stats(self._output_directory) self._logger.info('--------- Training finished ---------') return True @@ -707,9 +710,10 @@ class TwinDelayedDdpgAgent(DdpgAgent): def _make_pickle_ready(self, session_dir): super(TwinDelayedDdpgAgent, self)._make_pickle_ready(session_dir) - self._save_net(self._critic2, 'critic2', session_dir) + self._export_net(self._critic2, 'critic2', filedir=session_dir) self._critic2 = None - self._save_net(self._critic2_target, 'critic2_target', session_dir) + self._export_net( + self._critic2_target, 'critic2_target', filedir=session_dir) self._critic2_target = None @classmethod @@ -980,7 +984,7 @@ class TwinDelayedDdpgAgent(DdpgAgent): self._evaluate() self.save_parameters(episode=self._current_episode) - self.save_best_network(os.path.join(self._output_directory, 'best')) + self.export_best_network() self._training_stats.save_stats(self._output_directory) self._logger.info('--------- Training finished ---------') return True @@ -1092,10 +1096,10 @@ class DqnAgent(Agent): def _make_pickle_ready(self, session_dir): super(DqnAgent, self)._make_pickle_ready(session_dir) - self._save_net(self._qnet, 'current_qnet') - self._save_net(self._qnet, 'qnet', session_dir) + self._export_net(self._qnet, 'current_qnet') + self._export_net(self._qnet, 'qnet', filedir=session_dir) self._qnet = None - self._save_net(self._target_qnet, 'target_net', session_dir) + self._export_net(self._target_qnet, 'target_net', filedir=session_dir) self._target_qnet = None def get_q_values(self, state, with_best=False): @@ -1237,7 +1241,7 @@ class DqnAgent(Agent): self._evaluate() self.save_parameters(episode=self._current_episode) - self.save_best_network(os.path.join(self._output_directory, 'best')) + self.export_best_network() self._training_stats.save_stats(self._output_directory) self._logger.info('--------- Training finished ---------') return True @@ -1253,7 +1257,8 @@ class DqnAgent(Agent): return config def save_parameters(self, episode): - self._save_parameters(self._qnet, episode=episode) + self._export_net( + self._qnet, self._agent_name + '_qnet', episode=episode) def _save_current_as_best_net(self): self._best_net = copy_net( -- GitLab