Commit a3a048d8 authored by Julian Johannes Steinsberger-Dührßen's avatar Julian Johannes Steinsberger-Dührßen
Browse files

cleanup and fixes

parent d157159f
Pipeline #348412 failed with stage
in 54 seconds
......@@ -115,4 +115,4 @@ public class CNNArch2Gluon extends CNNArchGenerator {
}
return fileContentMap;
}
}
}
\ No newline at end of file
......@@ -45,6 +45,7 @@ public class CNNArch2GluonLayerSupportChecker extends LayerSupportChecker {
supportedLayerList.add(AllPredefinedLayers.EPISODIC_MEMORY_NAME);
supportedLayerList.add(AllPredefinedLayers.DOT_PRODUCT_SELF_ATTENTION_NAME);
supportedLayerList.add(AllPredefinedLayers.LOAD_NETWORK_NAME);
supportedLayerList.add(AllPredefinedLayers.LAYERNORM_NAME);
}
}
/* (c) https://github.com/MontiCore/monticore */
package de.monticore.lang.monticar.cnnarch.gluongenerator;
import de.monticore.lang.monticar.cnntrain._ast.*;
import de.monticore.lang.monticar.cnnarch.generator.TrainParamSupportChecker;
public class CNNArch2GluonTrainParamSupportChecker extends TrainParamSupportChecker {
public void visit(ASTAdamWOptimizer node){}
}
......@@ -147,11 +147,21 @@ class DotProductSelfAttention(gluon.HybridBlock):
score = F.batch_dot(head_queries, head_keys, transpose_b=True)
score = score * self.scale_factor
if self.use_mask:
mask = F.tile(mask, self.num_heads)
mask = F.repeat(mask, self.dim_model)
mask = F.reshape(mask, shape=(-1, self.dim_model))
weights = F.softmax(score, mask, use_length=self.use_mask)
seqs = F.contrib.arange_like(score, axis=1)
zeros = F.zeros_like(seqs)
zeros = F.reshape(zeros, shape=(1, -1))
mask = args[0]
mask = F.reshape(mask, shape=(-1, 1))
mask = F.broadcast_add(mask, zeros)
mask = F.expand_dims(mask, axis=1)
mask = F.broadcast_axis(mask, axis=1, size=self.num_heads)
mask = mask.reshape(shape=(-1, 0), reverse=True)
mask = F.cast(mask, dtype='int32')
weights = F.softmax(score, mask, use_length=self.use_mask)
else:
weights = F.softmax(score)
head_values = F.reshape(head_values, shape=(0, 0, self.num_heads, -1))
head_values = F.transpose(head_values, axes=(0,2,1,3))
......@@ -170,7 +180,7 @@ class DotProductSelfAttention(gluon.HybridBlock):
class EpisodicReplayMemoryInterface(gluon.HybridBlock):
__metaclass__ = abc.ABCMeta
def __init__(self, use_replay, replay_interval, replay_batch_size, replay_steps, replay_gradient_steps, num_heads, **kwargs):
def __init__(self, use_replay, replay_interval, replay_batch_size, replay_steps, replay_gradient_steps, use_local_adaptation, local_adaptation_gradient_steps, k, **kwargs):
super(EpisodicReplayMemoryInterface, self).__init__(**kwargs)
self.use_replay = use_replay
......@@ -178,7 +188,10 @@ class EpisodicReplayMemoryInterface(gluon.HybridBlock):
self.replay_batch_size = replay_batch_size
self.replay_steps = replay_steps
self.replay_gradient_steps = replay_gradient_steps
self.num_heads = num_heads
self.use_local_adaptation = use_local_adaptation
self.local_adaptation_gradient_steps = local_adaptation_gradient_steps
self.k = k
@abc.abstractmethod
def store_samples(self, data, y, query_network, store_prob, mx_context):
......@@ -188,6 +201,10 @@ class EpisodicReplayMemoryInterface(gluon.HybridBlock):
def sample_memory(self, batch_size, mx_context):
pass
@abc.abstractmethod
def sample_neighbours(self, data, query_network):
pass
@abc.abstractmethod
def get_query_network(self, mx_context):
pass
......@@ -206,7 +223,6 @@ class LargeMemory(gluon.HybridBlock):
sub_key_size,
query_size,
query_act,
dist_measure,
k,
num_heads,
values_dim,
......@@ -214,7 +230,6 @@ class LargeMemory(gluon.HybridBlock):
super(LargeMemory, self).__init__(**kwargs)
with self.name_scope():
#Memory parameters
self.dist_measure = dist_measure
self.k = k
self.num_heads = num_heads
self.query_act = query_act
......@@ -251,46 +266,25 @@ class LargeMemory(gluon.HybridBlock):
q_split = F.split(q, num_outputs=2, axis=-1)
if self.dist_measure == "l2":
q_split_resh = F.reshape(q_split[0], shape=(0,0,1,-1))
sub_keys1_resh = F.reshape(sub_keys1, shape=(1,0,0,-1), reverse=True)
q1_diff = F.broadcast_sub(q_split_resh, sub_keys1_resh)
q1_dist = F.norm(q1_diff, axis=-1)
q_split_resh = F.reshape(q_split[1], shape=(0,0,1,-1))
sub_keys2_resh = F.reshape(sub_keys2, shape=(1,0,0,-1), reverse=True)
q2_diff = F.broadcast_sub(q_split_resh, sub_keys2_resh)
q2_dist = F.norm(q2_diff, axis=-1)
else:
q1 = F.split(q_split[0], num_outputs=self.num_heads, axis=1)
q2 = F.split(q_split[1], num_outputs=self.num_heads, axis=1)
sub_keys1_resh = F.split(sub_keys1, num_outputs=self.num_heads, axis=0, squeeze_axis=True)
sub_keys2_resh = F.split(sub_keys2, num_outputs=self.num_heads, axis=0, squeeze_axis=True)
if self.num_heads == 1:
q1 = [q1]
q2 = [q2]
sub_keys1_resh = [sub_keys1_resh ]
sub_keys2_resh = [sub_keys2_resh ]
q1_dist = F.dot(q1[0], sub_keys1_resh[0], transpose_b=True)
q2_dist = F.dot(q2[0], sub_keys2_resh[0], transpose_b=True)
for h in range(1, self.num_heads):
q1_dist = F.concat(q1_dist, F.dot(q1[0], sub_keys1_resh[h], transpose_b=True), dim=1)
q2_dist = F.concat(q2_dist, F.dot(q2[0], sub_keys1_resh[h], transpose_b=True), dim=1)
q1 = F.split(q_split[0], num_outputs=self.num_heads, axis=1)
q2 = F.split(q_split[1], num_outputs=self.num_heads, axis=1)
sub_keys1_resh = F.split(sub_keys1, num_outputs=self.num_heads, axis=0, squeeze_axis=True)
sub_keys2_resh = F.split(sub_keys2, num_outputs=self.num_heads, axis=0, squeeze_axis=True)
if self.num_heads == 1:
q1 = [q1]
q2 = [q2]
sub_keys1_resh = [sub_keys1_resh ]
sub_keys2_resh = [sub_keys2_resh ]
q1_dist = F.dot(q1[0], sub_keys1_resh[0], transpose_b=True)
q2_dist = F.dot(q2[0], sub_keys2_resh[0], transpose_b=True)
for h in range(1, self.num_heads):
q1_dist = F.concat(q1_dist, F.dot(q1[0], sub_keys1_resh[h], transpose_b=True), dim=1)
q2_dist = F.concat(q2_dist, F.dot(q2[0], sub_keys1_resh[h], transpose_b=True), dim=1)
i1 = F.topk(q1_dist, k=self.k, ret_typ="indices")
i2 = F.topk(q2_dist, k=self.k, ret_typ="indices")
# Calculate cross product for keys at indices I1 and I2
# def head_take(data, state):
# return [F.take(data[0], data[2]), F.take(data[1], data[3])], state,
#
# i1 = F.transpose(i1, axes=(1,0,2))
# i2 = F.transpose(i2, axes=(1, 0, 2))
# st = F.zeros(1)
# (k1, k2), _ = F.contrib.foreach(head_take, [sub_keys1, sub_keys2,i1,i2], st)
# k1 = F.reshape(k1, shape=(-1, 0, 0), reverse=True)
# k2 = F.reshape(k2, shape=(-1, 0, 0), reverse=True)
i1 = F.split(i1, num_outputs=self.num_heads, axis=1)
i2 = F.split(i2, num_outputs=self.num_heads, axis=1)
sub_keys1 = F.split(sub_keys1, num_outputs=self.num_heads, axis=0, squeeze_axis=True)
......@@ -314,12 +308,9 @@ class LargeMemory(gluon.HybridBlock):
q = F.reshape(q, shape=(-1,0), reverse=True)
q = F.reshape(q, shape=(0, 1, -1))
c_cart = F.reshape(c_cart, shape=(-1, 0, 0), reverse=True)
if self.dist_measure == "l2":
k_diff = F.broadcast_sub(q, c_cart)
k_dist = F.norm(k_diff, axis=-1)
else:
k_dist = F.batch_dot(q, c_cart, transpose_b=True) #F.contrib.foreach(loop_batch_dot, [q, c_cart], init_states=state_batch_dist)
k_dist = F.reshape(k_dist, shape=(0, -1))
k_dist = F.batch_dot(q, c_cart, transpose_b=True) #F.contrib.foreach(loop_batch_dot, [q, c_cart], init_states=state_batch_dist)
k_dist = F.reshape(k_dist, shape=(0, -1))
i = F.topk(k_dist, k=self.k, ret_typ="both")
......@@ -360,11 +351,14 @@ class EpisodicMemory(EpisodicReplayMemoryInterface):
max_stored_samples,
memory_replacement_strategy,
use_replay,
use_local_adaptation,
local_adaptation_gradient_steps,
k,
query_net_dir,
query_net_prefix,
query_net_num_inputs,
**kwargs):
super(EpisodicMemory, self).__init__(use_replay, replay_interval, replay_batch_size, replay_steps, replay_gradient_steps, 1, **kwargs)
super(EpisodicMemory, self).__init__(use_replay, replay_interval, replay_batch_size, replay_steps, replay_gradient_steps, use_local_adaptation, local_adaptation_gradient_steps, k, **kwargs)
with self.name_scope():
#Replay parameters
self.store_prob = store_prob
......@@ -459,6 +453,29 @@ class EpisodicMemory(EpisodicReplayMemoryInterface):
return sample_batches
def sample_neighbours(self, data, query_network):
num_stored_samples = self.key_memory.shape[0]
batch_size = data[0].shape[0]
query = query_network(*data).as_in_context(mx.cpu())
vec1 = nd.repeat(query, repeats=num_stored_samples, axis=0)
vec2 = nd.tile(self.key_memory, reps=(batch_size, 1))
diff = nd.subtract(vec1, vec2)
sq = nd.square(diff)
batch_sum = nd.sum(sq, exclude=1, axis=0)
sqrt = nd.sqrt(batch_sum)
dist = nd.reshape(sqrt, shape=(batch_size, num_stored_samples))
sample_ind = nd.topk(dist, k=self.k, axis=1, ret_typ="indices")
num_outputs = len(self.label_memory)
sample_labels = [self.label_memory[i][sample_ind] for i in range(num_outputs)]
sample_batches = [[self.value_memory[j][sample_ind] for j in range(len(self.value_memory))], sample_labels]
return sample_batches
def get_query_network(self, context):
lastEpoch = 0
for file in os.listdir(self.query_net_dir):
......
......@@ -51,7 +51,7 @@ public:
std::vector<std::map<std::string, NDArray>> replay_memory;
//parameters for local adapt
std::vector<bool> use_local_adaption = {};
std::vector<bool> use_local_adaptation = {};
std::vector<mx_uint> replay_k = {};
std::vector<mx_uint> gradient_steps = {};
std::vector<mx_uint> query_num_inputs = {};
......@@ -92,7 +92,7 @@ public:
std::vector<std::vector<float>> network_input = {${tc.join(tc.getStreamInputNames(networkInstruction.body, false), ", ", "in_", "")}};
for(mx_uint i=1; i < num_subnets; i++){
if(use_local_adaption[i-1]){
if(use_local_adaptation[i-1]){
local_adapt(i, replay_query_handles[i-1], replay_memory[i-1], network_input, network_input_keys, network_input_shapes, network_input_sizes, loss_input_keys, gradient_steps[i-1], replay_k[i-1]);
}
}
......@@ -138,7 +138,7 @@ public:
}
<#if networkInstruction.body.episodicSubNetworks?has_content>
//perform local adaption, train network on examples, only use updated on one inference (locally), don't save them
//perform local adaptation, train network on examples, only use updated on one inference (locally), don't save them
void local_adapt(int net_start_ind,
Executor * query_handle,
std::map<std::string, NDArray> &memory,
......
......@@ -111,6 +111,37 @@ class SoftmaxCrossEntropyLossIgnoreLabel(gluon.loss.Loss):
loss = gluon.loss._apply_weighting(F, loss, self._weight, sample_weight)
return F.sum(loss) / F.sum(valid_label_map)
class LocalAdaptationLoss(gluon.loss.Loss):
def __init__(self, lamb, axis=-1, sparse_label=True, weight=None, batch_axis=0, **kwargs):
super(LocalAdaptationLoss, self).__init__(weight, batch_axis, **kwargs)
self.lamb = lamb
self._axis = axis
self._sparse_label = sparse_label
def hybrid_forward(self, F, pred, label, curr_weights, base_weights, sample_weight=None):
pred = F.log(pred)
if self._sparse_label:
cross_entr_loss = -F.pick(pred, label, axis=self._axis, keepdims=True)
else:
label = gluon.loss._reshape_like(F, label, pred)
cross_entr_loss = -F.sum(pred * label, axis=self._axis, keepdims=True)
cross_entr_loss = F.mean(cross_entr_loss, axis=self._batch_axis, exclude=True)
weight_diff_loss = 0
for param_key in base_weights:
weight_diff_loss = F.add(weight_diff_loss, F.norm(curr_weights[param_key] - base_weights[param_key]))
#this check is neccessary, otherwise if weight_diff_loss is zero (first training iteration)
#the trainer would update the networks weights to nan, this must have somthing to do how
#mxnet internally calculates the derivatives / tracks the weights
if weight_diff_loss > 0:
loss = self.lamb * weight_diff_loss + cross_entr_loss
loss = gluon.loss._apply_weighting(F, loss, self._weight, sample_weight)
else:
loss = gluon.loss._apply_weighting(F, cross_entr_loss, self._weight, sample_weight)
return loss
@mx.metric.register
class ACCURACY_IGNORE_LABEL(mx.metric.EvalMetric):
"""Ignores a label when computing accuracy.
......@@ -422,6 +453,8 @@ class ${tc.fileNameWithoutEnding}:
</#if>
</#list>
</#list>
# Episodic memory local adaptation
local_adaptation_loss_function = LocalAdaptationLoss(lamb=0.001)
</#if>
tic = None
......@@ -501,14 +534,49 @@ class ${tc.fileNameWithoutEnding}:
global_loss_train /= (train_batches * batch_size)
tic = None
<#assign containsUnrollNetwork = false>
<#assign anyEpisodicLocalAdaptation = false>
<#list tc.architecture.networkInstructions as networkInstruction>
<#if networkInstruction.isUnroll()>
<#assign containsUnrollNetwork = true>
</#if>
<#if networkInstruction.body.anyEpisodicLocalAdaptation>
<#assign anyEpisodicLocalAdaptation = true>
</#if>
</#list>
<#if episodicReplayVisited?? && anyEpisodicLocalAdaptation>
params = {}
for key in self._networks:
paramDict = self._networks[key].collect_params()
params[key] = {}
for param in paramDict:
params[key][param] = paramDict[param].data(ctx=mx_context[0]).copy()
</#if>
if eval_train:
train_iter.batch_size = single_pu_batch_size
train_iter.reset()
metric = mx.metric.create(eval_metric, **eval_metric_params)
for batch_i, batch in enumerate(train_iter):
<#if episodicReplayVisited?? && anyEpisodicLocalAdaptation && !containsUnrollNetwork>
<#include "pythonExecuteTest.ftl">
predictions = []
for output_name in outputs:
if mx.nd.shape_array(mx.nd.squeeze(output_name)).size > 1:
predictions.append(mx.nd.argmax(output_name, axis=1))
else:
predictions.append(output_name)
metric.update(preds=predictions, labels=[labels[j][local_adaptation_batch_i] for j in range(len(labels))])
<#list tc.architecture.networkInstructions as networkInstruction>
self._networks[${networkInstruction?index}].collect_params().load_dict(params[${networkInstruction?index}], ctx=mx_context[0])
</#list>
<#else>
<#include "pythonExecuteTest.ftl">
<#include "saveAttentionImageTrain.ftl">
predictions = []
......@@ -519,6 +587,7 @@ class ${tc.fileNameWithoutEnding}:
predictions.append(output_name)
metric.update(preds=predictions, labels=[labels[j] for j in range(len(labels))])
</#if>
train_metric_score = metric.get()[1]
else:
......@@ -526,13 +595,37 @@ class ${tc.fileNameWithoutEnding}:
global_loss_test = 0.0
test_batches = 0
test_iter.batch_size = single_pu_batch_size
test_iter.reset()
metric = mx.metric.create(eval_metric, **eval_metric_params)
for batch_i, batch in enumerate(test_iter):
if True: <#-- Fix indentation -->
<#if episodicReplayVisited?? && anyEpisodicLocalAdaptation && !containsUnrollNetwork>
<#include "pythonExecuteTest.ftl">
loss = 0
for element in lossList:
loss = loss + element
global_loss_test += loss.sum().asscalar()
test_batches += 1
predictions = []
for output_name in outputs:
if mx.nd.shape_array(mx.nd.squeeze(output_name)).size > 1:
predictions.append(mx.nd.argmax(output_name, axis=1))
else:
predictions.append(output_name)
metric.update(preds=predictions, labels=[labels[j][local_adaptation_batch_i] for j in range(len(labels))])
<#list tc.architecture.networkInstructions as networkInstruction>
self._networks[${networkInstruction?index}].collect_params().load_dict(params[${networkInstruction?index}], ctx=mx_context[0])
</#list>
global_loss_test /= (test_batches)
<#else>
<#include "pythonExecuteTest.ftl">
<#include "saveAttentionImageTest.ftl">
......@@ -546,13 +639,17 @@ class ${tc.fileNameWithoutEnding}:
predictions = []
for output_name in outputs:
predictions.append(output_name)
if mx.nd.shape_array(mx.nd.squeeze(output_name)).size > 1:
predictions.append(mx.nd.argmax(output_name, axis=1))
else:
predictions.append(output_name)
metric.update(preds=predictions, labels=[labels[j] for j in range(len(labels))])
test_metric_score = metric.get()[1]
global_loss_test /= (test_batches * single_pu_batch_size)
</#if>
test_metric_score = metric.get()[1]
logging.info("Epoch[%d] Train metric: %f, Test metric: %f, Train loss: %f, Test loss: %f" % (epoch, train_metric_score, test_metric_score, global_loss_train, global_loss_test))
......
......@@ -9,15 +9,12 @@
<#assign useProjBias = element.useProjBias?string("True", "False")>
<#assign useMask = element.useMask?string("True", "False")>
<#if mode == "ARCHITECTURE_DEFINITION">
<#assign dimModel = 1>
<#list element.element.outputTypes[0].dimensions as dim>
<#assign dimModel = dimModel * dim>
</#list>
<#assign dimModel = element.element.outputTypes[0].dimensions?reverse[1]>
self.${element.name} = DotProductSelfAttention(scale_factor=${scaleFactor}, num_heads=${numHeads}, dim_model=${dimModel}, dim_keys=${dimKeys}, dim_values=${dimValues}, use_proj_bias=${useProjBias}, use_mask=${useMask})
<#include "OutputShape.ftl">
<#elseif mode == "FORWARD_FUNCTION">
<#if (element.inputs[3])??>
${element.name} = self.${element.name}(${inputQueries}, ${inputKeys}, ${inputValues}, element.inputs[3])
${element.name} = self.${element.name}(${inputQueries}, ${inputKeys}, ${inputValues}, ${element.inputs[3]})
<#else>
${element.name} = self.${element.name}(${inputQueries}, ${inputKeys}, ${inputValues})
</#if>
......
......@@ -3,33 +3,35 @@
<#assign replayBatchSize = element.replayBatchSize?c>
<#assign replaySteps = element.replaySteps?c>
<#assign replayGradientSteps = element.replayGradientSteps?c>
<#assign replayMemoryStoreProb = element.replayMemoryStoreProb?c>
<#assign memoryStoreProb = element.memoryStoreProb?c>
<#assign maxStoredSamples = element.maxStoredSamples?c>
<#assign memoryReplacementStrategy = element.memoryReplacementStrategy>
<#assign useReplay = element.useReplay?string("True", "False")>
<#assign useLocalAdaption = element.useLocalAdaption?string("true", "false")>
<#assign localAdaptionK = element.localAdaptionK?c>
<#assign localAdaptionGradientSteps = element.localAdaptionGradientSteps?c>
<#assign useLocalAdaptationPy = element.useLocalAdaptation?string("True", "False")>
<#assign useLocalAdaptationCpp = element.useLocalAdaptation?string("true", "false")>
<#assign localAdaptationK = element.localAdaptationK?c>
<#assign localAdaptationGradientSteps = element.localAdaptationGradientSteps?c>
<#assign queryNetDir = element.queryNetDir>
<#assign queryNetPrefix = element.queryNetPrefix>
<#assign queryNetNumInputs = element.queryNetNumInputs>
<#if mode == "ARCHITECTURE_DEFINITION">
self.${element.name} = EpisodicMemory(replay_interval=${replayInterval}, replay_batch_size=${replayBatchSize}, replay_steps=${replaySteps},
replay_gradient_steps=${replayGradientSteps}, store_prob=${replayMemoryStoreProb},
replay_gradient_steps=${replayGradientSteps}, store_prob=${memoryStoreProb},
max_stored_samples=${maxStoredSamples}, memory_replacement_strategy="${memoryReplacementStrategy}", use_replay=${useReplay},
use_local_adaptation=${useLocalAdaptationPy}, local_adaptation_gradient_steps=${localAdaptationGradientSteps}, k=${localAdaptationK},
query_net_dir="${queryNetDir}/",
query_net_prefix="${queryNetPrefix}",
query_net_num_inputs=${queryNetNumInputs})
<#elseif mode == "FORWARD_FUNCTION">
<#if element.inputs?size == 1>
<#if useReplay == "True" || useLocalAdaption == "true">
<#if useReplay == "True" || useLocaladaptation == "true">
${element.name}full_, ind_${element.name} = self.${element.name}(*args)
<#else>
${element.name}full_, ind_${element.name} = self.${element.name}(${tc.join(element.inputs, ",")})
</#if>
${element.name} = ${element.name}full_[0]
<#else>
<#if useReplay == "True" || useLocalAdaption == "true">
<#if useReplay == "True" || useLocaladaptation == "true">
${element.name}full_, ind_${element.name} = self.${element.name}(*args)
<#else>
${element.name}full_, ind_${element.name} = self.${element.name}(${tc.join(element.inputs, ",")})
......@@ -37,9 +39,9 @@
${element.name} = ${element.name}full_
</#if>
<#elseif mode == "PREDICTION_PARAMETER">
use_local_adaption.push_back(${useLocalAdaption});
replay_k.push_back(${localAdaptionK});
gradient_steps.push_back(${localAdaptionGradientSteps});
use_local_adaptation.push_back(${useLocalAdaptationCpp});
replay_k.push_back(${localAdaptationK});
gradient_steps.push_back(${localAdaptationGradientSteps});
query_num_inputs.push_back(${queryNetNumInputs});
</#if>
\ No newline at end of file
......@@ -3,14 +3,12 @@
<#assign subKeySize = element.subKeySize?c>
<#assign querySize = "[" + tc.join(element.querySize, ",") + "]">
<#assign queryAct = element.queryAct>
<#assign storeDistMeasure = element.storeDistMeasure>
<#assign k = element.k?c>
<#assign numHeads = element.numHeads?c>
<#assign valuesDim = element.valuesDim?c>
<#if mode == "ARCHITECTURE_DEFINITION">
self.${element.name} = LargeMemory(sub_key_size=${subKeySize}, query_size=${querySize}, query_act="${queryAct}",
dist_measure="${storeDistMeasure}", k=${k}, num_heads=${numHeads},
values_dim=${valuesDim})
k=${k}, num_heads=${numHeads}, values_dim=${valuesDim})
<#elseif mode == "FORWARD_FUNCTION">
${element.name} = self.${element.name}(${input})
</#if>
\ No newline at end of file
<#-- (c) https://github.com/MontiCore/monticore -->
<#assign input = element.inputs[0]>
<#if mode == "ARCHITECTURE_DEFINITION">
self.${element.name} = gluon.nn.LayerNorm()
<#include "OutputShape.ftl">
<#elseif mode == "FORWARD_FUNCTION">
${element.name} = self.${element.name}(${input})
</#if>
......@@ -33,7 +33,7 @@
self.${element.name}out_shape = self.${element.name}(*zeroInputs).shape
if self.${element.name}out_shape != (1,${outputShape}):
outputSize=1
for x in (${outputShape}):
for x in (${outputShape},):
outputSize = outputSize * x
self.${element.name}fc_ = gluon.nn.Dense(units=outputSize, use_bias=False, flatten=False)
......@@ -42,7 +42,5 @@
if self.${element.name}out_shape != (1,${outputShape}):
${element.name} = self.${element.name}fc_(${element.name})
${element.name} = F.reshape(${element.name}, shape=(-1,${outputShape}))
<#elseif mode == "PREDICTION_PARAMETER">
query_num_inputs.push_back(${numInputs});
</#if>
\ No newline at end of file
labels = [gluon.utils.split_and_load(batch.label[i], ctx_list=mx_context, even_split=False)[0] for i in range(${tc.architectureOutputs?size?c})]
labels = [batch.label[i].as_in_context(mx_context[0]) for i in range(${tc.architectureOutputs?size?c})]
<#list tc.architectureInputs as input_name>
${input_name} = gluon.utils.split_and_load(batch.data[${input_name?index}], ctx_list=mx_context, even_split=False)[0]
${input_name} = batch.data[${input_name?index}].as_in_context(mx_context[0])
</#list>
<#if tc.architectureOutputSymbols?size gt 1>
......@@ -104,20 +104,85 @@
</#if>
</#if>
</#list>
<#else>
net_ret = self._networks[${networkInstruction?index}](${tc.join(tc.getStreamInputNames(networkInstruction.body, true), ", ")})
<#elseif networkInstruction.body.anyEpisodicLocalAdaptation>
for layer_i, layer in enumerate(episodic_layers[${networkInstruction?index}]):
if layer.use_local_adaptation:
local_adaptation_output = self._networks[${networkInstruction?index}].episodicsubnet0_(${tc.join(tc.getStreamInputNames(networkInstruction.body, true), ", ")})[0]
for i in range(1, layer_i):
local_adaptation_output = self._networks[${networkInstruction?index}].episodic_sub_nets[i](*local_adaptation_output)[0]
local_adaptation_batch = layer.sample_neighbours(local_adaptation_output, episodic_query_networks[${networkInstruction?index}][layer_i])
local_adaptation_data = {}
local_adaptation_labels = {}
local_adaptation_data[layer_i] = [[local_adaptation_batch[0][i][j].as_in_context(mx_context[0]) for i in range(len(local_adaptation_batch[0]))] for j in range(single_pu_batch_size)]
local_adaptation_labels[layer_i] = [[local_adaptation_batch[1][i][j].as_in_context(mx_context[0]) for i in range(len(local_adaptation_batch[1]))] for j in range(single_pu_batch_size)]
for local_adaptation_batch_i in range(single_pu_batch_size):
self._networks[${networkInstruction?index}].collect_params().load_dict(params[${networkInstruction?index}], ctx=mx_context[0])
if len(self._networks[${networkInstruction?index}].collect_params().values()) != 0:
if optimizer == "adamw":
local_adaptation_trainer = mx.gluon.Trainer(self._networks[${networkInstruction?index}].collect_params(), AdamW.AdamW(**optimizer_params))
else:
local_adaptation_trainer = mx.gluon.Trainer(self._networks[${networkInstruction?index}].collect_params(), optimizer, optimizer_params)
for layer_i, layer in enumerate(episodic_layers[${networkInstruction?index}]):
if layer.use_local_adaptation:
for gradient_step in range(layer.local_adaptation_gradient_steps):
with autograd.record():
local_adaptation_output = self._networks[${networkInstruction?index}].episodic_sub_nets[layer_i](*(local_adaptation_data[layer_i][local_adaptation_batch_i]))[0]
for i in range(layer_i+1, len(episodic_layers[${networkInstruction?index}])):
local_adaptation_output = self._networks[${networkInstruction?index}].episodic_sub_nets[i](*local_adaptation_output)[0]
curr_param_dict = self._networks[${networkInstruction?index}].collect_params()
curr_params = {}
for param in curr_param_dict:
curr_params[param] = curr_param_dict[param].data()
local_adaptation_loss_list = []
<#list tc.getStreamOutputNames(networkInstruction.body, true) as outputName>
${outputName} = net_ret[0][${outputName?index}]
<#if tc.getNameWithoutIndex(outputName) == tc.outputName>
local_adaptation_loss_list.append(local_adaptation_loss_function(local_adaptation_output[${outputName?index}], local_adaptation_labels[layer_i][local_adaptation_batch_i][${outputName?index}], curr_params, params[${networkInstruction?index}]))
</#if>
</#list>
loss = 0
for element in local_adaptation_loss_list:
loss = loss + element
loss.backward()
if clip_global_grad_norm:
grads = []
for network in self._networks.values():
grads.extend([param.grad(mx_context) for param in network.collect_params().values()])
gluon.utils.clip_global_norm(grads, clip_global_grad_norm)
local_adaptation_trainer.step(layer.k)
outputs = []
lossList = []
net_ret = self._networks[${networkInstruction?index}](${tc.join(tc.getStreamInputNames(networkInstruction.body, true), ".take(nd.array([local_adaptation_batch_i], ctx=mx_context[0])), ")}.take(nd.array([local_adaptation_batch_i], ctx=mx_context[0])))
<#list tc.getStreamOutputNames(networkInstruction.body, true) as outputName>
${outputName} = net_ret[0][${outputName?index}]
<#if tc.getNameWithoutIndex(outputName) == tc.outputName>
outputs.append(${outputName})
lossList.append(loss_function(${outputName}, labels[${tc.getIndex(outputName, true)}]))