Commit 039f6e64 authored by Sebastian N.'s avatar Sebastian N.
Browse files

Cleaned up layers, simplified dimensions

parent 95a67ac4
...@@ -28,6 +28,11 @@ public class CNNArch2GluonLayerSupportChecker extends LayerSupportChecker { ...@@ -28,6 +28,11 @@ public class CNNArch2GluonLayerSupportChecker extends LayerSupportChecker {
supportedLayerList.add(AllPredefinedLayers.GRU_NAME); supportedLayerList.add(AllPredefinedLayers.GRU_NAME);
supportedLayerList.add(AllPredefinedLayers.EMBEDDING_NAME); supportedLayerList.add(AllPredefinedLayers.EMBEDDING_NAME);
supportedLayerList.add(AllPredefinedLayers.ARG_MAX_NAME); supportedLayerList.add(AllPredefinedLayers.ARG_MAX_NAME);
supportedLayerList.add(AllPredefinedLayers.REPEAT_NAME);
supportedLayerList.add(AllPredefinedLayers.DOT_NAME);
supportedLayerList.add(AllPredefinedLayers.EXPAND_DIMS_NAME);
supportedLayerList.add(AllPredefinedLayers.SQUEEZE_NAME);
supportedLayerList.add(AllPredefinedLayers.SWAPAXES_NAME);
} }
} }
...@@ -26,6 +26,7 @@ import de.monticore.lang.monticar.cnnarch.generator.CNNArchTemplateController; ...@@ -26,6 +26,7 @@ import de.monticore.lang.monticar.cnnarch.generator.CNNArchTemplateController;
import de.monticore.lang.monticar.cnnarch._symboltable.*; import de.monticore.lang.monticar.cnnarch._symboltable.*;
import de.monticore.lang.monticar.cnnarch.generator.TemplateConfiguration; import de.monticore.lang.monticar.cnnarch.generator.TemplateConfiguration;
import de.monticore.lang.monticar.cnnarch.predefined.AllPredefinedLayers; import de.monticore.lang.monticar.cnnarch.predefined.AllPredefinedLayers;
import de.se_rwth.commons.logging.Log;
import java.io.Writer; import java.io.Writer;
import java.util.*; import java.util.*;
...@@ -61,7 +62,9 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController { ...@@ -61,7 +62,9 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController {
} }
} }
else if (element.getType() == VariableSymbol.Type.LAYER) { else if (element.getType() == VariableSymbol.Type.LAYER) {
include(TEMPLATE_ELEMENTS_DIR_PATH, element.getLayerVariableDeclaration().getLayer().getName(), writer, netDefinitionMode); if (element.getMember() != VariableSymbol.Member.OUTPUT) {
include(TEMPLATE_ELEMENTS_DIR_PATH, element.getLayerVariableDeclaration().getLayer().getName(), writer, netDefinitionMode);
}
} }
} }
else { else {
...@@ -137,25 +140,13 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController { ...@@ -137,25 +140,13 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController {
include(architectureElement, getWriter(), netDefinitionMode); include(architectureElement, getWriter(), netDefinitionMode);
} }
public Set<String> getStreamInputNames(SerialCompositeElementSymbol stream, boolean addStateIndex) { public Set<String> getStreamInputNames(SerialCompositeElementSymbol stream) {
if(addStateIndex) { return getStreamInputs(stream).keySet();
Set<String> names = getStreamInputs(stream, addStateIndex).keySet();
Set<String> newNames = new LinkedHashSet<>();
for (String name : names) {
// if LSTM state, transform name into list of hidden state and cell state
if (name.endsWith("_state_")) {
name = "[" + name + "[0], " + name + "[1]]";
}
newNames.add(name);
}
return newNames;
}
return getStreamInputs(stream, addStateIndex).keySet();
} }
// used for unroll // used for unroll
public List<String> getStreamInputNames(SerialCompositeElementSymbol stream, SerialCompositeElementSymbol currentStream, boolean addStateIndex) { public List<String> getStreamInputNames(SerialCompositeElementSymbol stream, SerialCompositeElementSymbol currentStream) {
List<String> inputNames = new LinkedList<>(getStreamInputNames(stream, addStateIndex)); List<String> inputNames = new LinkedList<>(getStreamInputNames(stream));
Map<String, String> pairs = getUnrollPairs(stream, currentStream); Map<String, String> pairs = getUnrollPairs(stream, currentStream);
for (int i = 0; i != inputNames.size(); ++i) { for (int i = 0; i != inputNames.size(); ++i) {
...@@ -167,28 +158,8 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController { ...@@ -167,28 +158,8 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController {
return inputNames; return inputNames;
} }
public Collection<List<String>> getStreamInputDimensions(SerialCompositeElementSymbol stream, boolean useStateDim) { public Collection<List<String>> getStreamInputDimensions(SerialCompositeElementSymbol stream) {
if(useStateDim) { return getStreamInputs(stream).values();
return getStreamInputs(stream, false).values();
}else{
Set<String> names = getStreamInputs(stream, true).keySet();
List<List<String>> dims = new ArrayList<List<String>>(getStreamInputs(stream, false).values());
List<List<String>> result = new ArrayList<List<String>>();
int index = 0;
for (String name : names) {
if (name.endsWith("_state_") || name.endsWith("_state_[0]")) {
ArrayList dim = new ArrayList<String>();
dim.add("-1");
dim.add(name.replace("_state_", "_output_.begin_state(batch_size=1, ctx=context)"));
result.add(dim);
}else{
result.add(dims.get(index));
}
index++;
}
return result;
}
} }
public Set<String> getStreamOutputNames(SerialCompositeElementSymbol stream) { public Set<String> getStreamOutputNames(SerialCompositeElementSymbol stream) {
...@@ -200,7 +171,7 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController { ...@@ -200,7 +171,7 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController {
} }
} }
outputNames.addAll(getStreamLayerVariableMembers(stream, "1", true, false, false).keySet()); outputNames.addAll(getStreamLayerVariableMembers(stream, true).keySet());
return outputNames; return outputNames;
} }
...@@ -220,25 +191,11 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController { ...@@ -220,25 +191,11 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController {
} }
// Used to initialize all layer variable members which are passed through the networks // Used to initialize all layer variable members which are passed through the networks
public Map<String, List<List<String>>> getLayerVariableMembers(String batchSize, boolean includeStates) { public Map<String, List<String>> getLayerVariableMembers() {
Map<String, List<List<String>>> members = new LinkedHashMap<>(); Map<String, List<String>> members = new LinkedHashMap<>();
int index = 0;
for (SerialCompositeElementSymbol stream : getArchitecture().getStreams()) { for (SerialCompositeElementSymbol stream : getArchitecture().getStreams()) {
List<List<String>> value = new ArrayList<>(); members.putAll(getStreamLayerVariableMembers(stream, true));
Map<String, List<String>> member = getStreamLayerVariableMembers(stream, batchSize, true, includeStates, false);
for (List<String> entry: member.values()){
value.add(entry);
ArrayList<String> streamIndex = new ArrayList<String>();
streamIndex.add(Integer.toString(index));
value.add(streamIndex);
}
for(String name: member.keySet()){
if(!members.containsKey(name)) {
members.put(name, value);
}
}
index++;
} }
return members; return members;
...@@ -272,7 +229,7 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController { ...@@ -272,7 +229,7 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController {
return pairs; return pairs;
} }
private Map<String, List<String>> getStreamInputs(SerialCompositeElementSymbol stream, boolean addStateIndex) { private Map<String, List<String>> getStreamInputs(SerialCompositeElementSymbol stream) {
Map<String, List<String>> inputs = new LinkedHashMap<>(); Map<String, List<String>> inputs = new LinkedHashMap<>();
for (ArchitectureElementSymbol element : stream.getFirstAtomicElements()) { for (ArchitectureElementSymbol element : stream.getFirstAtomicElements()) {
...@@ -284,19 +241,16 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController { ...@@ -284,19 +241,16 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController {
dimensions.add(intDimension.toString()); dimensions.add(intDimension.toString());
} }
// Add batch size dimension
dimensions.add(0, "1");
inputs.put(getName(element), dimensions); inputs.put(getName(element), dimensions);
} }
} }
inputs.putAll(getStreamLayerVariableMembers(stream, "1", false, false, addStateIndex)); inputs.putAll(getStreamLayerVariableMembers(stream, false));
return inputs; return inputs;
} }
private Map<String, List<String>> getStreamLayerVariableMembers(SerialCompositeElementSymbol stream, String batchSize, boolean includeOutput, boolean includeStates, boolean addStateIndex) { private Map<String, List<String>> getStreamLayerVariableMembers(SerialCompositeElementSymbol stream, boolean includeOutput) {
Map<String, List<String>> members = new LinkedHashMap<>(); Map<String, List<String>> members = new LinkedHashMap<>();
List<ArchitectureElementSymbol> elements = stream.getSpannedScope().resolveLocally(ArchitectureElementSymbol.KIND); List<ArchitectureElementSymbol> elements = stream.getSpannedScope().resolveLocally(ArchitectureElementSymbol.KIND);
...@@ -304,19 +258,20 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController { ...@@ -304,19 +258,20 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController {
if (element instanceof VariableSymbol) { if (element instanceof VariableSymbol) {
VariableSymbol variable = (VariableSymbol) element; VariableSymbol variable = (VariableSymbol) element;
if (variable.getType() == VariableSymbol.Type.LAYER && (variable.getMember() == VariableSymbol.Member.NONE || includeStates)) { if (variable.getType() == VariableSymbol.Type.LAYER && (variable.getMember() == VariableSymbol.Member.NONE)) {
LayerVariableDeclarationSymbol layerVariableDeclaration = variable.getLayerVariableDeclaration(); LayerVariableDeclarationSymbol layerVariableDeclaration = variable.getLayerVariableDeclaration();
if (layerVariableDeclaration.getLayer().getDeclaration().isPredefined()) { if (layerVariableDeclaration.getLayer().getDeclaration().isPredefined()) {
PredefinedLayerDeclaration predefinedLayerDeclaration = PredefinedLayerDeclaration predefinedLayerDeclaration =
(PredefinedLayerDeclaration) layerVariableDeclaration.getLayer().getDeclaration(); (PredefinedLayerDeclaration) layerVariableDeclaration.getLayer().getDeclaration();
if (predefinedLayerDeclaration.isValidMember(VariableSymbol.Member.STATE)) { int arrayLength = predefinedLayerDeclaration.getArrayLength(VariableSymbol.Member.STATE);
String name;
if(addStateIndex && predefinedLayerDeclaration.getName().equals(AllPredefinedLayers.GRU_NAME)){ for (int i = 0; i < arrayLength; ++i) {
name = variable.getName() + "_state_[0]"; String name = variable.getName() + "_state_";
}else{
name = variable.getName() + "_state_"; if (arrayLength > 1) {
name += i + "_";
} }
List<Integer> intDimensions = predefinedLayerDeclaration.computeOutputTypes( List<Integer> intDimensions = predefinedLayerDeclaration.computeOutputTypes(
...@@ -331,17 +286,19 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController { ...@@ -331,17 +286,19 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController {
dimensions.add(intDimension.toString()); dimensions.add(intDimension.toString());
} }
// Add batch size dimension at index 1, since RNN states in Gluon have the format
// (layers, batch_size, units)
dimensions.add(1, batchSize);
members.put(name, dimensions); members.put(name, dimensions);
} }
if (includeOutput) { if (includeOutput) {
if (predefinedLayerDeclaration.isValidMember(VariableSymbol.Member.OUTPUT)) { arrayLength = predefinedLayerDeclaration.getArrayLength(VariableSymbol.Member.OUTPUT);
for (int i = 0; i < arrayLength; ++i) {
String name = variable.getName() + "_output_"; String name = variable.getName() + "_output_";
if (arrayLength > 1) {
name += i + "_";
}
List<Integer> intDimensions = predefinedLayerDeclaration.computeOutputTypes( List<Integer> intDimensions = predefinedLayerDeclaration.computeOutputTypes(
layerVariableDeclaration.getLayer().getInputTypes(), layerVariableDeclaration.getLayer().getInputTypes(),
layerVariableDeclaration.getLayer(), layerVariableDeclaration.getLayer(),
...@@ -354,9 +311,6 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController { ...@@ -354,9 +311,6 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController {
dimensions.add(intDimension.toString()); dimensions.add(intDimension.toString());
} }
// Add batch size dimension at index 0, since we use NTC format for RNN output in Gluon
dimensions.add(0, batchSize);
members.put(name, dimensions); members.put(name, dimensions);
} }
} }
...@@ -367,6 +321,15 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController { ...@@ -367,6 +321,15 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController {
return members; return members;
} }
// cuts
public List<String> cutDimensions(List<String> dimensions) {
while (dimensions.size() > 1 && dimensions.get(dimensions.size() - 1).equals("1")) {
dimensions.remove(dimensions.size() - 1);
}
return dimensions;
}
public int getBeamSearchWidth(UnrollInstructionSymbol unroll){ public int getBeamSearchWidth(UnrollInstructionSymbol unroll){
return unroll.getIntValue(AllPredefinedLayers.WIDTH_NAME).get(); return unroll.getIntValue(AllPredefinedLayers.WIDTH_NAME).get();
} }
......
...@@ -57,8 +57,7 @@ class ${tc.fileNameWithoutEnding}: ...@@ -57,8 +57,7 @@ class ${tc.fileNameWithoutEnding}:
self.networks[${networkInstruction?index}] = Net_${networkInstruction?index}(data_mean=data_mean, data_std=data_std) self.networks[${networkInstruction?index}] = Net_${networkInstruction?index}(data_mean=data_mean, data_std=data_std)
self.networks[${networkInstruction?index}].collect_params().initialize(self.weight_initializer, ctx=context) self.networks[${networkInstruction?index}].collect_params().initialize(self.weight_initializer, ctx=context)
self.networks[${networkInstruction?index}].hybridize() self.networks[${networkInstruction?index}].hybridize()
self.networks[${networkInstruction?index}](<#list tc.getStreamInputDimensions(networkInstruction.body, false) as dimensions> self.networks[${networkInstruction?index}](<#list tc.getStreamInputDimensions(networkInstruction.body) as dimensions>mx.nd.zeros((1, ${tc.join(tc.cutDimensions(dimensions), ",")},), ctx=context)<#sep>, </#list>)
<#if dimensions[0] == "-1">self.networks[${networkInstruction?index}].${dimensions[1]}<#else>mx.nd.zeros((${tc.join(dimensions, ",")},), ctx=context)</#if> <#sep>, </#list>)
</#if> </#if>
</#list> </#list>
......
...@@ -2,88 +2,6 @@ import mxnet as mx ...@@ -2,88 +2,6 @@ import mxnet as mx
import numpy as np import numpy as np
from mxnet import gluon from mxnet import gluon
class OneHot(gluon.HybridBlock):
def __init__(self, size, **kwargs):
super(OneHot, self).__init__(**kwargs)
with self.name_scope():
self.size = size
def hybrid_forward(self, F, x):
return F.one_hot(indices=F.argmax(data=x, axis=1), depth=self.size)
class Softmax(gluon.HybridBlock):
def __init__(self, **kwargs):
super(Softmax, self).__init__(**kwargs)
def hybrid_forward(self, F, x):
return F.softmax(x)
class Split(gluon.HybridBlock):
def __init__(self, num_outputs, axis=1, **kwargs):
super(Split, self).__init__(**kwargs)
with self.name_scope():
self.axis = axis
self.num_outputs = num_outputs
def hybrid_forward(self, F, x):
return F.split(data=x, axis=self.axis, num_outputs=self.num_outputs)
class Concatenate(gluon.HybridBlock):
def __init__(self, dim=1, **kwargs):
super(Concatenate, self).__init__(**kwargs)
with self.name_scope():
self.dim = dim
def hybrid_forward(self, F, *x):
return F.concat(*x, dim=self.dim)
class Repeat(gluon.HybridBlock):
def __init__(self, repeats, axis=1, **kwargs):
super(Repeat, self).__init__(**kwargs)
with self.name_scope():
self.axis = axis
self.repeats = repeats
def hybrid_forward(self, F, x):
return F.repeat(data=x, axis=self.axis, repeats=self.repeats)
class Dot(gluon.HybridBlock):
def __init__(self, **kwargs):
super(Dot, self).__init__(**kwargs)
def hybrid_forward(self, F, *x):
return F.batch_dot(*x)
class ExpandDims(gluon.HybridBlock):
def __init__(self, dim=1, **kwargs):
super(ExpandDims, self).__init__(**kwargs)
with self.name_scope():
self.dim = dim
def hybrid_forward(self, F, x):
return F.expand_dims(data=x, axis=self.dim)
class SwapAxes(gluon.HybridBlock):
def __init__(self, dim1, dim2, **kwargs):
super(SwapAxes, self).__init__(**kwargs)
with self.name_scope():
self.dim1 = dim1
self.dim2 = dim2
def hybrid_forward(self, F, x):
return F.swapaxes(data=x, dim1=self.dim1, dim2=self.dim2)
class ReduceSum(gluon.HybridBlock):
def __init__(self, axis=1, **kwargs):
super(ReduceSum, self).__init__(**kwargs)
with self.name_scope():
self.axis = axis
def hybrid_forward(self, F, x):
return F.sum(data=x, axis=self.axis)
class ZScoreNormalization(gluon.HybridBlock): class ZScoreNormalization(gluon.HybridBlock):
def __init__(self, data_mean, data_std, **kwargs): def __init__(self, data_mean, data_std, **kwargs):
...@@ -122,6 +40,42 @@ class NoNormalization(gluon.HybridBlock): ...@@ -122,6 +40,42 @@ class NoNormalization(gluon.HybridBlock):
return x return x
class CustomRNN(gluon.HybridBlock):
def __init__(self, hidden_size, num_layers, bidirectional, **kwargs):
super(CustomRNN, self).__init__(**kwargs)
with self.name_scope():
self.rnn = gluon.rnn.RNN(hidden_size=hidden_size, num_layers=num_layers,
bidirectional=bidirectional, activation='tanh', layout='NTC')
def hybrid_forward(self, F, data, state0):
output, [state0] = self.rnn(data, [F.swapaxes(state0, 0, 1)])
return output, F.swapaxes(state0, 0, 1)
class CustomLSTM(gluon.HybridBlock):
def __init__(self, hidden_size, num_layers, bidirectional, **kwargs):
super(CustomLSTM, self).__init__(**kwargs)
with self.name_scope():
self.lstm = gluon.rnn.LSTM(hidden_size=hidden_size, num_layers=num_layers,
bidirectional=bidirectional, layout='NTC')
def hybrid_forward(self, F, data, state0, state1):
output, [state0, state1] = self.lstm(data, [F.swapaxes(state0, 0, 1), F.swapaxes(state1, 0, 1)])
return output, F.swapaxes(state0, 0, 1), F.swapaxes(state1, 0, 1)
class CustomGRU(gluon.HybridBlock):
def __init__(self, hidden_size, num_layers, bidirectional, **kwargs):
super(CustomGRU, self).__init__(**kwargs)
with self.name_scope():
self.gru = gluon.rnn.GRU(hidden_size=hidden_size, num_layers=num_layers,
bidirectional=bidirectional, layout='NTC')
def hybrid_forward(self, F, data, state0):
output, [state0] = self.gru(data, [F.swapaxes(state0, 0, 1)])
return output, F.swapaxes(state0, 0, 1)
<#list tc.architecture.networkInstructions as networkInstruction> <#list tc.architecture.networkInstructions as networkInstruction>
<#if networkInstruction.body.isTrainable()> <#if networkInstruction.body.isTrainable()>
class Net_${networkInstruction?index}(gluon.HybridBlock): class Net_${networkInstruction?index}(gluon.HybridBlock):
...@@ -131,7 +85,7 @@ class Net_${networkInstruction?index}(gluon.HybridBlock): ...@@ -131,7 +85,7 @@ class Net_${networkInstruction?index}(gluon.HybridBlock):
with self.name_scope(): with self.name_scope():
${tc.include(networkInstruction.body, "ARCHITECTURE_DEFINITION")} ${tc.include(networkInstruction.body, "ARCHITECTURE_DEFINITION")}
def hybrid_forward(self, F, ${tc.join(tc.getStreamInputNames(networkInstruction.body, false), ", ")}): def hybrid_forward(self, F, ${tc.join(tc.getStreamInputNames(networkInstruction.body), ", ")}):
${tc.include(networkInstruction.body, "FORWARD_FUNCTION")} ${tc.include(networkInstruction.body, "FORWARD_FUNCTION")}
return ${tc.join(tc.getStreamOutputNames(networkInstruction.body), ", ")} return ${tc.join(tc.getStreamOutputNames(networkInstruction.body), ", ")}
......
...@@ -16,13 +16,13 @@ public: ...@@ -16,13 +16,13 @@ public:
const std::string json_file = "model/${tc.componentName}/model_${networkInstruction?index}_newest-symbol.json"; const std::string json_file = "model/${tc.componentName}/model_${networkInstruction?index}_newest-symbol.json";
const std::string param_file = "model/${tc.componentName}/model_${networkInstruction?index}_newest-0000.params"; const std::string param_file = "model/${tc.componentName}/model_${networkInstruction?index}_newest-0000.params";
const std::vector<std::string> input_keys = { const std::vector<std::string> input_keys = {
<#if tc.getStreamInputNames(networkInstruction.body, false)?size == 1> <#if tc.getStreamInputNames(networkInstruction.body)?size == 1>
"data" "data"
<#else> <#else>
<#list tc.getStreamInputNames(networkInstruction.body, false) as variable>"data${variable?index}"<#sep>, </#list> <#list tc.getStreamInputNames(networkInstruction.body) as variable>"data${variable?index}"<#sep>, </#list>
</#if> </#if>
}; };
const std::vector<std::vector<mx_uint>> input_shapes = {<#list tc.getStreamInputDimensions(networkInstruction.body, true) as dimensions>{${tc.join(dimensions, ", ")}}<#sep>, </#list>}; const std::vector<std::vector<mx_uint>> input_shapes = {<#list tc.getStreamInputDimensions(networkInstruction.body) as dimensions>{${tc.join(dimensions, ", ")}}<#sep>, </#list>};
const bool use_gpu = false; const bool use_gpu = false;
PredictorHandle handle; PredictorHandle handle;
...@@ -35,9 +35,9 @@ public: ...@@ -35,9 +35,9 @@ public:
if(handle) MXPredFree(handle); if(handle) MXPredFree(handle);
} }
void predict(${tc.join(tc.getStreamInputNames(networkInstruction.body, false), ", ", "const std::vector<float> &in_", "")}, void predict(${tc.join(tc.getStreamInputNames(networkInstruction.body), ", ", "const std::vector<float> &in_", "")},
${tc.join(tc.getStreamOutputNames(networkInstruction.body), ", ", "std::vector<float> &out_", "")}){ ${tc.join(tc.getStreamOutputNames(networkInstruction.body), ", ", "std::vector<float> &out_", "")}){
<#list tc.getStreamInputNames(networkInstruction.body, false) as variable> <#list tc.getStreamInputNames(networkInstruction.body) as variable>
MXPredSetInput(handle, input_keys[${variable?index}].c_str(), in_${variable}.data(), static_cast<mx_uint>(in_${variable}.size())); MXPredSetInput(handle, input_keys[${variable?index}].c_str(), in_${variable}.data(), static_cast<mx_uint>(in_${variable}.size()));
</#list> </#list>
......
...@@ -37,7 +37,7 @@ if __name__ == "__main__": ...@@ -37,7 +37,7 @@ if __name__ == "__main__":
normalize=${config.normalize?string("True","False")}, normalize=${config.normalize?string("True","False")},
</#if> </#if>
<#if (config.evalMetric)??> <#if (config.evalMetric)??>
eval_metric='${config.evalMetric.metric}', eval_metric='${config.evalMetric.name}',
eval_metric_params={ eval_metric_params={
<#if (config.evalMetric.exclude)??> <#if (config.evalMetric.exclude)??>
'exclude': [<#list config.evalMetric.exclude as value>${value}<#sep>, </#list>], 'exclude': [<#list config.evalMetric.exclude as value>${value}<#sep>, </#list>],
......
<#assign dim = element.dim?c> <#assign axis = (element.axis + 1)?c>
<#if mode == "ARCHITECTURE_DEFINITION"> <#if mode == "FORWARD_FUNCTION">
self.${element.name} = Concatenate(dim=${dim}) ${element.name} = F.concat(${tc.join(element.inputs, ", ")}, dim=${axis})
<#include "OutputShape.ftl">
<#elseif mode == "FORWARD_FUNCTION">
${element.name} = self.${element.name}(${tc.join(element.inputs, ", ")})
</#if> </#if>
\ No newline at end of file
<#if mode == "ARCHITECTURE_DEFINITION"> <#if mode == "FORWARD_FUNCTION">
self.${element.name} = Dot() ${element.name} = F.batch_dot(${tc.join(element.inputs, ", ")})
<#include "OutputShape.ftl">
<#elseif mode == "FORWARD_FUNCTION">
${element.name} = self.${element.name}(${tc.join(element.inputs, ", ")})
</#if> </#if>
\ No newline at end of file
<#assign dim = element.dim?c> <#assign axis = (element.axis + 1)?c>
<#if mode == "ARCHITECTURE_DEFINITION"> <#if mode == "FORWARD_FUNCTION">