Commit edb10eba authored by Sebastian N.'s avatar Sebastian N.
Browse files

Added ArgMax layer without implementation, updated for ResolvableSymbol

parent 374d469f
Pipeline #180029 failed with stages
in 2 minutes and 51 seconds
......@@ -28,6 +28,7 @@ public class CNNArch2GluonLayerSupportChecker extends LayerSupportChecker {
supportedLayerList.add(AllPredefinedLayers.LSTM_NAME);
supportedLayerList.add(AllPredefinedLayers.GRU_NAME);
supportedLayerList.add(AllPredefinedLayers.EMBEDDING_NAME);
supportedLayerList.add(AllPredefinedLayers.ARG_MAX_NAME);
}
}
......@@ -66,7 +66,7 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController {
}
}
else {
include(element.getResolvedThis().get(), partOfUnroll, unrollIndex, writer, netDefinitionMode);
include((ArchitectureElementSymbol) element.getResolvedThis().get(), partOfUnroll, unrollIndex, writer, netDefinitionMode);
}
setCurrentElement(previousElement);
......@@ -80,7 +80,7 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController {
include(TEMPLATE_ELEMENTS_DIR_PATH, "Const", writer, netDefinitionMode);
}
else {
include(constant.getResolvedThis().get(), partOfUnroll, unrollIndex, writer, netDefinitionMode);
include((ArchitectureElementSymbol) constant.getResolvedThis().get(), partOfUnroll, unrollIndex, writer, netDefinitionMode);
}
setCurrentElement(previousElement);
......@@ -97,7 +97,7 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController {
include(TEMPLATE_ELEMENTS_DIR_PATH, templateName, writer, netDefinitionMode);
}
else {
include(layer.getResolvedThis().get(), partOfUnroll, unrollIndex, writer, netDefinitionMode);
include((ArchitectureElementSymbol) layer.getResolvedThis().get(), partOfUnroll, unrollIndex, writer, netDefinitionMode);
}
setCurrentElement(previousElement);
......
<#assign input = element.inputs[0]>
<#if mode == "FORWARD_FUNCTION">
${element.name} = ${input}
</#if>
\ No newline at end of file
architecture RNNencdec(max_length=50, vocabulary_size=30000, hidden_size=1000){
def input N(0:29999)^{50} source
def output Q(0:29999)^{1} target[50]
def input Z(0:29999)^{50} source
def output Z(0:29999)^{1} target[50]
layer LSTM(units=hidden_size) encoder;
......@@ -20,6 +20,7 @@ architecture RNNencdec(max_length=50, vocabulary_size=30000, hidden_size=1000){
decoder ->
FullyConnected(units=vocabulary_size) ->
Softmax() ->
ArgMax() ->
target[t]
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment