Commit a403480d authored by Nicola Gatto's avatar Nicola Gatto
Browse files

Refactor network export

parent 9a5ec6a1
Pipeline #161472 failed with stages
...@@ -170,7 +170,7 @@ if __name__ == "__main__": ...@@ -170,7 +170,7 @@ if __name__ == "__main__":
if train_successful: if train_successful:
<#if (config.rlAlgorithm == "dqn")> <#if (config.rlAlgorithm == "dqn")>
agent.save_best_network(qnet_creator._model_dir_ + qnet_creator._model_prefix_ + '_0_newest', epoch=0) agent.export_best_network(path=qnet_creator._model_dir_ + qnet_creator._model_prefix_ + '_0_newest', epoch=0)
<#else> <#else>
agent.save_best_network(actor_creator._model_dir_ + actor_creator._model_prefix_ + '_0_newest', epoch=0) agent.export_best_network(path=actor_creator._model_dir_ + actor_creator._model_prefix_ + '_0_newest', epoch=0)
</#if> </#if>
\ No newline at end of file
...@@ -127,7 +127,7 @@ class Agent(object): ...@@ -127,7 +127,7 @@ class Agent(object):
self._logger = None self._logger = None
self._environment.close() self._environment.close()
self._environment = None self._environment = None
self._save_net(self._best_net, 'best_net', session_dir) self._export_net(self._best_net, 'best_net', filedir=session_dir)
self._best_net = None self._best_net = None
def _make_config_dict(self): def _make_config_dict(self):
...@@ -258,26 +258,26 @@ class Agent(object): ...@@ -258,26 +258,26 @@ class Agent(object):
return self._target_score is not None\ return self._target_score is not None\
and avg_reward > self._target_score and avg_reward > self._target_score
def _save_parameters(self, net, episode=None, filename='dqn-agent-params'): def _export_net(self, net, filename, filedir=None, episode=None):
assert self._output_directory assert self._output_directory
assert isinstance(net, gluon.HybridBlock) assert isinstance(net, gluon.HybridBlock)
make_directory_if_not_exist(self._output_directory) make_directory_if_not_exist(self._output_directory)
filedir = self._output_directory if filedir is None else filedir
filename = os.path.join(filedir, filename)
if(episode is not None): if(episode is not None):
self._logger.info( self._logger.debug(
'Saving model parameters after episode %d' % episode) 'Saving model parameters after episode %d' % episode)
filename = filename + '-ep{}'.format(episode) filename = filename + '-ep{}'.format(episode)
else: else:
self._logger.info('Saving model parameters') self._logger.debug('Saving model parameters')
self._save_net(net, filename)
def _save_net(self, net, filename, filedir=None):
filedir = self._output_directory if filedir is None else filedir
filename = os.path.join(filedir, filename)
net.save_parameters(filename + '.params') net.save_parameters(filename + '.params')
net.export(filename, epoch=0) net.export(filename, epoch=0)
def save_best_network(self, path, epoch=0): def export_best_network(self, path=None, epoch=0):
path = os.path.join(self._output_directory, 'best_network')\
if path is None else path
self._logger.info( self._logger.info(
'Saving best network with average reward of {}'.format( 'Saving best network with average reward of {}'.format(
self._best_avg_score)) self._best_avg_score))
...@@ -373,15 +373,17 @@ class DdpgAgent(Agent): ...@@ -373,15 +373,17 @@ class DdpgAgent(Agent):
def _make_pickle_ready(self, session_dir): def _make_pickle_ready(self, session_dir):
super(DdpgAgent, self)._make_pickle_ready(session_dir) super(DdpgAgent, self)._make_pickle_ready(session_dir)
self._save_net(self._actor, 'current_actor') self._export_net(self._actor, 'current_actor')
self._save_net(self._actor, 'actor', session_dir) self._export_net(self._actor, 'actor', filedir=session_dir)
self._actor = None self._actor = None
self._save_net(self._critic, 'critic', session_dir) self._export_net(self._critic, 'critic', filedir=session_dir)
self._critic = None self._critic = None
self._save_net(self._actor_target, 'actor_target', session_dir) self._export_net(
self._actor_target, 'actor_target', filedir=session_dir)
self._actor_target = None self._actor_target = None
self._save_net(self._critic_target, 'critic_target', session_dir) self._export_net(
self._critic_target, 'critic_target', filedir=session_dir)
self._critic_target = None self._critic_target = None
@classmethod @classmethod
...@@ -449,7 +451,8 @@ class DdpgAgent(Agent): ...@@ -449,7 +451,8 @@ class DdpgAgent(Agent):
return action[0].asnumpy() return action[0].asnumpy()
def save_parameters(self, episode): def save_parameters(self, episode):
self._save_parameters(self._actor, episode=episode) self._export_net(
self._actor, self._agent_name + '_actor', episode=episode)
def train(self, episodes=None): def train(self, episodes=None):
self.save_config_file() self.save_config_file()
...@@ -605,7 +608,7 @@ class DdpgAgent(Agent): ...@@ -605,7 +608,7 @@ class DdpgAgent(Agent):
self._evaluate() self._evaluate()
self.save_parameters(episode=self._current_episode) self.save_parameters(episode=self._current_episode)
self.save_best_network(os.path.join(self._output_directory, 'best')) self.export_best_network()
self._training_stats.save_stats(self._output_directory) self._training_stats.save_stats(self._output_directory)
self._logger.info('--------- Training finished ---------') self._logger.info('--------- Training finished ---------')
return True return True
...@@ -707,9 +710,10 @@ class TwinDelayedDdpgAgent(DdpgAgent): ...@@ -707,9 +710,10 @@ class TwinDelayedDdpgAgent(DdpgAgent):
def _make_pickle_ready(self, session_dir): def _make_pickle_ready(self, session_dir):
super(TwinDelayedDdpgAgent, self)._make_pickle_ready(session_dir) super(TwinDelayedDdpgAgent, self)._make_pickle_ready(session_dir)
self._save_net(self._critic2, 'critic2', session_dir) self._export_net(self._critic2, 'critic2', filedir=session_dir)
self._critic2 = None self._critic2 = None
self._save_net(self._critic2_target, 'critic2_target', session_dir) self._export_net(
self._critic2_target, 'critic2_target', filedir=session_dir)
self._critic2_target = None self._critic2_target = None
@classmethod @classmethod
...@@ -980,7 +984,7 @@ class TwinDelayedDdpgAgent(DdpgAgent): ...@@ -980,7 +984,7 @@ class TwinDelayedDdpgAgent(DdpgAgent):
self._evaluate() self._evaluate()
self.save_parameters(episode=self._current_episode) self.save_parameters(episode=self._current_episode)
self.save_best_network(os.path.join(self._output_directory, 'best')) self.export_best_network()
self._training_stats.save_stats(self._output_directory) self._training_stats.save_stats(self._output_directory)
self._logger.info('--------- Training finished ---------') self._logger.info('--------- Training finished ---------')
return True return True
...@@ -1092,10 +1096,10 @@ class DqnAgent(Agent): ...@@ -1092,10 +1096,10 @@ class DqnAgent(Agent):
def _make_pickle_ready(self, session_dir): def _make_pickle_ready(self, session_dir):
super(DqnAgent, self)._make_pickle_ready(session_dir) super(DqnAgent, self)._make_pickle_ready(session_dir)
self._save_net(self._qnet, 'current_qnet') self._export_net(self._qnet, 'current_qnet')
self._save_net(self._qnet, 'qnet', session_dir) self._export_net(self._qnet, 'qnet', filedir=session_dir)
self._qnet = None self._qnet = None
self._save_net(self._target_qnet, 'target_net', session_dir) self._export_net(self._target_qnet, 'target_net', filedir=session_dir)
self._target_qnet = None self._target_qnet = None
def get_q_values(self, state, with_best=False): def get_q_values(self, state, with_best=False):
...@@ -1237,7 +1241,7 @@ class DqnAgent(Agent): ...@@ -1237,7 +1241,7 @@ class DqnAgent(Agent):
self._evaluate() self._evaluate()
self.save_parameters(episode=self._current_episode) self.save_parameters(episode=self._current_episode)
self.save_best_network(os.path.join(self._output_directory, 'best')) self.export_best_network()
self._training_stats.save_stats(self._output_directory) self._training_stats.save_stats(self._output_directory)
self._logger.info('--------- Training finished ---------') self._logger.info('--------- Training finished ---------')
return True return True
...@@ -1253,7 +1257,8 @@ class DqnAgent(Agent): ...@@ -1253,7 +1257,8 @@ class DqnAgent(Agent):
return config return config
def save_parameters(self, episode): def save_parameters(self, episode):
self._save_parameters(self._qnet, episode=episode) self._export_net(
self._qnet, self._agent_name + '_qnet', episode=episode)
def _save_current_as_best_net(self): def _save_current_as_best_net(self):
self._best_net = copy_net( self._best_net = copy_net(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment