Skip to content
Snippets Groups Projects
Commit edb10eba authored by Sebastian Nickels's avatar Sebastian Nickels
Browse files

Added ArgMax layer without implementation, updated for ResolvableSymbol

parent 374d469f
Branches
No related tags found
1 merge request!23Added Unroll-related features and layers
Pipeline #180029 failed
......@@ -28,6 +28,7 @@ public class CNNArch2GluonLayerSupportChecker extends LayerSupportChecker {
supportedLayerList.add(AllPredefinedLayers.LSTM_NAME);
supportedLayerList.add(AllPredefinedLayers.GRU_NAME);
supportedLayerList.add(AllPredefinedLayers.EMBEDDING_NAME);
supportedLayerList.add(AllPredefinedLayers.ARG_MAX_NAME);
}
}
......@@ -66,7 +66,7 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController {
}
}
else {
include(element.getResolvedThis().get(), partOfUnroll, unrollIndex, writer, netDefinitionMode);
include((ArchitectureElementSymbol) element.getResolvedThis().get(), partOfUnroll, unrollIndex, writer, netDefinitionMode);
}
setCurrentElement(previousElement);
......@@ -80,7 +80,7 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController {
include(TEMPLATE_ELEMENTS_DIR_PATH, "Const", writer, netDefinitionMode);
}
else {
include(constant.getResolvedThis().get(), partOfUnroll, unrollIndex, writer, netDefinitionMode);
include((ArchitectureElementSymbol) constant.getResolvedThis().get(), partOfUnroll, unrollIndex, writer, netDefinitionMode);
}
setCurrentElement(previousElement);
......@@ -97,7 +97,7 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController {
include(TEMPLATE_ELEMENTS_DIR_PATH, templateName, writer, netDefinitionMode);
}
else {
include(layer.getResolvedThis().get(), partOfUnroll, unrollIndex, writer, netDefinitionMode);
include((ArchitectureElementSymbol) layer.getResolvedThis().get(), partOfUnroll, unrollIndex, writer, netDefinitionMode);
}
setCurrentElement(previousElement);
......
<#assign input = element.inputs[0]>
<#if mode == "FORWARD_FUNCTION">
${element.name} = ${input}
</#if>
\ No newline at end of file
architecture RNNencdec(max_length=50, vocabulary_size=30000, hidden_size=1000){
def input N(0:29999)^{50} source
def output Q(0:29999)^{1} target[50]
def input Z(0:29999)^{50} source
def output Z(0:29999)^{1} target[50]
layer LSTM(units=hidden_size) encoder;
......@@ -20,6 +20,7 @@ architecture RNNencdec(max_length=50, vocabulary_size=30000, hidden_size=1000){
decoder ->
FullyConnected(units=vocabulary_size) ->
Softmax() ->
ArgMax() ->
target[t]
};
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment