C++ part for multiple input to replay sub nets; bug fixes

parent edb591ad
Pipeline #287905 failed with stage
in 1 minute and 20 seconds
......@@ -367,7 +367,7 @@ class ReplayMemory(ReplayMemoryInterface):
def hybrid_forward(self, F, *args):
#propagate the input as the rest is only used for replay
return [args[0], []]
return [args, []]
def store_samples(self, data, y, query_network, store_prob, mx_context):
x = data[0]
......@@ -459,7 +459,7 @@ ${tc.include(networkInstruction.body, elements?index, "ARCHITECTURE_DEFINITION")
def hybrid_forward(self, F, ${tc.join(tc.getStreamInputNames(networkInstruction.body, false), ", ")}):
${tc.include(networkInstruction.body, elements?index, "FORWARD_FUNCTION")}
return [[${tc.getName(elements[elements?size-1])}]]
return [[${tc.join(tc.getSubnetOutputNames(elements), ", ")}]]
<#else>
super(ReplaySubNet_${elements?index}, self).__init__(**kwargs)
with self.name_scope():
......@@ -469,7 +469,7 @@ ${tc.include(networkInstruction.body, elements?index, "ARCHITECTURE_DEFINITION")
def hybrid_forward(self, F, *args):
${tc.include(networkInstruction.body, elements?index, "FORWARD_FUNCTION")}
return [[${tc.join(tc.getSubnetOutputNames(elements), ", ")}], [${tc.join(tc.getSubnetInputNames(elements), ", ")}, ind_${tc.join(tc.getSubnetInputNames(elements), ", ")}]]
return [[${tc.join(tc.getSubnetOutputNames(elements), ", ")}], [${tc.getSubnetInputNames(elements)[0]}full_, ind_${tc.join(tc.getSubnetInputNames(elements), ", ")}]]
</#if>
</#list>
......@@ -502,7 +502,7 @@ ${tc.include(networkInstruction.body, "ARCHITECTURE_DEFINITION")}
<#if elements?index == 0>
replaysubnet${elements?index}_ = self.replaysubnet${elements?index}_(${tc.join(tc.getStreamInputNames(networkInstruction.body, false), ", ")})
<#else>
replaysubnet${elements?index}_ = self.replay_sub_nets[${elements?index-1}](replaysubnet${elements?index - 1}_[0][0])
replaysubnet${elements?index}_ = self.replay_sub_nets[${elements?index-1}](*replaysubnet${elements?index - 1}_[0])
</#if>
</#list>
return [replaysubnet${networkInstruction.body.replaySubNetworks?size - 1}_[0], [<#list networkInstruction.body.replaySubNetworks as elements><#if elements?index != 0>replaysubnet${elements?index}_[1], </#if></#list>]]
......
......@@ -33,10 +33,6 @@ public:
<#if networkInstruction.body.replaySubNetworks?has_content>
//replay_memory
//replay_sub_nets
const std::vector<std::string> sub_network_input_key = { //currently the same for all replay subnetworks
"data"
};
//loss
const std::vector<std::string> loss_input_keys = {
"data0",
......@@ -45,6 +41,7 @@ public:
mx_uint num_outputs = ${tc.getStreamOutputNames(networkInstruction.body, false)?size};
const std::vector<std::vector<mx_uint>> output_shapes = {<#list tc.getStreamOutputDimensions(networkInstruction.body) as dimensions>{1, ${tc.join(tc.cutDimensions(dimensions), ", ")}}<#sep>, </#list>};
std::vector<mx_uint> num_sub_net_outputs = {<#list networkInstruction.body.replaySubNetworks as elements> ${tc.getSubnetOutputNames(elements)?size}<#sep>, </#list>};
std::vector<Executor *> loss_handles;
//replay query nets
......@@ -116,13 +113,14 @@ public:
<#if networkInstruction.body.replaySubNetworks?has_content>
for(int i=1; i < network_handles.size(); i++){
NDArray prev_output = network_handles[i-1]->outputs[0];
prev_output.CopyTo(&(network_handles[i]->arg_dict()[sub_network_input_key[0]]));
NDArray::WaitAll();
network_handles[i]->Forward(false);
CheckMXNetError("Forward, predict, handle ind. " + std::to_string(i));
}
std::vector<NDArray> prev_output = network_handles[i-1]->outputs;
if(num_sub_net_outputs[i-1] = 1){
prev_output[0].CopyTo(&(network_handles[i]->arg_dict()["data"]));
}else{
for(int j = 0; j<num_sub_net_outputs[i-1]; j++){
prev_output[j].CopyTo(&(network_handles[i]->arg_dict()["data" + std::to_string(j)]));
}
}
NDArray::WaitAll();
</#if>
......@@ -153,7 +151,7 @@ public:
std::string dist_measure,
mx_uint k){
NDArray prev_output;
std::vector<NDArray> prev_output;
for(size_t i=0; i < in_keys.size(); i++){
NDArray input_temp(in_shapes[i], ctx, false, dtype);
......@@ -165,15 +163,21 @@ public:
for(size_t i=0; i < net_start_ind; i++){
network_handles[i]->Forward(false);
CheckMXNetError("Network forward, local_adapt, handle ind. " + std::to_string(i));
prev_output = network_handles[i]->outputs[0];
prev_output = network_handles[i]->outputs;
if(i+1 < net_start_ind){
prev_output.CopyTo(&(network_handles[i+1]->arg_dict()[sub_network_input_key[0]]));
if(num_sub_net_outputs[i] == 1){
prev_output[0].CopyTo(&(network_handles[i+1]->arg_dict()["data"]));
}else{
for(size_t j=0; j<num_sub_net_outputs[i]; j++){
prev_output[j].CopyTo(&(network_handles[i+1]->arg_dict()["data" + std::to_string(j)]));
}
}
NDArray::WaitAll();
}
}
prev_output.CopyTo(&(query_handle->arg_dict()[replay_query_input_keys[0]]));
prev_output[0].CopyTo(&(query_handle->arg_dict()[replay_query_input_keys[0]]));
NDArray::WaitAll();
query_handle->Forward(false);
......@@ -184,21 +188,34 @@ public:
Operator slice("slice_axis");
slice.SetParam("axis", 1);
slice.SetInput("data", samples[1]);
slice.SetInput("data", samples[1][0]);
NDArray labels;
for(mx_uint i=0; i < gradient_steps; i++){
for(mx_uint j=0; j < k; j++){
samples[0].Slice(j,j+1).CopyTo(&(network_handles[net_start_ind]->arg_dict()[sub_network_input_key[0]]));
if(samples[0].size() == 1){
samples[0][0].Slice(j,j+1).CopyTo(&(network_handles[net_start_ind]->arg_dict()["data"]));
}else{
for(mx_uint t=0; t<samples[0].size(); t++){
samples[0][t].Slice(j,j+1).CopyTo(&(network_handles[net_start_ind]->arg_dict()["data" + std::to_string(t)]));
}
}
slice.SetParam("begin", j);
slice.SetParam("end", j+1);
slice.Invoke(labels);
network_handles[net_start_ind]->Forward(true);
CheckMXNetError("Network forward, local_adapt, handle ind. " + std::to_string(net_start_ind));
for(int k=net_start_ind+1; k < network_handles.size(); k++){
prev_output = network_handles[k-1]->outputs[0];
prev_output.CopyTo(&(network_handles[k]->arg_dict()[sub_network_input_key[0]]));
prev_output = network_handles[k-1]->outputs;
if(num_sub_net_outputs[k-1] == 1){
prev_output[0].CopyTo(&(network_handles[k]->arg_dict()["data"]));
}else{
for(mx_uint t=0; t<num_sub_net_outputs[k-1]; t++){
prev_output[t].CopyTo(&(network_handles[k]->arg_dict()["data" + std::to_string(t)]));
}
}
NDArray::WaitAll();
network_handles[k]->Forward(true);
......@@ -226,7 +243,15 @@ public:
for(int k=network_handles.size()-1; k >= net_start_ind; k--){
network_handles[k]->Backward(last_grads);
CheckMXNetError("Network backward, local_adapt, handle ind. " + std::to_string(k));
last_grads = {network_handles[k]->grad_dict()[sub_network_input_key[0]]};
last_grads = {};
if(num_sub_net_outputs[k-1] == 1){
last_grads.push_back(network_handles[k]->grad_dict()["data"]);
}else{
for(mx_uint t=0; t<num_sub_net_outputs[k-1]; t++){
last_grads.push_back(network_handles[k]->grad_dict()["data" + std::to_string(t)]);
}
}
}
for(size_t k=net_start_ind; k < network_arg_names.size(); ++k) {
......@@ -303,6 +328,7 @@ public:
take_values.SetInput("a", memory["values_" + std::to_string(i)]);
take_values.SetInput("indices", indices);
vals.push_back(take_values.Invoke()[0]);
}
ret.push_back(vals);
std::vector<NDArray> labs;
......
......@@ -21,12 +21,11 @@
query_net_prefix="${queryNetPrefix}")
<#elseif mode == "FORWARD_FUNCTION">
<#if useReplay == "True" || useLocalAdaption == "true">
${element.name}full_, ind_${element.name} = self.${element.name}(args)
${element.name} = ${element.name}full_
${element.name}full_, ind_${element.name} = self.${element.name}(*args)
<#else>
${element.name}full_, ind_${element.name} = self.${element.name}(${input})
${element.name} = ${element.name}full_
</#if>
${element.name} = ${element.name}full_[0]
<#elseif mode == "PREDICTION_PARAMETER">
use_local_adaption.push_back(${useLocalAdaption});
dist_measure.push_back("${replayMemoryStoreDistMeasure}");
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment