diff --git a/src/main/resources/templates/gluon/reinforcement/agent/Agent.ftl b/src/main/resources/templates/gluon/reinforcement/agent/Agent.ftl index 33135b0f596b96943412b30807f84c2e7ab1eb79..fb3d3bfe4ff3bdd74d68102e44b00ef66c4aa7b7 100644 --- a/src/main/resources/templates/gluon/reinforcement/agent/Agent.ftl +++ b/src/main/resources/templates/gluon/reinforcement/agent/Agent.ftl @@ -273,8 +273,8 @@ class Agent(object): def _save_net(self, net, filename, filedir=None): filedir = self._output_directory if filedir is None else filedir - filename = os.path.join(filedir, filename + '.params') - net.save_parameters(filename) + filename = os.path.join(filedir, filename) + net.save_parameters(filename + '.params') net.export(filename, epoch=0) def save_best_network(self, path, epoch=0): @@ -373,6 +373,8 @@ class DdpgAgent(Agent): def _make_pickle_ready(self, session_dir): super(DdpgAgent, self)._make_pickle_ready(session_dir) + self._save_net(self._actor, 'current_actor') + self._save_net(self._actor, 'actor', session_dir) self._actor = None self._save_net(self._critic, 'critic', session_dir) @@ -738,6 +740,7 @@ class DqnAgent(Agent): def _make_pickle_ready(self, session_dir): super(DqnAgent, self)._make_pickle_ready(session_dir) + self._save_net(self._qnet, 'current_qnet') self._save_net(self._qnet, 'qnet', session_dir) self._qnet = None self._save_net(self._target_qnet, 'target_net', session_dir)