@@ -477,10 +482,22 @@ class EpisodicMemory(EpisodicReplayMemoryInterface):
return net
def save_memory(self, path):
mem_arr = [("keys", self.key_memory)] + [("labels", self.label_memory)] + [("values_"+str(k),v) for (k,v) in enumerate(self.value_memory)]
mem_arr = [("keys", self.key_memory)] + [("values_"+str(k),v) for (k,v) in enumerate(self.value_memory)] + [("labels_"+str(k),v) for (k,v) in enumerate(self.label_memory)]
mem_dict = {entry[0]:entry[1] for entry in mem_arr}
nd.save(path, mem_dict)
def load_memory(self, path):
mem_dict = nd.load(path)
self.value_memory = []
self.label_memory = []
for key in sorted(mem_dict.keys()):
if key == "keys":
self.key_memory = mem_dict[key]
elif key.startswith("values_"):
self.value_memory.append(mem_dict[key])
elif key.startswith("labels_"):
self.label_memory.append(mem_dict[key])
<#listtc.architecture.networkInstructions as networkInstruction>