Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in
Toggle navigation
Menu
Open sidebar
monticore
EmbeddedMontiArc
generators
CNNArch2Gluon
Commits
a403480d
Commit
a403480d
authored
Jul 17, 2019
by
Nicola Gatto
Browse files
Refactor network export
parent
9a5ec6a1
Pipeline
#161472
failed with stages
Changes
2
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
src/main/resources/templates/gluon/reinforcement/Trainer.ftl
View file @
a403480d
...
@@ -170,7 +170,7 @@ if __name__ == "__main__":
...
@@ -170,7 +170,7 @@ if __name__ == "__main__":
if train_successful:
if train_successful:
<#if (config.rlAlgorithm == "dqn")>
<#if (config.rlAlgorithm == "dqn")>
agent.
save
_best_network(qnet_creator._model_dir_ + qnet_creator._model_prefix_ + '_0_newest', epoch=0)
agent.
export
_best_network(
path=
qnet_creator._model_dir_ + qnet_creator._model_prefix_ + '_0_newest', epoch=0)
<#else>
<#else>
agent.
save
_best_network(actor_creator._model_dir_ + actor_creator._model_prefix_ + '_0_newest', epoch=0)
agent.
export
_best_network(
path=
actor_creator._model_dir_ + actor_creator._model_prefix_ + '_0_newest', epoch=0)
</#if>
</#if>
\ No newline at end of file
src/main/resources/templates/gluon/reinforcement/agent/Agent.ftl
View file @
a403480d
...
@@ -127,7 +127,7 @@ class Agent(object):
...
@@ -127,7 +127,7 @@ class Agent(object):
self._logger = None
self._logger = None
self._environment.close()
self._environment.close()
self._environment = None
self._environment = None
self._
save
_net(self._best_net, 'best_net', session_dir)
self._
export
_net(self._best_net, 'best_net',
filedir=
session_dir)
self._best_net = None
self._best_net = None
def _make_config_dict(self):
def _make_config_dict(self):
...
@@ -258,26 +258,26 @@ class Agent(object):
...
@@ -258,26 +258,26 @@ class Agent(object):
return self._target_score is not None\
return self._target_score is not None\
and avg_reward > self._target_score
and avg_reward > self._target_score
def _
save_parameters(self, net, episode=None, filename='dqn-agent-params'
):
def _
export_net(self, net, filename, filedir=None, episode=None
):
assert self._output_directory
assert self._output_directory
assert isinstance(net, gluon.HybridBlock)
assert isinstance(net, gluon.HybridBlock)
make_directory_if_not_exist(self._output_directory)
make_directory_if_not_exist(self._output_directory)
filedir = self._output_directory if filedir is None else filedir
filename = os.path.join(filedir, filename)
if(episode is not None):
if(episode is not None):
self._logger.
info
(
self._logger.
debug
(
'Saving model parameters after episode %d' % episode)
'Saving model parameters after episode %d' % episode)
filename = filename + '-ep{}'.format(episode)
filename = filename + '-ep{}'.format(episode)
else:
else:
self._logger.info('Saving model parameters')
self._logger.debug('Saving model parameters')
self._save_net(net, filename)
def _save_net(self, net, filename, filedir=None):
filedir = self._output_directory if filedir is None else filedir
filename = os.path.join(filedir, filename)
net.save_parameters(filename + '.params')
net.save_parameters(filename + '.params')
net.export(filename, epoch=0)
net.export(filename, epoch=0)
def save_best_network(self, path, epoch=0):
def export_best_network(self, path=None, epoch=0):
path = os.path.join(self._output_directory, 'best_network')\
if path is None else path
self._logger.info(
self._logger.info(
'Saving best network with average reward of {}'.format(
'Saving best network with average reward of {}'.format(
self._best_avg_score))
self._best_avg_score))
...
@@ -373,15 +373,17 @@ class DdpgAgent(Agent):
...
@@ -373,15 +373,17 @@ class DdpgAgent(Agent):
def _make_pickle_ready(self, session_dir):
def _make_pickle_ready(self, session_dir):
super(DdpgAgent, self)._make_pickle_ready(session_dir)
super(DdpgAgent, self)._make_pickle_ready(session_dir)
self._
save
_net(self._actor, 'current_actor')
self._
export
_net(self._actor, 'current_actor')
self._
save
_net(self._actor, 'actor', session_dir)
self._
export
_net(self._actor, 'actor',
filedir=
session_dir)
self._actor = None
self._actor = None
self._
save
_net(self._critic, 'critic', session_dir)
self._
export
_net(self._critic, 'critic',
filedir=
session_dir)
self._critic = None
self._critic = None
self._save_net(self._actor_target, 'actor_target', session_dir)
self._export_net(
self._actor_target, 'actor_target', filedir=session_dir)
self._actor_target = None
self._actor_target = None
self._save_net(self._critic_target, 'critic_target', session_dir)
self._export_net(
self._critic_target, 'critic_target', filedir=session_dir)
self._critic_target = None
self._critic_target = None
@classmethod
@classmethod
...
@@ -449,7 +451,8 @@ class DdpgAgent(Agent):
...
@@ -449,7 +451,8 @@ class DdpgAgent(Agent):
return action[0].asnumpy()
return action[0].asnumpy()
def save_parameters(self, episode):
def save_parameters(self, episode):
self._save_parameters(self._actor, episode=episode)
self._export_net(
self._actor, self._agent_name + '_actor', episode=episode)
def train(self, episodes=None):
def train(self, episodes=None):
self.save_config_file()
self.save_config_file()
...
@@ -605,7 +608,7 @@ class DdpgAgent(Agent):
...
@@ -605,7 +608,7 @@ class DdpgAgent(Agent):
self._evaluate()
self._evaluate()
self.save_parameters(episode=self._current_episode)
self.save_parameters(episode=self._current_episode)
self.
save
_best_network(
os.path.join(self._output_directory, 'best')
)
self.
export
_best_network()
self._training_stats.save_stats(self._output_directory)
self._training_stats.save_stats(self._output_directory)
self._logger.info('--------- Training finished ---------')
self._logger.info('--------- Training finished ---------')
return True
return True
...
@@ -707,9 +710,10 @@ class TwinDelayedDdpgAgent(DdpgAgent):
...
@@ -707,9 +710,10 @@ class TwinDelayedDdpgAgent(DdpgAgent):
def _make_pickle_ready(self, session_dir):
def _make_pickle_ready(self, session_dir):
super(TwinDelayedDdpgAgent, self)._make_pickle_ready(session_dir)
super(TwinDelayedDdpgAgent, self)._make_pickle_ready(session_dir)
self._
save
_net(self._critic2, 'critic2', session_dir)
self._
export
_net(self._critic2, 'critic2',
filedir=
session_dir)
self._critic2 = None
self._critic2 = None
self._save_net(self._critic2_target, 'critic2_target', session_dir)
self._export_net(
self._critic2_target, 'critic2_target', filedir=session_dir)
self._critic2_target = None
self._critic2_target = None
@classmethod
@classmethod
...
@@ -980,7 +984,7 @@ class TwinDelayedDdpgAgent(DdpgAgent):
...
@@ -980,7 +984,7 @@ class TwinDelayedDdpgAgent(DdpgAgent):
self._evaluate()
self._evaluate()
self.save_parameters(episode=self._current_episode)
self.save_parameters(episode=self._current_episode)
self.
save
_best_network(
os.path.join(self._output_directory, 'best')
)
self.
export
_best_network()
self._training_stats.save_stats(self._output_directory)
self._training_stats.save_stats(self._output_directory)
self._logger.info('--------- Training finished ---------')
self._logger.info('--------- Training finished ---------')
return True
return True
...
@@ -1092,10 +1096,10 @@ class DqnAgent(Agent):
...
@@ -1092,10 +1096,10 @@ class DqnAgent(Agent):
def _make_pickle_ready(self, session_dir):
def _make_pickle_ready(self, session_dir):
super(DqnAgent, self)._make_pickle_ready(session_dir)
super(DqnAgent, self)._make_pickle_ready(session_dir)
self._
save
_net(self._qnet, 'current_qnet')
self._
export
_net(self._qnet, 'current_qnet')
self._
save
_net(self._qnet, 'qnet', session_dir)
self._
export
_net(self._qnet, 'qnet',
filedir=
session_dir)
self._qnet = None
self._qnet = None
self._
save
_net(self._target_qnet, 'target_net', session_dir)
self._
export
_net(self._target_qnet, 'target_net',
filedir=
session_dir)
self._target_qnet = None
self._target_qnet = None
def get_q_values(self, state, with_best=False):
def get_q_values(self, state, with_best=False):
...
@@ -1237,7 +1241,7 @@ class DqnAgent(Agent):
...
@@ -1237,7 +1241,7 @@ class DqnAgent(Agent):
self._evaluate()
self._evaluate()
self.save_parameters(episode=self._current_episode)
self.save_parameters(episode=self._current_episode)
self.
save
_best_network(
os.path.join(self._output_directory, 'best')
)
self.
export
_best_network()
self._training_stats.save_stats(self._output_directory)
self._training_stats.save_stats(self._output_directory)
self._logger.info('--------- Training finished ---------')
self._logger.info('--------- Training finished ---------')
return True
return True
...
@@ -1253,7 +1257,8 @@ class DqnAgent(Agent):
...
@@ -1253,7 +1257,8 @@ class DqnAgent(Agent):
return config
return config
def save_parameters(self, episode):
def save_parameters(self, episode):
self._save_parameters(self._qnet, episode=episode)
self._export_net(
self._qnet, self._agent_name + '_qnet', episode=episode)
def _save_current_as_best_net(self):
def _save_current_as_best_net(self):
self._best_net = copy_net(
self._best_net = copy_net(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment