Commit 9251bd58 authored by Nicola Gatto's avatar Nicola Gatto
Browse files

Save current network when interrupting the training

parent 4031ded3
Pipeline #148706 failed with stages
in 1 minute and 54 seconds
......@@ -273,8 +273,8 @@ class Agent(object):
def _save_net(self, net, filename, filedir=None):
filedir = self._output_directory if filedir is None else filedir
filename = os.path.join(filedir, filename + '.params')
filename = os.path.join(filedir, filename)
net.save_parameters(filename + '.params')
net.export(filename, epoch=0)
def save_best_network(self, path, epoch=0):
......@@ -373,6 +373,8 @@ class DdpgAgent(Agent):
def _make_pickle_ready(self, session_dir):
super(DdpgAgent, self)._make_pickle_ready(session_dir)
self._save_net(self._actor, 'current_actor')
self._save_net(self._actor, 'actor', session_dir)
self._actor = None
self._save_net(self._critic, 'critic', session_dir)
......@@ -738,6 +740,7 @@ class DqnAgent(Agent):
def _make_pickle_ready(self, session_dir):
super(DqnAgent, self)._make_pickle_ready(session_dir)
self._save_net(self._qnet, 'current_qnet')
self._save_net(self._qnet, 'qnet', session_dir)
self._qnet = None
self._save_net(self._target_qnet, 'target_net', session_dir)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment