Skip to content
Snippets Groups Projects
Commit 9251bd58 authored by Nicola Gatto's avatar Nicola Gatto
Browse files

Save current network when interrupting the training

parent 4031ded3
Branches
No related tags found
1 merge request!17Shared code, updated for CNNArchLang, etc.
Pipeline #148706 failed
......@@ -273,8 +273,8 @@ class Agent(object):
def _save_net(self, net, filename, filedir=None):
filedir = self._output_directory if filedir is None else filedir
filename = os.path.join(filedir, filename + '.params')
net.save_parameters(filename)
filename = os.path.join(filedir, filename)
net.save_parameters(filename + '.params')
net.export(filename, epoch=0)
def save_best_network(self, path, epoch=0):
......@@ -373,6 +373,8 @@ class DdpgAgent(Agent):
def _make_pickle_ready(self, session_dir):
super(DdpgAgent, self)._make_pickle_ready(session_dir)
self._save_net(self._actor, 'current_actor')
self._save_net(self._actor, 'actor', session_dir)
self._actor = None
self._save_net(self._critic, 'critic', session_dir)
......@@ -738,6 +740,7 @@ class DqnAgent(Agent):
def _make_pickle_ready(self, session_dir):
super(DqnAgent, self)._make_pickle_ready(session_dir)
self._save_net(self._qnet, 'current_qnet')
self._save_net(self._qnet, 'qnet', session_dir)
self._qnet = None
self._save_net(self._target_qnet, 'target_net', session_dir)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment