Commit 7b2dabc2 authored by Julian Johannes Steinsberger-Dührßen's avatar Julian Johannes Steinsberger-Dührßen
Browse files

added new parameter to episodic memory, bug fixes

parent 22fa37ac
Pipeline #323805 failed with stage
in 28 seconds
......@@ -347,8 +347,8 @@ class LargeMemory(gluon.HybridBlock):
else:
self.query_network.add(gluon.nn.Dense(units=self.num_heads*size, activation=self.query_act, flatten=False))
return self.query_network
#EpisodicMemory layer
class EpisodicMemory(EpisodicReplayMemoryInterface):
def __init__(self,
......@@ -358,6 +358,7 @@ class EpisodicMemory(EpisodicReplayMemoryInterface):
replay_gradient_steps,
store_prob,
max_stored_samples,
memory_replacement_strategy,
use_replay,
query_net_dir,
query_net_prefix,
......@@ -368,7 +369,8 @@ class EpisodicMemory(EpisodicReplayMemoryInterface):
#Replay parameters
self.store_prob = store_prob
self.max_stored_samples = max_stored_samples
self.memory_replacement_strategy = memory_replacement_strategy
self.query_net_dir = query_net_dir
self.query_net_prefix = query_net_prefix
self.query_net_num_inputs = query_net_num_inputs
......@@ -383,64 +385,65 @@ class EpisodicMemory(EpisodicReplayMemoryInterface):
return [args, []]
def store_samples(self, data, y, query_network, store_prob, context):
num_pus = len(data)
sub_batch_sizes = [data[i][0][0].shape[0] for i in range(num_pus)]
num_inputs = len(data[0][0])
num_outputs = len(y)
mx_context = context[0]
if len(self.key_memory) == 0:
self.key_memory = nd.empty(0, ctx=mx.cpu())
self.value_memory = []
self.label_memory = []#nd.empty((num_outputs, 0), ctx=mx.cpu())
ind = [nd.sample_multinomial(store_prob, sub_batch_sizes[i]).as_in_context(mx_context) for i in range(num_pus)]
max_inds = [nd.max(ind[i]) for i in range(num_pus)]
if any(max_inds):
to_store_values = []
for i in range(num_inputs):
tmp_values = []
for j in range(0, num_pus):
if max_inds[j]:
if isinstance(tmp_values, list):
tmp_values = nd.contrib.boolean_mask(data[j][0][i].as_in_context(mx_context), ind[j])
else:
tmp_values = nd.concat(tmp_values, nd.contrib.boolean_mask(data[j][0][i].as_in_context(mx_context), ind[j]), dim=0)
to_store_values.append(tmp_values)
to_store_labels = []
for i in range(num_outputs):
tmp_labels = []
for j in range(0, num_pus):
if max_inds[j]:
if isinstance(tmp_labels, list):
tmp_labels = nd.contrib.boolean_mask(y[i][j].as_in_context(mx_context), ind[j])
else:
tmp_labels = nd.concat(tmp_labels, nd.contrib.boolean_mask(y[i][j].as_in_context(mx_context), ind[j]), dim=0)
to_store_labels.append(tmp_labels)
to_store_keys = query_network(*to_store_values[0:self.query_net_num_inputs])
if self.key_memory.shape[0] == 0:
self.key_memory = to_store_keys.as_in_context(mx.cpu())
if not (self.memory_replacement_strategy == "no_replacement" and self.max_stored_samples != -1 and self.key_memory.shape[0] >= self.max_stored_samples):
num_pus = len(data)
sub_batch_sizes = [data[i][0][0].shape[0] for i in range(num_pus)]
num_inputs = len(data[0][0])
num_outputs = len(y)
mx_context = context[0]
if len(self.key_memory) == 0:
self.key_memory = nd.empty(0, ctx=mx.cpu())
self.value_memory = []
self.label_memory = []#nd.empty((num_outputs, 0), ctx=mx.cpu())
ind = [nd.sample_multinomial(store_prob, sub_batch_sizes[i]).as_in_context(mx_context) for i in range(num_pus)]
max_inds = [nd.max(ind[i]) for i in range(num_pus)]
if any(max_inds):
to_store_values = []
for i in range(num_inputs):
self.value_memory.append(to_store_values[i].as_in_context(mx.cpu()))
tmp_values = []
for j in range(0, num_pus):
if max_inds[j]:
if isinstance(tmp_values, list):
tmp_values = nd.contrib.boolean_mask(data[j][0][i].as_in_context(mx_context), ind[j])
else:
tmp_values = nd.concat(tmp_values, nd.contrib.boolean_mask(data[j][0][i].as_in_context(mx_context), ind[j]), dim=0)
to_store_values.append(tmp_values)
to_store_labels = []
for i in range(num_outputs):
self.label_memory.append(to_store_labels[i].as_in_context(mx.cpu()))
elif self.max_stored_samples != -1 and self.key_memory.shape[0] >= self.max_stored_samples:
num_to_store = to_store_keys.shape[0]
self.key_memory = nd.concat(self.key_memory[num_to_store:], to_store_keys.as_in_context(mx.cpu()), dim=0)
for i in range(num_inputs):
self.value_memory[i] = nd.concat(self.value_memory[i][num_to_store:], to_store_values[i].as_in_context(mx.cpu()), dim=0)
for i in range(num_outputs):
self.label_memory[i] = nd.concat(self.label_memory[i][num_to_store:], to_store_labels[i].as_in_context(mx.cpu()), dim=1)
else:
self.key_memory = nd.concat(self.key_memory, to_store_keys.as_in_context(mx.cpu()), dim=0)
for i in range(num_inputs):
self.value_memory[i] = nd.concat(self.value_memory[i], to_store_values[i].as_in_context(mx.cpu()), dim=0)
for i in range(num_outputs):
self.label_memory[i] = nd.concat(self.label_memory[i], to_store_labels[i].as_in_context(mx.cpu()), dim=0)
tmp_labels = []
for j in range(0, num_pus):
if max_inds[j]:
if isinstance(tmp_labels, list):
tmp_labels = nd.contrib.boolean_mask(y[i][j].as_in_context(mx_context), ind[j])
else:
tmp_labels = nd.concat(tmp_labels, nd.contrib.boolean_mask(y[i][j].as_in_context(mx_context), ind[j]), dim=0)
to_store_labels.append(tmp_labels)
to_store_keys = query_network(*to_store_values[0:self.query_net_num_inputs])
if self.key_memory.shape[0] == 0:
self.key_memory = to_store_keys.as_in_context(mx.cpu())
for i in range(num_inputs):
self.value_memory.append(to_store_values[i].as_in_context(mx.cpu()))
for i in range(num_outputs):
self.label_memory.append(to_store_labels[i].as_in_context(mx.cpu()))
elif self.memory_replacement_strategy == "replace_oldest" and self.max_stored_samples != -1 and self.key_memory.shape[0] >= self.max_stored_samples:
num_to_store = to_store_keys.shape[0]
self.key_memory = nd.concat(self.key_memory[num_to_store:], to_store_keys.as_in_context(mx.cpu()), dim=0)
for i in range(num_inputs):
self.value_memory[i] = nd.concat(self.value_memory[i][num_to_store:], to_store_values[i].as_in_context(mx.cpu()), dim=0)
for i in range(num_outputs):
self.label_memory[i] = nd.concat(self.label_memory[i][num_to_store:], to_store_labels[i].as_in_context(mx.cpu()), dim=0)
else:
self.key_memory = nd.concat(self.key_memory, to_store_keys.as_in_context(mx.cpu()), dim=0)
for i in range(num_inputs):
self.value_memory[i] = nd.concat(self.value_memory[i], to_store_values[i].as_in_context(mx.cpu()), dim=0)
for i in range(num_outputs):
self.label_memory[i] = nd.concat(self.label_memory[i], to_store_labels[i].as_in_context(mx.cpu()), dim=0)
def sample_memory(self, batch_size):
num_stored_samples = self.key_memory.shape[0]
......
......@@ -5,6 +5,7 @@
<#assign replayGradientSteps = element.replayGradientSteps?c>
<#assign replayMemoryStoreProb = element.replayMemoryStoreProb?c>
<#assign maxStoredSamples = element.maxStoredSamples?c>
<#assign memoryReplacementStrategy = element.memoryReplacementStrategy>
<#assign useReplay = element.useReplay?string("True", "False")>
<#assign useLocalAdaption = element.useLocalAdaption?string("true", "false")>
<#assign localAdaptionK = element.localAdaptionK?c>
......@@ -15,7 +16,7 @@
<#if mode == "ARCHITECTURE_DEFINITION">
self.${element.name} = EpisodicMemory(replay_interval=${replayInterval}, replay_batch_size=${replayBatchSize}, replay_steps=${replaySteps},
replay_gradient_steps=${replayGradientSteps}, store_prob=${replayMemoryStoreProb},
max_stored_samples=${maxStoredSamples}, use_replay=${useReplay},
max_stored_samples=${maxStoredSamples}, memory_replacement_strategy="${memoryReplacementStrategy}", use_replay=${useReplay},
query_net_dir="${queryNetDir}/",
query_net_prefix="${queryNetPrefix}",
query_net_num_inputs=${queryNetNumInputs})
......
......@@ -346,8 +346,8 @@ class LargeMemory(gluon.HybridBlock):
else:
self.query_network.add(gluon.nn.Dense(units=self.num_heads*size, activation=self.query_act, flatten=False))
return self.query_network
#EpisodicMemory layer
class EpisodicMemory(EpisodicReplayMemoryInterface):
def __init__(self,
......@@ -357,6 +357,7 @@ class EpisodicMemory(EpisodicReplayMemoryInterface):
replay_gradient_steps,
store_prob,
max_stored_samples,
memory_replacement_strategy,
use_replay,
query_net_dir,
query_net_prefix,
......@@ -367,7 +368,8 @@ class EpisodicMemory(EpisodicReplayMemoryInterface):
#Replay parameters
self.store_prob = store_prob
self.max_stored_samples = max_stored_samples
self.memory_replacement_strategy = memory_replacement_strategy
self.query_net_dir = query_net_dir
self.query_net_prefix = query_net_prefix
self.query_net_num_inputs = query_net_num_inputs
......@@ -382,64 +384,65 @@ class EpisodicMemory(EpisodicReplayMemoryInterface):
return [args, []]
def store_samples(self, data, y, query_network, store_prob, context):
num_pus = len(data)
sub_batch_sizes = [data[i][0][0].shape[0] for i in range(num_pus)]
num_inputs = len(data[0][0])
num_outputs = len(y)
mx_context = context[0]
if len(self.key_memory) == 0:
self.key_memory = nd.empty(0, ctx=mx.cpu())
self.value_memory = []
self.label_memory = []#nd.empty((num_outputs, 0), ctx=mx.cpu())
ind = [nd.sample_multinomial(store_prob, sub_batch_sizes[i]).as_in_context(mx_context) for i in range(num_pus)]
max_inds = [nd.max(ind[i]) for i in range(num_pus)]
if any(max_inds):
to_store_values = []
for i in range(num_inputs):
tmp_values = []
for j in range(0, num_pus):
if max_inds[j]:
if isinstance(tmp_values, list):
tmp_values = nd.contrib.boolean_mask(data[j][0][i].as_in_context(mx_context), ind[j])
else:
tmp_values = nd.concat(tmp_values, nd.contrib.boolean_mask(data[j][0][i].as_in_context(mx_context), ind[j]), dim=0)
to_store_values.append(tmp_values)
to_store_labels = []
for i in range(num_outputs):
tmp_labels = []
for j in range(0, num_pus):
if max_inds[j]:
if isinstance(tmp_labels, list):
tmp_labels = nd.contrib.boolean_mask(y[i][j].as_in_context(mx_context), ind[j])
else:
tmp_labels = nd.concat(tmp_labels, nd.contrib.boolean_mask(y[i][j].as_in_context(mx_context), ind[j]), dim=0)
to_store_labels.append(tmp_labels)
to_store_keys = query_network(*to_store_values[0:self.query_net_num_inputs])
if self.key_memory.shape[0] == 0:
self.key_memory = to_store_keys.as_in_context(mx.cpu())
if not (self.memory_replacement_strategy == "no_replacement" and self.max_stored_samples != -1 and self.key_memory.shape[0] >= self.max_stored_samples):
num_pus = len(data)
sub_batch_sizes = [data[i][0][0].shape[0] for i in range(num_pus)]
num_inputs = len(data[0][0])
num_outputs = len(y)
mx_context = context[0]
if len(self.key_memory) == 0:
self.key_memory = nd.empty(0, ctx=mx.cpu())
self.value_memory = []
self.label_memory = []#nd.empty((num_outputs, 0), ctx=mx.cpu())
ind = [nd.sample_multinomial(store_prob, sub_batch_sizes[i]).as_in_context(mx_context) for i in range(num_pus)]
max_inds = [nd.max(ind[i]) for i in range(num_pus)]
if any(max_inds):
to_store_values = []
for i in range(num_inputs):
self.value_memory.append(to_store_values[i].as_in_context(mx.cpu()))
tmp_values = []
for j in range(0, num_pus):
if max_inds[j]:
if isinstance(tmp_values, list):
tmp_values = nd.contrib.boolean_mask(data[j][0][i].as_in_context(mx_context), ind[j])
else:
tmp_values = nd.concat(tmp_values, nd.contrib.boolean_mask(data[j][0][i].as_in_context(mx_context), ind[j]), dim=0)
to_store_values.append(tmp_values)
to_store_labels = []
for i in range(num_outputs):
self.label_memory.append(to_store_labels[i].as_in_context(mx.cpu()))
elif self.max_stored_samples != -1 and self.key_memory.shape[0] >= self.max_stored_samples:
num_to_store = to_store_keys.shape[0]
self.key_memory = nd.concat(self.key_memory[num_to_store:], to_store_keys.as_in_context(mx.cpu()), dim=0)
for i in range(num_inputs):
self.value_memory[i] = nd.concat(self.value_memory[i][num_to_store:], to_store_values[i].as_in_context(mx.cpu()), dim=0)
for i in range(num_outputs):
self.label_memory[i] = nd.concat(self.label_memory[i][num_to_store:], to_store_labels[i].as_in_context(mx.cpu()), dim=1)
else:
self.key_memory = nd.concat(self.key_memory, to_store_keys.as_in_context(mx.cpu()), dim=0)
for i in range(num_inputs):
self.value_memory[i] = nd.concat(self.value_memory[i], to_store_values[i].as_in_context(mx.cpu()), dim=0)
for i in range(num_outputs):
self.label_memory[i] = nd.concat(self.label_memory[i], to_store_labels[i].as_in_context(mx.cpu()), dim=0)
tmp_labels = []
for j in range(0, num_pus):
if max_inds[j]:
if isinstance(tmp_labels, list):
tmp_labels = nd.contrib.boolean_mask(y[i][j].as_in_context(mx_context), ind[j])
else:
tmp_labels = nd.concat(tmp_labels, nd.contrib.boolean_mask(y[i][j].as_in_context(mx_context), ind[j]), dim=0)
to_store_labels.append(tmp_labels)
to_store_keys = query_network(*to_store_values[0:self.query_net_num_inputs])
if self.key_memory.shape[0] == 0:
self.key_memory = to_store_keys.as_in_context(mx.cpu())
for i in range(num_inputs):
self.value_memory.append(to_store_values[i].as_in_context(mx.cpu()))
for i in range(num_outputs):
self.label_memory.append(to_store_labels[i].as_in_context(mx.cpu()))
elif self.memory_replacement_strategy == "replace_oldest" and self.max_stored_samples != -1 and self.key_memory.shape[0] >= self.max_stored_samples:
num_to_store = to_store_keys.shape[0]
self.key_memory = nd.concat(self.key_memory[num_to_store:], to_store_keys.as_in_context(mx.cpu()), dim=0)
for i in range(num_inputs):
self.value_memory[i] = nd.concat(self.value_memory[i][num_to_store:], to_store_values[i].as_in_context(mx.cpu()), dim=0)
for i in range(num_outputs):
self.label_memory[i] = nd.concat(self.label_memory[i][num_to_store:], to_store_labels[i].as_in_context(mx.cpu()), dim=0)
else:
self.key_memory = nd.concat(self.key_memory, to_store_keys.as_in_context(mx.cpu()), dim=0)
for i in range(num_inputs):
self.value_memory[i] = nd.concat(self.value_memory[i], to_store_values[i].as_in_context(mx.cpu()), dim=0)
for i in range(num_outputs):
self.label_memory[i] = nd.concat(self.label_memory[i], to_store_labels[i].as_in_context(mx.cpu()), dim=0)
def sample_memory(self, batch_size):
num_stored_samples = self.key_memory.shape[0]
......
......@@ -346,8 +346,8 @@ class LargeMemory(gluon.HybridBlock):
else:
self.query_network.add(gluon.nn.Dense(units=self.num_heads*size, activation=self.query_act, flatten=False))
return self.query_network
#EpisodicMemory layer
class EpisodicMemory(EpisodicReplayMemoryInterface):
def __init__(self,
......@@ -357,6 +357,7 @@ class EpisodicMemory(EpisodicReplayMemoryInterface):
replay_gradient_steps,
store_prob,
max_stored_samples,
memory_replacement_strategy,
use_replay,
query_net_dir,
query_net_prefix,
......@@ -367,7 +368,8 @@ class EpisodicMemory(EpisodicReplayMemoryInterface):
#Replay parameters
self.store_prob = store_prob
self.max_stored_samples = max_stored_samples
self.memory_replacement_strategy = memory_replacement_strategy
self.query_net_dir = query_net_dir
self.query_net_prefix = query_net_prefix
self.query_net_num_inputs = query_net_num_inputs
......@@ -382,64 +384,65 @@ class EpisodicMemory(EpisodicReplayMemoryInterface):
return [args, []]
def store_samples(self, data, y, query_network, store_prob, context):
num_pus = len(data)
sub_batch_sizes = [data[i][0][0].shape[0] for i in range(num_pus)]
num_inputs = len(data[0][0])
num_outputs = len(y)
mx_context = context[0]
if len(self.key_memory) == 0:
self.key_memory = nd.empty(0, ctx=mx.cpu())
self.value_memory = []
self.label_memory = []#nd.empty((num_outputs, 0), ctx=mx.cpu())
ind = [nd.sample_multinomial(store_prob, sub_batch_sizes[i]).as_in_context(mx_context) for i in range(num_pus)]
max_inds = [nd.max(ind[i]) for i in range(num_pus)]
if any(max_inds):
to_store_values = []
for i in range(num_inputs):
tmp_values = []
for j in range(0, num_pus):
if max_inds[j]:
if isinstance(tmp_values, list):
tmp_values = nd.contrib.boolean_mask(data[j][0][i].as_in_context(mx_context), ind[j])
else:
tmp_values = nd.concat(tmp_values, nd.contrib.boolean_mask(data[j][0][i].as_in_context(mx_context), ind[j]), dim=0)
to_store_values.append(tmp_values)
to_store_labels = []
for i in range(num_outputs):
tmp_labels = []
for j in range(0, num_pus):
if max_inds[j]:
if isinstance(tmp_labels, list):
tmp_labels = nd.contrib.boolean_mask(y[i][j].as_in_context(mx_context), ind[j])
else:
tmp_labels = nd.concat(tmp_labels, nd.contrib.boolean_mask(y[i][j].as_in_context(mx_context), ind[j]), dim=0)
to_store_labels.append(tmp_labels)
to_store_keys = query_network(*to_store_values[0:self.query_net_num_inputs])
if self.key_memory.shape[0] == 0:
self.key_memory = to_store_keys.as_in_context(mx.cpu())
if not (self.memory_replacement_strategy == "no_replacement" and self.max_stored_samples != -1 and self.key_memory.shape[0] >= self.max_stored_samples):
num_pus = len(data)
sub_batch_sizes = [data[i][0][0].shape[0] for i in range(num_pus)]
num_inputs = len(data[0][0])
num_outputs = len(y)
mx_context = context[0]
if len(self.key_memory) == 0:
self.key_memory = nd.empty(0, ctx=mx.cpu())
self.value_memory = []
self.label_memory = []#nd.empty((num_outputs, 0), ctx=mx.cpu())
ind = [nd.sample_multinomial(store_prob, sub_batch_sizes[i]).as_in_context(mx_context) for i in range(num_pus)]
max_inds = [nd.max(ind[i]) for i in range(num_pus)]
if any(max_inds):
to_store_values = []
for i in range(num_inputs):
self.value_memory.append(to_store_values[i].as_in_context(mx.cpu()))
tmp_values = []
for j in range(0, num_pus):
if max_inds[j]:
if isinstance(tmp_values, list):
tmp_values = nd.contrib.boolean_mask(data[j][0][i].as_in_context(mx_context), ind[j])
else:
tmp_values = nd.concat(tmp_values, nd.contrib.boolean_mask(data[j][0][i].as_in_context(mx_context), ind[j]), dim=0)
to_store_values.append(tmp_values)
to_store_labels = []
for i in range(num_outputs):
self.label_memory.append(to_store_labels[i].as_in_context(mx.cpu()))
elif self.max_stored_samples != -1 and self.key_memory.shape[0] >= self.max_stored_samples:
num_to_store = to_store_keys.shape[0]
self.key_memory = nd.concat(self.key_memory[num_to_store:], to_store_keys.as_in_context(mx.cpu()), dim=0)
for i in range(num_inputs):
self.value_memory[i] = nd.concat(self.value_memory[i][num_to_store:], to_store_values[i].as_in_context(mx.cpu()), dim=0)
for i in range(num_outputs):
self.label_memory[i] = nd.concat(self.label_memory[i][num_to_store:], to_store_labels[i].as_in_context(mx.cpu()), dim=1)
else:
self.key_memory = nd.concat(self.key_memory, to_store_keys.as_in_context(mx.cpu()), dim=0)
for i in range(num_inputs):
self.value_memory[i] = nd.concat(self.value_memory[i], to_store_values[i].as_in_context(mx.cpu()), dim=0)
for i in range(num_outputs):
self.label_memory[i] = nd.concat(self.label_memory[i], to_store_labels[i].as_in_context(mx.cpu()), dim=0)
tmp_labels = []
for j in range(0, num_pus):
if max_inds[j]:
if isinstance(tmp_labels, list):
tmp_labels = nd.contrib.boolean_mask(y[i][j].as_in_context(mx_context), ind[j])
else:
tmp_labels = nd.concat(tmp_labels, nd.contrib.boolean_mask(y[i][j].as_in_context(mx_context), ind[j]), dim=0)
to_store_labels.append(tmp_labels)
to_store_keys = query_network(*to_store_values[0:self.query_net_num_inputs])
if self.key_memory.shape[0] == 0:
self.key_memory = to_store_keys.as_in_context(mx.cpu())
for i in range(num_inputs):
self.value_memory.append(to_store_values[i].as_in_context(mx.cpu()))
for i in range(num_outputs):
self.label_memory.append(to_store_labels[i].as_in_context(mx.cpu()))
elif self.memory_replacement_strategy == "replace_oldest" and self.max_stored_samples != -1 and self.key_memory.shape[0] >= self.max_stored_samples:
num_to_store = to_store_keys.shape[0]
self.key_memory = nd.concat(self.key_memory[num_to_store:], to_store_keys.as_in_context(mx.cpu()), dim=0)
for i in range(num_inputs):
self.value_memory[i] = nd.concat(self.value_memory[i][num_to_store:], to_store_values[i].as_in_context(mx.cpu()), dim=0)
for i in range(num_outputs):
self.label_memory[i] = nd.concat(self.label_memory[i][num_to_store:], to_store_labels[i].as_in_context(mx.cpu()), dim=0)
else:
self.key_memory = nd.concat(self.key_memory, to_store_keys.as_in_context(mx.cpu()), dim=0)
for i in range(num_inputs):
self.value_memory[i] = nd.concat(self.value_memory[i], to_store_values[i].as_in_context(mx.cpu()), dim=0)
for i in range(num_outputs):
self.label_memory[i] = nd.concat(self.label_memory[i], to_store_labels[i].as_in_context(mx.cpu()), dim=0)
def sample_memory(self, batch_size):
num_stored_samples = self.key_memory.shape[0]
......
......@@ -346,8 +346,8 @@ class LargeMemory(gluon.HybridBlock):
else:
self.query_network.add(gluon.nn.Dense(units=self.num_heads*size, activation=self.query_act, flatten=False))
return self.query_network
#EpisodicMemory layer
class EpisodicMemory(EpisodicReplayMemoryInterface):
def __init__(self,
......@@ -357,6 +357,7 @@ class EpisodicMemory(EpisodicReplayMemoryInterface):
replay_gradient_steps,
store_prob,
max_stored_samples,
memory_replacement_strategy,
use_replay,
query_net_dir,
query_net_prefix,
......@@ -367,7 +368,8 @@ class EpisodicMemory(EpisodicReplayMemoryInterface):
#Replay parameters
self.store_prob = store_prob
self.max_stored_samples = max_stored_samples
self.memory_replacement_strategy = memory_replacement_strategy
self.query_net_dir = query_net_dir
self.query_net_prefix = query_net_prefix
self.query_net_num_inputs = query_net_num_inputs
......@@ -382,64 +384,65 @@ class EpisodicMemory(EpisodicReplayMemoryInterface):
return [args, []]
def store_samples(self, data, y, query_network, store_prob, context):
num_pus = len(data)
sub_batch_sizes = [data[i][0][0].shape[0] for i in range(num_pus)]
num_inputs = len(data[0][0])
num_outputs = len(y)
mx_context = context[0]
if len(self.key_memory) == 0:
self.key_memory = nd.empty(0, ctx=mx.cpu())
self.value_memory = []
self.label_memory = []#nd.empty((num_outputs, 0), ctx=mx.cpu())
ind = [nd.sample_multinomial(store_prob, sub_batch_sizes[i]).as_in_context(mx_context) for i in range(num_pus)]
max_inds = [nd.max(ind[i]) for i in range(num_pus)]
if any(max_inds):
to_store_values = []
for i in range(num_inputs):
tmp_values = []
for j in range(0, num_pus):
if max_inds[j]:
if isinstance(tmp_values, list):
tmp_values = nd.contrib.boolean_mask(data[j][0][i].as_in_context(mx_context), ind[j])
else:
tmp_values = nd.concat(tmp_values, nd.contrib.boolean_mask(data[j][0][i].as_in_context(mx_context), ind[j]), dim=0)
to_store_values.append(tmp_values)
to_store_labels = []
for i in range(num_outputs):
tmp_labels = []
for j in range(0, num_pus):
if max_inds[j]:
if isinstance(tmp_labels, list):
tmp_labels = nd.contrib.boolean_mask(y[i][j].as_in_context(mx_context), ind[j])
else:
tmp_labels = nd.concat(tmp_labels, nd.contrib.boolean_mask(y[i][j].as_in_context(mx_context), ind[j]), dim=0)
to_store_labels.append(tmp_labels)
to_store_keys = query_network(*to_store_values[0:self.query_net_num_inputs])
if self.key_memory.shape[0] == 0:
self.key_memory = to_store_keys.as_in_context(mx.cpu())
if not (self.memory_replacement_strategy == "no_replacement" and self.max_stored_samples != -1 and self.key_memory.shape[0] >= self.max_stored_samples):
num_pus = len(data)
sub_batch_sizes = [data[i][0][0].shape[0] for i in range(num_pus)]
num_inputs = len(data[0][0])
num_outputs = len(y)
mx_context = context[0]
if len(self.key_memory) == 0:
self.key_memory = nd.empty(0, ctx=mx.cpu())
self.value_memory = []
self.label_memory = []#nd.empty((num_outputs, 0), ctx=mx.cpu())
ind = [nd.sample_multinomial(store_prob, sub_batch_sizes[i]).as_in_context(mx_context) for i in range(num_pus)]
max_inds = [nd.max(ind[i]) for i in range(num_pus)]
if any(max_inds):
to_store_values = []
for i in range(num_inputs):
self.value_memory.append(to_store_values[i].as_in_context(mx.cpu()))
tmp_values = []
for j in range(0, num_pus):