Commit 22fa37ac authored by Julian Johannes Steinsberger-Dührßen's avatar Julian Johannes Steinsberger-Dührßen
Browse files

bug fixes, added tests for EpisodicMemory and LoadNetwork, increased version number

parent ed639463
Pipeline #323329 failed with stage
in 15 seconds
......@@ -9,7 +9,7 @@
<groupId>de.monticore.lang.monticar</groupId>
<artifactId>cnnarch-gluon-generator</artifactId>
<version>0.2.11-SNAPSHOT</version>
<version>0.2.12-SNAPSHOT</version>
<!-- == PROJECT DEPENDENCIES ============================================= -->
......@@ -17,9 +17,9 @@
<!-- .. SE-Libraries .................................................. -->
<CNNArch.version>0.3.5-SNAPSHOT</CNNArch.version>
<CNNTrain.version>0.3.11-SNAPSHOT</CNNTrain.version>
<CNNArch2X.version>0.0.6-SNAPSHOT</CNNArch2X.version>
<CNNArch.version>0.3.7-SNAPSHOT</CNNArch.version>
<CNNTrain.version>0.3.12-SNAPSHOT</CNNTrain.version>
<CNNArch2X.version>0.0.7-SNAPSHOT</CNNArch2X.version>
<embedded-montiarc-math-opt-generator>0.1.6</embedded-montiarc-math-opt-generator>
<EMADL2PythonWrapper.version>0.0.2-SNAPSHOT</EMADL2PythonWrapper.version>
......
......@@ -135,15 +135,17 @@ public class CNNTrain2Gluon extends CNNTrainGenerator {
fileContentMap.put("CNNLAOptimizer_" + getInstanceName() + ".h", cnnTrainLAOptimizerTemplateContent);
//AdamW optimizer if used for training
String optimizerName = configuration.getOptimizer().getName();
Optional<OptimizerSymbol> criticOptimizer = configuration.getCriticOptimizer();
String criticOptimizerName = "";
if (criticOptimizer.isPresent()){
criticOptimizerName = criticOptimizer.get().getName();
}
if (optimizerName.equals("adamw") || criticOptimizerName.equals("adamw")){
String adamWContent = templateConfiguration.processTemplate(ftlContext, "Optimizer/AdamW.ftl");
fileContentMap.put("AdamW.py", adamWContent);
if(configuration.getOptimizer() != null) {
String optimizerName = configuration.getOptimizer().getName();
Optional<OptimizerSymbol> criticOptimizer = configuration.getCriticOptimizer();
String criticOptimizerName = "";
if (criticOptimizer.isPresent()) {
criticOptimizerName = criticOptimizer.get().getName();
}
if (optimizerName.equals("adamw") || criticOptimizerName.equals("adamw")) {
String adamWContent = templateConfiguration.processTemplate(ftlContext, "Optimizer/AdamW.ftl");
fileContentMap.put("AdamW.py", adamWContent);
}
}
if (configData.isSupervisedLearning()) {
......
......@@ -4,6 +4,7 @@ import logging
import os
import shutil
import warnings
import inspect
<#list tc.architecture.networkInstructions as networkInstruction>
from CNNNet_${tc.fullArchitectureName} import Net_${networkInstruction?index}
......@@ -28,6 +29,10 @@ class ${tc.fileNameWithoutEnding}:
for i, network in self.networks.items():
lastEpoch = 0
param_file = None
if hasattr(network, 'episodic_sub_nets'):
num_episodic_sub_nets = len(network.episodic_sub_nets)
lastMemEpoch = [0]*num_episodic_sub_nets
mem_files = [None]*num_episodic_sub_nets
try:
os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + "_newest-0000.params")
......@@ -38,22 +43,77 @@ class ${tc.fileNameWithoutEnding}:
except OSError:
pass
if hasattr(network, 'episodic_sub_nets'):
try:
os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_sub_net_' + str(0) + "-0000.params")
except OSError:
pass
try:
os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_sub_net_' + str(0) + "-symbol.json")
except OSError:
pass
for j in range(len(network.episodic_sub_nets)):
try:
os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_sub_net_' + str(j+1) + "-0000.params")
except OSError:
pass
try:
os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_sub_net_' + str(j+1) + "-symbol.json")
except OSError:
pass
try:
os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_query_net_' + str(j+1) + "-0000.params")
except OSError:
pass
try:
os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_episodic_query_net_' + str(j+1) + "-symbol.json")
except OSError:
pass
try:
os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_loss' + "-0000.params")
except OSError:
pass
try:
os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + '_newest_loss' + "-symbol.json")
except OSError:
pass
try:
os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + "_newest_episodic_memory_sub_net_" + str(j + 1) + "-0000")
except OSError:
pass
if os.path.isdir(self._model_dir_):
for file in os.listdir(self._model_dir_):
if ".params" in file and self._model_prefix_ + "_" + str(i) in file and not "loss" in file:
epochStr = file.replace(".params","").replace(self._model_prefix_ + "_" + str(i) + "-","")
epochStr = file.replace(".params", "").replace(self._model_prefix_ + "_" + str(i) + "-", "")
epoch = int(epochStr)
if epoch > lastEpoch:
if epoch >= lastEpoch:
lastEpoch = epoch
param_file = file
elif hasattr(network, 'episodic_sub_nets') and self._model_prefix_ + "_" + str(i) + "_episodic_memory_sub_net_" in file:
relMemPathInfo = file.replace(self._model_prefix_ + "_" + str(i) + "_episodic_memory_sub_net_", "").split("-")
memSubNet = int(relMemPathInfo[0])
memEpochStr = relMemPathInfo[1]
memEpoch = int(memEpochStr)
if memEpoch >= lastMemEpoch[memSubNet-1]:
lastMemEpoch[memSubNet-1] = memEpoch
mem_files[memSubNet-1] = file
if param_file is None:
earliestLastEpoch = 0
else:
logging.info("Loading checkpoint: " + param_file)
network.load_parameters(self._model_dir_ + param_file)
if hasattr(network, 'episodic_sub_nets'):
for j, sub_net in enumerate(network.episodic_sub_nets):
if mem_files[j] != None:
logging.info("Loading Replay Memory: " + mem_files[j])
mem_layer = [param for param in inspect.getmembers(sub_net, lambda x: not(inspect.isroutine(x))) if param[0].startswith("memory")][0][1]
mem_layer.load_memory(self._model_dir_ + mem_files[j])
if earliestLastEpoch == None or lastEpoch < earliestLastEpoch:
earliestLastEpoch = lastEpoch
if earliestLastEpoch == None or lastEpoch + 1 < earliestLastEpoch:
earliestLastEpoch = lastEpoch + 1
return earliestLastEpoch
......@@ -64,6 +124,11 @@ class ${tc.fileNameWithoutEnding}:
for i, network in self.networks.items():
# param_file = self._model_prefix_ + "_" + str(i) + "_newest-0000.params"
param_file = None
if hasattr(network, 'episodic_sub_nets'):
num_episodic_sub_nets = len(network.episodic_sub_nets)
lastMemEpoch = [0] * num_episodic_sub_nets
mem_files = [None] * num_episodic_sub_nets
if os.path.isdir(self._weights_dir_):
lastEpoch = 0
......@@ -72,17 +137,35 @@ class ${tc.fileNameWithoutEnding}:
if ".params" in file and self._model_prefix_ + "_" + str(i) in file and not "loss" in file:
epochStr = file.replace(".params","").replace(self._model_prefix_ + "_" + str(i) + "-","")
epoch = int(epochStr)
if epoch > lastEpoch:
if epoch >= lastEpoch:
lastEpoch = epoch
param_file = file
elif hasattr(network, 'episodic_sub_nets') and self._model_prefix_ + "_" + str(i) + "_episodic_memory_sub_net_" in file:
relMemPathInfo = file.replace(self._model_prefix_ + "_" + str(i) + "_episodic_memory_sub_net_").split("-")
memSubNet = int(relMemPathInfo[0])
memEpochStr = relMemPathInfo[1]
memEpoch = int(memEpochStr)
if memEpoch >= lastMemEpoch[memSubNet-1]:
lastMemEpoch[memSubNet-1] = memEpoch
mem_files[memSubNet-1] = file
logging.info("Loading pretrained weights: " + self._weights_dir_ + param_file)
network.load_parameters(self._weights_dir_ + param_file, allow_missing=True, ignore_extra=True)
if hasattr(network, 'episodic_sub_nets'):
assert lastEpoch == lastMemEpoch
for j, sub_net in enumerate(network.episodic_sub_nets):
if mem_files[j] != None:
logging.info("Loading pretrained Replay Memory: " + mem_files[j])
mem_layer = \
[param for param in inspect.getmembers(sub_net, lambda x: not (inspect.isroutine(x))) if
param[0].startswith("memory")][0][1]
mem_layer.load_memory(self._model_dir_ + mem_files[j])
else:
logging.info("No pretrained weights available at: " + self._weights_dir_ + param_file)
def construct(self, context, data_mean=None, data_std=None):
<#list tc.architecture.networkInstructions as networkInstruction>
self.networks[${networkInstruction?index}] = Net_${networkInstruction?index}(data_mean=data_mean, data_std=data_std, mx_context=context)
self.networks[${networkInstruction?index}] = Net_${networkInstruction?index}(data_mean=data_mean, data_std=data_std, mx_context=context, prefix="")
with warnings.catch_warnings():
warnings.simplefilter("ignore")
self.networks[${networkInstruction?index}].collect_params().initialize(self.weight_initializer, force_reinit=False, ctx=context)
......
......@@ -24,7 +24,11 @@ public:
explicit CNNLAOptimizer_${config.instanceName}(){
<#if (config.configuration.optimizer)??>
<#if config.optimizerName == "adamw">
optimizerHandle = OptimizerRegistry::Find("adam");
<#else>
optimizerHandle = OptimizerRegistry::Find("${config.optimizerName}");
</#if>
<#list config.optimizerParams?keys as param>
<#if param == "learning_rate">
optimizerHandle->SetParam("lr", ${config.optimizerParams[param]});
......@@ -40,10 +44,10 @@ public:
optimizerHandle->SetParam("${param}", ${config.optimizerParams[param]});
</#if>
</#list>
<#if (learningRateDecay)??>
<#if (learningRateDecay)?? && (stepSize)??>
<#if !(minLearningRate)??>
<#assign minLearningRate = "1e-08">
</#if>
</#if>
std::unique_ptr<LRScheduler> lrScheduler(new FactorScheduler(${stepSize}, ${learningRateDecay}, ${minLearningRate}));
optimizerHandle->SetLRScheduler(std::move(lrScheduler));
</#if>
......
......@@ -90,7 +90,7 @@ public:
query_param_path = file_prefix + "_episodic_query_net_" + std::to_string(i) + "-0000.params";
loadComponent(query_json_path, query_param_path, query_symbol_list, query_param_map_list);
memory_path = file_prefix + "_replay_memory_" + std::to_string(i);
memory_path = file_prefix + "_episodic_memory_sub_net_" + std::to_string(i) + "-0000";
checkFile(memory_path);
std::map<std::string, NDArray> mem_map = NDArray::LoadToMap(memory_path);
......
......@@ -172,7 +172,7 @@ class EpisodicReplayMemoryInterface(gluon.HybridBlock):
def __init__(self, use_replay, replay_interval, replay_batch_size, replay_steps, replay_gradient_steps, num_heads, **kwargs):
super(EpisodicReplayMemoryInterface, self).__init__(**kwargs)
self.use_replay = use_replay
self.replay_interval = replay_interval
self.replay_batch_size = replay_batch_size
......@@ -191,10 +191,14 @@ class EpisodicReplayMemoryInterface(gluon.HybridBlock):
@abc.abstractmethod
def get_query_network(self, mx_context):
pass
@abc.abstractmethod
def save_memory(self, path):
pass
@abc.abstractmethod
def load_memory(self, path):
pass
#Memory layer
class LargeMemory(gluon.HybridBlock):
......@@ -386,9 +390,9 @@ class EpisodicMemory(EpisodicReplayMemoryInterface):
mx_context = context[0]
if len(self.key_memory) == 0:
self.key_memory = nd.empty(0, ctx=mx_context)
self.value_memory = nd.empty(0, ctx=mx_context)
self.label_memory = nd.empty((num_outputs, 0), ctx=mx_context)
self.key_memory = nd.empty(0, ctx=mx.cpu())
self.value_memory = []
self.label_memory = []#nd.empty((num_outputs, 0), ctx=mx.cpu())
ind = [nd.sample_multinomial(store_prob, sub_batch_sizes[i]).as_in_context(mx_context) for i in range(num_pus)]
......@@ -399,15 +403,13 @@ class EpisodicMemory(EpisodicReplayMemoryInterface):
tmp_values = []
for j in range(0, num_pus):
if max_inds[j]:
#tmp = data[j][0][i].as_in_context(mx_context)
if isinstance(tmp_values, list):
tmp_values = nd.contrib.boolean_mask(data[j][0][i].as_in_context(mx_context), ind[j])
else:
tmp_values = nd.concat(tmp_values, nd.contrib.boolean_mask(data[j][0][i].as_in_context(mx_context), ind[j]), dim=0)
to_store_values.append(tmp_values)
to_store_labels = nd.empty((num_outputs, len(to_store_values[0])), ctx=mx_context)
to_store_labels = []
for i in range(num_outputs):
tmp_labels = []
for j in range(0, num_pus):
......@@ -416,26 +418,29 @@ class EpisodicMemory(EpisodicReplayMemoryInterface):
tmp_labels = nd.contrib.boolean_mask(y[i][j].as_in_context(mx_context), ind[j])
else:
tmp_labels = nd.concat(tmp_labels, nd.contrib.boolean_mask(y[i][j].as_in_context(mx_context), ind[j]), dim=0)
to_store_labels[i] = tmp_labels
to_store_labels.append(tmp_labels)
to_store_keys = query_network(*to_store_values[0:self.query_net_num_inputs])
if self.key_memory.shape[0] == 0:
self.key_memory = to_store_keys
self.value_memory = to_store_values
self.label_memory = to_store_labels
self.key_memory = to_store_keys.as_in_context(mx.cpu())
for i in range(num_inputs):
self.value_memory.append(to_store_values[i].as_in_context(mx.cpu()))
for i in range(num_outputs):
self.label_memory.append(to_store_labels[i].as_in_context(mx.cpu()))
elif self.max_stored_samples != -1 and self.key_memory.shape[0] >= self.max_stored_samples:
num_to_store = to_store_keys.shape[0]
self.key_memory = nd.concat(self.key_memory[num_to_store:], to_store_keys, dim=0)
for i in range(len(to_store_values)):
self.value_memory[i] = nd.concat(self.value_memory[i][num_to_store:], to_store_values[i], dim=0)
self.label_memory = nd.slice_axis(self.label_memory, axis=1, begin=num_to_store, end=-1)
self.label_memory = nd.concat(self.label_memory, to_store_labels, dim=1)
self.key_memory = nd.concat(self.key_memory[num_to_store:], to_store_keys.as_in_context(mx.cpu()), dim=0)
for i in range(num_inputs):
self.value_memory[i] = nd.concat(self.value_memory[i][num_to_store:], to_store_values[i].as_in_context(mx.cpu()), dim=0)
for i in range(num_outputs):
self.label_memory[i] = nd.concat(self.label_memory[i][num_to_store:], to_store_labels[i].as_in_context(mx.cpu()), dim=1)
else:
self.key_memory = nd.concat(self.key_memory, to_store_keys, dim=0)
for i in range(len(to_store_values)):
self.value_memory[i] = nd.concat(self.value_memory[i], to_store_values[i], dim=0)
self.label_memory = nd.concat(self.label_memory, to_store_labels, dim=1)
self.key_memory = nd.concat(self.key_memory, to_store_keys.as_in_context(mx.cpu()), dim=0)
for i in range(num_inputs):
self.value_memory[i] = nd.concat(self.value_memory[i], to_store_values[i].as_in_context(mx.cpu()), dim=0)
for i in range(num_outputs):
self.label_memory[i] = nd.concat(self.label_memory[i], to_store_labels[i].as_in_context(mx.cpu()), dim=0)
def sample_memory(self, batch_size):
num_stored_samples = self.key_memory.shape[0]
......@@ -477,10 +482,22 @@ class EpisodicMemory(EpisodicReplayMemoryInterface):
return net
def save_memory(self, path):
mem_arr = [("keys", self.key_memory)] + [("labels", self.label_memory)] + [("values_"+str(k),v) for (k,v) in enumerate(self.value_memory)]
mem_arr = [("keys", self.key_memory)] + [("values_"+str(k),v) for (k,v) in enumerate(self.value_memory)] + [("labels_"+str(k),v) for (k,v) in enumerate(self.label_memory)]
mem_dict = {entry[0]:entry[1] for entry in mem_arr}
nd.save(path, mem_dict)
def load_memory(self, path):
mem_dict = nd.load(path)
self.value_memory = []
self.label_memory = []
for key in sorted(mem_dict.keys()):
if key == "keys":
self.key_memory = mem_dict[key]
elif key.startswith("values_"):
self.value_memory.append(mem_dict[key])
elif key.startswith("labels_"):
self.label_memory.append(mem_dict[key])
<#list tc.architecture.networkInstructions as networkInstruction>
#Stream ${networkInstruction?index}
......@@ -510,7 +527,7 @@ ${tc.include(networkInstruction.body, elements?index, "FORWARD_FUNCTION")}
retNames = [${tc.join(tc.getSubnetOutputNames(elements), ", ")}]
ret = []
for elem in retNames:
if len(elem) >= 2:
if isinstance(elem, list) and len(elem) >= 2:
for elem2 in elem:
ret.append(elem2)
else:
......
......@@ -52,7 +52,6 @@ public:
//parameters for local adapt
std::vector<bool> use_local_adaption = {};
std::vector<std::string> dist_measure = {};
std::vector<mx_uint> replay_k = {};
std::vector<mx_uint> gradient_steps = {};
std::vector<mx_uint> query_num_inputs = {};
......@@ -94,7 +93,7 @@ public:
for(mx_uint i=1; i < num_subnets; i++){
if(use_local_adaption[i-1]){
local_adapt(i, replay_query_handles[i-1], replay_memory[i-1], network_input, network_input_keys, network_input_shapes, network_input_sizes, loss_input_keys, gradient_steps[i-1], dist_measure[i-1], replay_k[i-1]);
local_adapt(i, replay_query_handles[i-1], replay_memory[i-1], network_input, network_input_keys, network_input_shapes, network_input_sizes, loss_input_keys, gradient_steps[i-1], replay_k[i-1]);
}
}
</#if>
......@@ -131,7 +130,8 @@ public:
curr_output_shape = output[${output_name?index}].GetShape();
curr_output_size = 1;
for (mx_uint i : curr_output_shape) curr_output_size *= i;
assert(curr_output_size == out_${output_name}.size());
//Fix due to a bug in the in how the output arrays are initialized when there are multiple outputs
assert((curr_output_size == out_${output_name}.size()) || (curr_output_size == out_${output_name}[0]));
output[${output_name?index}].SyncCopyToCPU(&out_${output_name});
</#list>
......@@ -139,7 +139,7 @@ public:
<#if networkInstruction.body.episodicSubNetworks?has_content>
//perform local adaption, train network on examples, only use updated on one inference (locally), don't save them
void local_adapt(int net_start_ind,
void local_adapt(int net_start_ind,
Executor * query_handle,
std::map<std::string, NDArray> &memory,
const std::vector<std::vector<float>> &in_data_,
......@@ -148,7 +148,6 @@ public:
std::vector<mx_uint> &in_sizes,
const std::vector<std::string> &loss_keys,
mx_uint gradient_steps,
std::string dist_measure,
mx_uint k){
std::vector<NDArray> prev_output;
......@@ -176,9 +175,13 @@ public:
NDArray::WaitAll();
}
}
for(mx_uint i=0; i < query_num_inputs[net_start_ind-1]; i++){
prev_output[i].CopyTo(&(query_handle->arg_dict()["data" + std::to_string(i)]));
if(query_num_inputs[net_start_ind-1] == 1){
prev_output[0].CopyTo(&(query_handle->arg_dict()["data"]));
}else{
for(mx_uint i=0; i < query_num_inputs[net_start_ind-1]; i++){
prev_output[i].CopyTo(&(query_handle->arg_dict()["data" + std::to_string(i)]));
}
}
NDArray::WaitAll();
......@@ -186,12 +189,7 @@ public:
CheckMXNetError("Query net forward, local_adapt, replay layer " + std::to_string(net_start_ind-1));
NDArray query_output = query_handle->outputs[0];
std::vector<std::vector<NDArray>> samples = pick_samples(query_output, memory, k, dist_measure);
Operator slice("slice_axis");
slice.SetParam("axis", 1);
slice.SetInput("data", samples[1][0]);
NDArray labels;
std::vector<std::vector<NDArray>> samples = pick_samples(query_output, memory, k, num_sub_net_outputs[net_start_ind-1], output_shapes.size());
for(mx_uint i=0; i < gradient_steps; i++){
for(mx_uint j=0; j < k; j++){
......@@ -202,13 +200,20 @@ public:
samples[0][t].Slice(j,j+1).CopyTo(&(network_handles[net_start_ind]->arg_dict()["data" + std::to_string(t)]));
}
}
slice.SetParam("begin", j);
slice.SetParam("end", j+1);
slice.Invoke(labels);
std::vector<NDArray> labels;
for(mx_uint t=0; t < samples[1].size(); t++){
Operator slice("slice_axis");
slice.SetParam("axis", 0);
slice.SetInput("data", samples[1][t]);
slice.SetParam("begin", j);
slice.SetParam("end", j+1);
labels.push_back(slice.Invoke()[0]);
}
network_handles[net_start_ind]->Forward(true);
CheckMXNetError("Network forward, local_adapt, handle ind. " + std::to_string(net_start_ind));
for(int k=net_start_ind+1; k < network_handles.size(); k++){
prev_output = network_handles[k-1]->outputs;
if(num_sub_net_outputs[k-1] == 1){
......@@ -229,7 +234,7 @@ public:
for(size_t k=0; k < num_outputs; k++){
std::vector<NDArray> network_output = network_handles.back()->outputs;
network_output[k].CopyTo(&(loss_handles[k]->arg_dict()[loss_keys[0]]));
labels.Slice(k,k+1).CopyTo(&(loss_handles[k]->arg_dict()[loss_keys[1]]));
labels[k].CopyTo(&(loss_handles[k]->arg_dict()[loss_keys[1]]));
NDArray::WaitAll();
loss_handles[k]->Forward(true);
......@@ -245,7 +250,7 @@ public:
for(int k=network_handles.size()-1; k >= net_start_ind; k--){
network_handles[k]->Backward(last_grads);
CheckMXNetError("Network backward, local_adapt, handle ind. " + std::to_string(k));
last_grads = {};
if(num_sub_net_outputs[k-1] == 1){
last_grads.push_back(network_handles[k]->grad_dict()["data"]);
......@@ -257,9 +262,9 @@ public:
}
for(size_t k=net_start_ind; k < network_arg_names.size(); ++k) {
for(size_t k=0; k < network_arg_names[k].size(); k++){
if (network_arg_names[k][k].find("data") != 0) continue;
optimizerHandle->Update(k, network_handles[k]->arg_arrays[k], network_handles[k]->grad_arrays[k]);
for(size_t t=0; t < network_arg_names[k].size(); t++){
//if (network_arg_names[k][t].find("data") != 0) continue;
optimizerHandle->Update(t, network_handles[k]->arg_arrays[t], network_handles[k]->grad_arrays[t]);
}
}
NDArray::WaitAll();
......@@ -267,52 +272,44 @@ public:
}
}
NDArray innerProdDist(NDArray &vec1, NDArray &vec2){
Operator dot("dot");
dot.SetParam("transpose_b", true);
NDArray ret;
dot.SetInput("lhs", vec1);
dot.SetInput("rhs", vec2);
dot.Invoke(ret);
return ret;
}
NDArray l2Norm(NDArray &vec1, NDArray &vec2){
Operator br_diff("broadcast_sub");
Operator l2_norm("norm");
l2_norm.SetParam("axis", 1);
Operator elem_square("square");
Operator batch_sum("sum");
batch_sum.SetParam("exclude", 1);
batch_sum.SetParam("axis", 0);
Operator elem_sqrt("sqrt");
NDArray diff;
br_diff.SetInput("lhs", vec1);
br_diff.SetInput("rhs", vec2);
br_diff.Invoke(diff);
NDArray ret;
l2_norm.SetInput("data", diff);
l2_norm.Invoke(ret);
NDArray sq;
elem_square.SetInput("data", diff);
elem_square.Invoke(sq);
return ret;
NDArray sum;
batch_sum.SetInput("data", sq);
batch_sum.Invoke(sum);
NDArray sqrt;
elem_sqrt.SetInput("data", sum);
elem_sqrt.Invoke(sqrt);
return sqrt;
}
std::vector<std::vector<NDArray>> pick_samples(NDArray query_output, std::map<std::string, NDArray> memory, mx_uint k, std::string dist_measure){
std::vector<std::vector<NDArray>> pick_samples(NDArray query_output, std::map<std::string, NDArray> memory, mx_uint k, mx_uint num_values, mx_uint num_labels){
Operator top_k("topk");
top_k.SetParam("k", k);
top_k.SetParam("ret_typ", "indices");
top_k.SetParam("dtype", "float32");
NDArray dist;
if(dist_measure == "l2"){
dist = l2Norm(query_output, memory["keys"]);
}else if(dist_measure == "inner_prod"){
dist = innerProdDist(query_output, memory["keys"]);
}else{
throw std::invalid_argument("Provided value for dist_measure is not supported.");
}
dist = l2Norm(query_output, memory["keys"]);
NDArray indices;
top_k.SetInput("data", dist);
top_k.Invoke(indices);
......@@ -325,22 +322,22 @@ public:
Probably a bug in mxnet-cpp.
*/
std::vector<NDArray> vals;
for(int i=0; i < memory.size()-2; i++){
for(int i=0; i < num_values; i++){
Operator take_values("take");
take_values.SetInput("a", memory["values_" + std::to_string(i)]);
take_values.SetInput("indices", indices);
vals.push_back(take_values.Invoke()[0]);
}
ret.push_back(vals);
std::vector<NDArray> labs;
Operator take_labels("take");
take_labels.SetParam("axis", 1);
take_labels.SetInput("a", memory["labels"]);
take_labels.SetInput("indices", indices);
labs.push_back(take_labels.Invoke()[0]);
for(int i=0; i < num_labels; i++){
Operator take_labels("take");
take_labels.SetInput("a", memory["labels_" + std::to_string(i)]);
take_labels.SetInput("indices", indices);
labs.push_back(take_labels.Invoke()[0]);
}
ret.push_back(labs);
return ret;
}
</#if>
......@@ -418,21 +415,23 @@ ${tc.include(networkInstruction.body, "PREDICTION_PARAMETER")}
replay_query_param_maps = model_loader.GetQueryParamMaps();
replay_memory = model_loader.GetReplayMemory();
std::vector<mx_uint> label_memory_shape = replay_memory[0]["labels"].GetShape();
std::vector<mx_uint> lab_shape;
if(label_memory_shape.size() == 2){
lab_shape = {1, 1};
}else{
lab_shape.push_back(1);
for(mx_uint i=2; i<label_memory_shape.size(); i++){
lab_shape.push_back(label_memory_shape[i]);
for(mx_uint i = 0; i < num_outputs; i++){
std::vector<mx_uint> label_memory_shape = replay_memory[0]["labels_" + std::to_string(i)].GetShape();
if(label_memory_shape.size() == 1){
lab_shape = {1};
}else{
lab_shape.push_back(1);
for(mx_uint i=1; i<label_memory_shape.size(); i++){
lab_shape.push_back(label_memory_shape[i]);
}
}
}
std::vector<std::vector<mx_uint>> label_shapes;
for(mx_uint i=0; i < num_outputs; i++){
label_shapes.push_back(lab_shape);