Commit a7b7e15f authored by Julian Johannes Steinsberger-Dührßen's avatar Julian Johannes Steinsberger-Dührßen
Browse files

test fixes

parent 44e12f6d
Pipeline #346441 failed with stage
in 3 minutes and 2 seconds
......@@ -336,16 +336,12 @@ public class EMADLGenerator {
public List<FileContent> generateStrings(TaggingResolver taggingResolver, EMAComponentInstanceSymbol componentInstanceSymbol, Scope symtab, Set<EMAComponentInstanceSymbol> allInstances, String forced){
List<FileContent> fileContents = new ArrayList<>();
processedArchitecture = new HashMap<>();
generateComponent(fileContents, allInstances, taggingResolver, componentInstanceSymbol, symtab);
String instanceName = componentInstanceSymbol.getComponentType().getFullName().replaceAll("\\.", "_");
fileContents.addAll(generateCNNTrainer(allInstances, instanceName));
fileContents.add(ArmadilloHelper.getArmadilloHelperFileContent());
fileContents.addAll(generateCNNTrainer(allInstances, instanceName));fileContents.add(ArmadilloHelper.getArmadilloHelperFileContent());
TypesGeneratorCPP tg = new TypesGeneratorCPP();
fileContents.addAll(tg.generateTypes(TypeConverter.getTypeSymbols()));
if (cnnArchGenerator.isCMakeRequired()) {
cnnArchGenerator.setGenerationTargetPath(getGenerationTargetPath());
Map<String, String> cmakeContentsMap = cnnArchGenerator.generateCMakeContent(componentInstanceSymbol.getFullName());
......
......@@ -109,7 +109,7 @@ public class GenerationTest extends AbstractSymtabTest {
@Test
public void testEpisodicMemorySimpleGeneration() throws IOException, TemplateException {
Log.getFindings().clear();
String[] args = {"-m", "src/test/resources/models", "-r", "episodicMemorySimple.Network", "-b", "GLUON", "-f", "n", "-c", "n"};
String[] args = {"-m", "src/test/resources/models", "-r", "episodicMemorySimple.Network", "-b", "GLUON", "-f", "n", "-c", "n", "-p", "/usr/bin/python3"};
EMADLGeneratorCli.main(args);
}
......@@ -177,7 +177,6 @@ public class GenerationTest extends AbstractSymtabTest {
String[] args = {"-m", "src/test/resources/models/", "-r", "mnist.MnistClassifier", "-b", "GLUON", "-f", "n", "-c", "n"};
EMADLGeneratorCli.main(args);
assertTrue(Log.getFindings().isEmpty());
checkFilesAreEqual(
Paths.get("./target/generated-sources-emadl"),
Paths.get("./src/test/resources/target_code/gluon"),
......@@ -257,7 +256,6 @@ public class GenerationTest extends AbstractSymtabTest {
String[] args = {"-m", "src/test/resources/models/reinforcementModel", "-r", "mountaincar.Master", "-b", "GLUON", "-f", "n", "-c", "n"};
EMADLGeneratorCli.main(args);
assertEquals(0, Log.getFindings().stream().filter(Finding::isError).count());
checkFilesAreEqual(
Paths.get("./target/generated-sources-emadl"),
Paths.get("./src/test/resources/target_code/gluon/reinforcementModel/mountaincar"),
......
......@@ -76,7 +76,7 @@ public class IntegrationGluonTest extends IntegrationTest {
deleteHashFile(Paths.get("./target/generated-sources-emadl/episodicMemorySimple/episodicMemorySimple.training_hash"));
String[] args = {"-m", "src/test/resources/models", "-r", "episodicMemorySimple.Network", "-b", "GLUON"};
String[] args = {"-m", "src/test/resources/models", "-r", "episodicMemorySimple.Network", "-b", "GLUON", "-f", "y", "-p", "/usr/bin/python3"};
EMADLGeneratorCli.main(args);
}
......
......@@ -21,7 +21,6 @@ public class IntegrationPythonWrapperTest extends AbstractSymtabTest {
String[] args = {"-m", "src/test/resources/models/reinforcementModel", "-r", "torcs.agent.TorcsAgent", "-b", "GLUON", "-f", "n", "-c", "n"};
EMADLGeneratorCli.main(args);
assertTrue(Log.getFindings().stream().filter(Finding::isError).collect(Collectors.toList()).isEmpty());
checkFilesAreEqual(
Paths.get("./target/generated-sources-emadl"),
Paths.get("./src/test/resources/target_code/gluon/reinforcementModel/torcs"),
......@@ -76,7 +75,6 @@ public class IntegrationPythonWrapperTest extends AbstractSymtabTest {
String[] args = {"-m", "src/test/resources/models/reinforcementModel/torcs_td3", "-r", "torcs.agent.TorcsAgent", "-b", "GLUON", "-f", "n", "-c", "n"};
EMADLGeneratorCli.main(args);
assertTrue(Log.getFindings().stream().filter(Finding::isError).collect(Collectors.toList()).isEmpty());
checkFilesAreEqual(
Paths.get("./target/generated-sources-emadl"),
Paths.get("./src/test/resources/target_code/gluon/reinforcementModel/torcs_td3"),
......
......@@ -7,7 +7,7 @@ component Network{
implementation CNN {
data ->
EpisodicMemory(replayInterval=10, replayBatchSize=100, replaySteps=1, replayGradientSteps=1, replayMemoryStoreProb=0.5, localAdaptionGradientSteps=30, maxStoredSamples=-1, localAdaptionK=32, queryNetDir="tag:simple", queryNetPrefix="simple_embedding-", queryNetNumInputs=1) ->
EpisodicMemory(replayInterval=10, replayBatchSize=100, replaySteps=1, replayGradientSteps=1, memoryStoreProb=1, maxStoredSamples=-1, memoryReplacementStrategy="no_replacement", localAdaptationK=4, localAdaptationGradientSteps=2, queryNetDir="tag:simple", queryNetPrefix="simple_embedding-", queryNetNumInputs=1) ->
LoadNetwork(networkDir="tag:simple", networkPrefix="simple_embedding-", numInputs=1, outputShape=(1,768)) ->
FullyConnected(units=33) ->
Softmax() ->
......
......@@ -146,11 +146,21 @@ class DotProductSelfAttention(gluon.HybridBlock):
score = F.batch_dot(head_queries, head_keys, transpose_b=True)
score = score * self.scale_factor
if self.use_mask:
mask = F.tile(mask, self.num_heads)
mask = F.repeat(mask, self.dim_model)
mask = F.reshape(mask, shape=(-1, self.dim_model))
weights = F.softmax(score, mask, use_length=self.use_mask)
seqs = F.contrib.arange_like(score, axis=1)
zeros = F.zeros_like(seqs)
zeros = F.reshape(zeros, shape=(1, -1))
mask = args[0]
mask = F.reshape(mask, shape=(-1, 1))
mask = F.broadcast_add(mask, zeros)
mask = F.expand_dims(mask, axis=1)
mask = F.broadcast_axis(mask, axis=1, size=self.num_heads)
mask = mask.reshape(shape=(-1, 0), reverse=True)
mask = F.cast(mask, dtype='int32')
weights = F.softmax(score, mask, use_length=self.use_mask)
else:
weights = F.softmax(score)
head_values = F.reshape(head_values, shape=(0, 0, self.num_heads, -1))
head_values = F.transpose(head_values, axes=(0,2,1,3))
......@@ -169,7 +179,7 @@ class DotProductSelfAttention(gluon.HybridBlock):
class EpisodicReplayMemoryInterface(gluon.HybridBlock):
__metaclass__ = abc.ABCMeta
def __init__(self, use_replay, replay_interval, replay_batch_size, replay_steps, replay_gradient_steps, num_heads, **kwargs):
def __init__(self, use_replay, replay_interval, replay_batch_size, replay_steps, replay_gradient_steps, use_local_adaptation, local_adaptation_gradient_steps, k, **kwargs):
super(EpisodicReplayMemoryInterface, self).__init__(**kwargs)
self.use_replay = use_replay
......@@ -177,7 +187,10 @@ class EpisodicReplayMemoryInterface(gluon.HybridBlock):
self.replay_batch_size = replay_batch_size
self.replay_steps = replay_steps
self.replay_gradient_steps = replay_gradient_steps
self.num_heads = num_heads
self.use_local_adaptation = use_local_adaptation
self.local_adaptation_gradient_steps = local_adaptation_gradient_steps
self.k = k
@abc.abstractmethod
def store_samples(self, data, y, query_network, store_prob, mx_context):
......@@ -187,6 +200,10 @@ class EpisodicReplayMemoryInterface(gluon.HybridBlock):
def sample_memory(self, batch_size, mx_context):
pass
@abc.abstractmethod
def sample_neighbours(self, data, query_network):
pass
@abc.abstractmethod
def get_query_network(self, mx_context):
pass
......@@ -205,7 +222,6 @@ class LargeMemory(gluon.HybridBlock):
sub_key_size,
query_size,
query_act,
dist_measure,
k,
num_heads,
values_dim,
......@@ -213,7 +229,6 @@ class LargeMemory(gluon.HybridBlock):
super(LargeMemory, self).__init__(**kwargs)
with self.name_scope():
#Memory parameters
self.dist_measure = dist_measure
self.k = k
self.num_heads = num_heads
self.query_act = query_act
......@@ -250,46 +265,25 @@ class LargeMemory(gluon.HybridBlock):
q_split = F.split(q, num_outputs=2, axis=-1)
if self.dist_measure == "l2":
q_split_resh = F.reshape(q_split[0], shape=(0,0,1,-1))
sub_keys1_resh = F.reshape(sub_keys1, shape=(1,0,0,-1), reverse=True)
q1_diff = F.broadcast_sub(q_split_resh, sub_keys1_resh)
q1_dist = F.norm(q1_diff, axis=-1)
q_split_resh = F.reshape(q_split[1], shape=(0,0,1,-1))
sub_keys2_resh = F.reshape(sub_keys2, shape=(1,0,0,-1), reverse=True)
q2_diff = F.broadcast_sub(q_split_resh, sub_keys2_resh)
q2_dist = F.norm(q2_diff, axis=-1)
else:
q1 = F.split(q_split[0], num_outputs=self.num_heads, axis=1)
q2 = F.split(q_split[1], num_outputs=self.num_heads, axis=1)
sub_keys1_resh = F.split(sub_keys1, num_outputs=self.num_heads, axis=0, squeeze_axis=True)
sub_keys2_resh = F.split(sub_keys2, num_outputs=self.num_heads, axis=0, squeeze_axis=True)
if self.num_heads == 1:
q1 = [q1]
q2 = [q2]
sub_keys1_resh = [sub_keys1_resh ]
sub_keys2_resh = [sub_keys2_resh ]
q1_dist = F.dot(q1[0], sub_keys1_resh[0], transpose_b=True)
q2_dist = F.dot(q2[0], sub_keys2_resh[0], transpose_b=True)
for h in range(1, self.num_heads):
q1_dist = F.concat(q1_dist, F.dot(q1[0], sub_keys1_resh[h], transpose_b=True), dim=1)
q2_dist = F.concat(q2_dist, F.dot(q2[0], sub_keys1_resh[h], transpose_b=True), dim=1)
q1 = F.split(q_split[0], num_outputs=self.num_heads, axis=1)
q2 = F.split(q_split[1], num_outputs=self.num_heads, axis=1)
sub_keys1_resh = F.split(sub_keys1, num_outputs=self.num_heads, axis=0, squeeze_axis=True)
sub_keys2_resh = F.split(sub_keys2, num_outputs=self.num_heads, axis=0, squeeze_axis=True)
if self.num_heads == 1:
q1 = [q1]
q2 = [q2]
sub_keys1_resh = [sub_keys1_resh ]
sub_keys2_resh = [sub_keys2_resh ]
q1_dist = F.dot(q1[0], sub_keys1_resh[0], transpose_b=True)
q2_dist = F.dot(q2[0], sub_keys2_resh[0], transpose_b=True)
for h in range(1, self.num_heads):
q1_dist = F.concat(q1_dist, F.dot(q1[0], sub_keys1_resh[h], transpose_b=True), dim=1)
q2_dist = F.concat(q2_dist, F.dot(q2[0], sub_keys1_resh[h], transpose_b=True), dim=1)
i1 = F.topk(q1_dist, k=self.k, ret_typ="indices")
i2 = F.topk(q2_dist, k=self.k, ret_typ="indices")
# Calculate cross product for keys at indices I1 and I2
# def head_take(data, state):
# return [F.take(data[0], data[2]), F.take(data[1], data[3])], state,
#
# i1 = F.transpose(i1, axes=(1,0,2))
# i2 = F.transpose(i2, axes=(1, 0, 2))
# st = F.zeros(1)
# (k1, k2), _ = F.contrib.foreach(head_take, [sub_keys1, sub_keys2,i1,i2], st)
# k1 = F.reshape(k1, shape=(-1, 0, 0), reverse=True)
# k2 = F.reshape(k2, shape=(-1, 0, 0), reverse=True)
i1 = F.split(i1, num_outputs=self.num_heads, axis=1)
i2 = F.split(i2, num_outputs=self.num_heads, axis=1)
sub_keys1 = F.split(sub_keys1, num_outputs=self.num_heads, axis=0, squeeze_axis=True)
......@@ -313,12 +307,9 @@ class LargeMemory(gluon.HybridBlock):
q = F.reshape(q, shape=(-1,0), reverse=True)
q = F.reshape(q, shape=(0, 1, -1))
c_cart = F.reshape(c_cart, shape=(-1, 0, 0), reverse=True)
if self.dist_measure == "l2":
k_diff = F.broadcast_sub(q, c_cart)
k_dist = F.norm(k_diff, axis=-1)
else:
k_dist = F.batch_dot(q, c_cart, transpose_b=True) #F.contrib.foreach(loop_batch_dot, [q, c_cart], init_states=state_batch_dist)
k_dist = F.reshape(k_dist, shape=(0, -1))
k_dist = F.batch_dot(q, c_cart, transpose_b=True) #F.contrib.foreach(loop_batch_dot, [q, c_cart], init_states=state_batch_dist)
k_dist = F.reshape(k_dist, shape=(0, -1))
i = F.topk(k_dist, k=self.k, ret_typ="both")
......@@ -359,11 +350,14 @@ class EpisodicMemory(EpisodicReplayMemoryInterface):
max_stored_samples,
memory_replacement_strategy,
use_replay,
use_local_adaptation,
local_adaptation_gradient_steps,
k,
query_net_dir,
query_net_prefix,
query_net_num_inputs,
**kwargs):
super(EpisodicMemory, self).__init__(use_replay, replay_interval, replay_batch_size, replay_steps, replay_gradient_steps, 1, **kwargs)
super(EpisodicMemory, self).__init__(use_replay, replay_interval, replay_batch_size, replay_steps, replay_gradient_steps, use_local_adaptation, local_adaptation_gradient_steps, k, **kwargs)
with self.name_scope():
#Replay parameters
self.store_prob = store_prob
......@@ -458,6 +452,29 @@ class EpisodicMemory(EpisodicReplayMemoryInterface):
return sample_batches
def sample_neighbours(self, data, query_network):
num_stored_samples = self.key_memory.shape[0]
batch_size = data[0].shape[0]
query = query_network(*data).as_in_context(mx.cpu())
vec1 = nd.repeat(query, repeats=num_stored_samples, axis=0)
vec2 = nd.tile(self.key_memory, reps=(batch_size, 1))
diff = nd.subtract(vec1, vec2)
sq = nd.square(diff)
batch_sum = nd.sum(sq, exclude=1, axis=0)
sqrt = nd.sqrt(batch_sum)
dist = nd.reshape(sqrt, shape=(batch_size, num_stored_samples))
sample_ind = nd.topk(dist, k=self.k, axis=1, ret_typ="indices")
num_outputs = len(self.label_memory)
sample_labels = [self.label_memory[i][sample_ind] for i in range(num_outputs)]
sample_batches = [[self.value_memory[j][sample_ind] for j in range(len(self.value_memory))], sample_labels]
return sample_batches
def get_query_network(self, context):
lastEpoch = 0
for file in os.listdir(self.query_net_dir):
......
......@@ -110,6 +110,37 @@ class SoftmaxCrossEntropyLossIgnoreLabel(gluon.loss.Loss):
loss = gluon.loss._apply_weighting(F, loss, self._weight, sample_weight)
return F.sum(loss) / F.sum(valid_label_map)
class LocalAdaptationLoss(gluon.loss.Loss):
def __init__(self, lamb, axis=-1, sparse_label=True, weight=None, batch_axis=0, **kwargs):
super(LocalAdaptationLoss, self).__init__(weight, batch_axis, **kwargs)
self.lamb = lamb
self._axis = axis
self._sparse_label = sparse_label
def hybrid_forward(self, F, pred, label, curr_weights, base_weights, sample_weight=None):
pred = F.log(pred)
if self._sparse_label:
cross_entr_loss = -F.pick(pred, label, axis=self._axis, keepdims=True)
else:
label = gluon.loss._reshape_like(F, label, pred)
cross_entr_loss = -F.sum(pred * label, axis=self._axis, keepdims=True)
cross_entr_loss = F.mean(cross_entr_loss, axis=self._batch_axis, exclude=True)
weight_diff_loss = 0
for param_key in base_weights:
weight_diff_loss = F.add(weight_diff_loss, F.norm(curr_weights[param_key] - base_weights[param_key]))
#this check is neccessary, otherwise if weight_diff_loss is zero (first training iteration)
#the trainer would update the networks weights to nan, this must have somthing to do how
#mxnet internally calculates the derivatives / tracks the weights
if weight_diff_loss > 0:
loss = self.lamb * weight_diff_loss + cross_entr_loss
loss = gluon.loss._apply_weighting(F, loss, self._weight, sample_weight)
else:
loss = gluon.loss._apply_weighting(F, cross_entr_loss, self._weight, sample_weight)
return loss
@mx.metric.register
class ACCURACY_IGNORE_LABEL(mx.metric.EvalMetric):
"""Ignores a label when computing accuracy.
......@@ -479,13 +510,16 @@ class CNNSupervisedTrainer_mnist_mnistClassifier_net:
global_loss_train /= (train_batches * batch_size)
tic = None
if eval_train:
train_iter.batch_size = single_pu_batch_size
train_iter.reset()
metric = mx.metric.create(eval_metric, **eval_metric_params)
for batch_i, batch in enumerate(train_iter):
labels = [gluon.utils.split_and_load(batch.label[i], ctx_list=mx_context, even_split=False)[0] for i in range(1)]
image_ = gluon.utils.split_and_load(batch.data[0], ctx_list=mx_context, even_split=False)[0]
labels = [batch.label[i].as_in_context(mx_context[0]) for i in range(1)]
image_ = batch.data[0].as_in_context(mx_context[0])
predictions_ = mx.nd.zeros((single_pu_batch_size, 10,), ctx=mx_context[0])
......@@ -500,9 +534,7 @@ class CNNSupervisedTrainer_mnist_mnistClassifier_net:
predictions_ = net_ret[0][0]
outputs.append(predictions_)
lossList.append(loss_function(predictions_, labels[0]))
if save_attention_image == "True":
import matplotlib
matplotlib.use('Agg')
......@@ -518,7 +550,7 @@ class CNNSupervisedTrainer_mnist_mnistClassifier_net:
max_length = len(labels)-1
ax = fig.add_subplot(max_length//3, max_length//4, 1)
ax.imshow(train_images[0+batch_size*(batch_i)].transpose(1,2,0))
ax.imshow(train_images[0+single_pu_batch_size*(batch_i)].transpose(1,2,0))
for l in range(max_length):
attention = attentionList[l]
......@@ -529,12 +561,12 @@ class CNNSupervisedTrainer_mnist_mnistClassifier_net:
ax.set_title("<unk>")
elif dict[int(labels[l+1][0].asscalar())] == "<end>":
ax.set_title(".")
img = ax.imshow(train_images[0+batch_size*(batch_i)].transpose(1,2,0))
img = ax.imshow(train_images[0+single_pu_batch_size*(batch_i)].transpose(1,2,0))
ax.imshow(attention_resized, cmap='gray', alpha=0.6, extent=img.get_extent())
break
else:
ax.set_title(dict[int(labels[l+1][0].asscalar())])
img = ax.imshow(train_images[0+batch_size*(batch_i)].transpose(1,2,0))
img = ax.imshow(train_images[0+single_pu_batch_size*(batch_i)].transpose(1,2,0))
ax.imshow(attention_resized, cmap='gray', alpha=0.6, extent=img.get_extent())
plt.tight_layout()
......@@ -558,13 +590,15 @@ class CNNSupervisedTrainer_mnist_mnistClassifier_net:
global_loss_test = 0.0
test_batches = 0
test_iter.batch_size = single_pu_batch_size
test_iter.reset()
metric = mx.metric.create(eval_metric, **eval_metric_params)
for batch_i, batch in enumerate(test_iter):
if True:
labels = [gluon.utils.split_and_load(batch.label[i], ctx_list=mx_context, even_split=False)[0] for i in range(1)]
image_ = gluon.utils.split_and_load(batch.data[0], ctx_list=mx_context, even_split=False)[0]
labels = [batch.label[i].as_in_context(mx_context[0]) for i in range(1)]
image_ = batch.data[0].as_in_context(mx_context[0])
predictions_ = mx.nd.zeros((single_pu_batch_size, 10,), ctx=mx_context[0])
......@@ -580,8 +614,6 @@ class CNNSupervisedTrainer_mnist_mnistClassifier_net:
outputs.append(predictions_)
lossList.append(loss_function(predictions_, labels[0]))
if save_attention_image == "True":
if not eval_train:
import matplotlib
......@@ -598,7 +630,7 @@ class CNNSupervisedTrainer_mnist_mnistClassifier_net:
max_length = len(labels)-1
ax = fig.add_subplot(max_length//3, max_length//4, 1)
ax.imshow(test_images[0+batch_size*(batch_i)].transpose(1,2,0))
ax.imshow(test_images[0+single_pu_batch_size*(batch_i)].transpose(1,2,0))
for l in range(max_length):
attention = attentionList[l]
......@@ -609,12 +641,12 @@ class CNNSupervisedTrainer_mnist_mnistClassifier_net:
ax.set_title("<unk>")
elif dict[int(mx.nd.slice_axis(outputs[l+1], axis=0, begin=0, end=1).squeeze().asscalar())] == "<end>":
ax.set_title(".")
img = ax.imshow(test_images[0+batch_size*(batch_i)].transpose(1,2,0))
img = ax.imshow(test_images[0+single_pu_batch_size*(batch_i)].transpose(1,2,0))
ax.imshow(attention_resized, cmap='gray', alpha=0.6, extent=img.get_extent())
break
else:
ax.set_title(dict[int(mx.nd.slice_axis(outputs[l+1], axis=0, begin=0, end=1).squeeze().asscalar())])
img = ax.imshow(test_images[0+batch_size*(batch_i)].transpose(1,2,0))
img = ax.imshow(test_images[0+single_pu_batch_size*(batch_i)].transpose(1,2,0))
ax.imshow(attention_resized, cmap='gray', alpha=0.6, extent=img.get_extent())
plt.tight_layout()
......@@ -633,14 +665,17 @@ class CNNSupervisedTrainer_mnist_mnistClassifier_net:
predictions = []
for output_name in outputs:
predictions.append(output_name)
if mx.nd.shape_array(mx.nd.squeeze(output_name)).size > 1:
predictions.append(mx.nd.argmax(output_name, axis=1))
else:
predictions.append(output_name)
metric.update(preds=predictions, labels=[labels[j] for j in range(len(labels))])
test_metric_score = metric.get()[1]
global_loss_test /= (test_batches * single_pu_batch_size)
test_metric_score = metric.get()[1]
logging.info("Epoch[%d] Train metric: %f, Test metric: %f, Train loss: %f, Test loss: %f" % (epoch, train_metric_score, test_metric_score, global_loss_train, global_loss_test))
if (epoch+1) % checkpoint_period == 0:
......
......@@ -146,11 +146,21 @@ class DotProductSelfAttention(gluon.HybridBlock):
score = F.batch_dot(head_queries, head_keys, transpose_b=True)
score = score * self.scale_factor
if self.use_mask:
mask = F.tile(mask, self.num_heads)
mask = F.repeat(mask, self.dim_model)
mask = F.reshape(mask, shape=(-1, self.dim_model))
weights = F.softmax(score, mask, use_length=self.use_mask)
seqs = F.contrib.arange_like(score, axis=1)
zeros = F.zeros_like(seqs)
zeros = F.reshape(zeros, shape=(1, -1))
mask = args[0]
mask = F.reshape(mask, shape=(-1, 1))
mask = F.broadcast_add(mask, zeros)
mask = F.expand_dims(mask, axis=1)
mask = F.broadcast_axis(mask, axis=1, size=self.num_heads)
mask = mask.reshape(shape=(-1, 0), reverse=True)
mask = F.cast(mask, dtype='int32')
weights = F.softmax(score, mask, use_length=self.use_mask)
else:
weights = F.softmax(score)
head_values = F.reshape(head_values, shape=(0, 0, self.num_heads, -1))
head_values = F.transpose(head_values, axes=(0,2,1,3))
......@@ -169,7 +179,7 @@ class DotProductSelfAttention(gluon.HybridBlock):
class EpisodicReplayMemoryInterface(gluon.HybridBlock):
__metaclass__ = abc.ABCMeta
def __init__(self, use_replay, replay_interval, replay_batch_size, replay_steps, replay_gradient_steps, num_heads, **kwargs):
def __init__(self, use_replay, replay_interval, replay_batch_size, replay_steps, replay_gradient_steps, use_local_adaptation, local_adaptation_gradient_steps, k, **kwargs):
super(EpisodicReplayMemoryInterface, self).__init__(**kwargs)
self.use_replay = use_replay
......@@ -177,7 +187,10 @@ class EpisodicReplayMemoryInterface(gluon.HybridBlock):
self.replay_batch_size = replay_batch_size
self.replay_steps = replay_steps
self.replay_gradient_steps = replay_gradient_steps
self.num_heads = num_heads
self.use_local_adaptation = use_local_adaptation
self.local_adaptation_gradient_steps = local_adaptation_gradient_steps
self.k = k
@abc.abstractmethod
def store_samples(self, data, y, query_network, store_prob, mx_context):
......@@ -187,6 +200,10 @@ class EpisodicReplayMemoryInterface(gluon.HybridBlock):
def sample_memory(self, batch_size, mx_context):
pass
@abc.abstractmethod
def sample_neighbours(self, data, query_network):
pass
@abc.abstractmethod
def get_query_network(self, mx_context):
pass
......@@ -205,7 +222,6 @@ class LargeMemory(gluon.HybridBlock):
sub_key_size,
query_size,
query_act,
dist_measure,
k,
num_heads,
values_dim,
......@@ -213,7 +229,6 @@ class LargeMemory(gluon.HybridBlock):
super(LargeMemory, self).__init__(**kwargs)
with self.name_scope():
#Memory parameters
self.dist_measure = dist_measure
self.k = k
self.num_heads = num_heads
self.query_act = query_act
......@@ -250,46 +265,25 @@ class LargeMemory(gluon.HybridBlock):
q_split = F.split(q, num_outputs=2, axis=-1)
if self.dist_measure == "l2":
q_split_resh = F.reshape(q_split[0], shape=(0,0,1,-1))
sub_keys1_resh = F.reshape(sub_keys1, shape=(1,0,0,-1), reverse=True)
q1_diff = F.broadcast_sub(q_split_resh, sub_keys1_resh)
q1_dist = F.norm(q1_diff, axis=-1)
q_split_resh = F.reshape(q_split[1], shape=(0,0,1,-1))
sub_keys2_resh = F.reshape(sub_keys2, shape=(1,0,0,-1), reverse=True)
q2_diff = F.broadcast_sub(q_split_resh, sub_keys2_resh)
q2_dist = F.norm(q2_diff, axis=-1)
else:
q1 = F.split(q_split[0], num_outputs=self.num_heads, axis=1)
q2 = F.split(q_split[1], num_outputs=self.num_heads, axis=1)
sub_keys1_resh = F.split(sub_keys1, num_outputs=self.num_heads, axis=0, squeeze_axis=True)
sub_keys2_resh = F.split(sub_keys2, num_outputs=self.num_heads, axis=0, squeeze_axis=True)
if self.num_heads == 1:
q1 = [q1]
q2 = [q2]
sub_keys1_resh = [sub_keys1_resh ]
sub_keys2_resh = [sub_keys2_resh ]
q1_dist = F.dot(q1[0],