Commit c5ca1033 authored by Christian Fuß's avatar Christian Fuß
Browse files

added OneHot Layer

parent 6bf25ee8
Pipeline #125342 failed with stages
...@@ -8,14 +8,14 @@ ...@@ -8,14 +8,14 @@
<groupId>de.monticore.lang.monticar</groupId> <groupId>de.monticore.lang.monticar</groupId>
<artifactId>cnnarch-mxnet-generator</artifactId> <artifactId>cnnarch-mxnet-generator</artifactId>
<version>0.2.14-SNAPSHOT</version> <version>0.2.15-SNAPSHOT</version>
<!-- == PROJECT DEPENDENCIES ============================================= --> <!-- == PROJECT DEPENDENCIES ============================================= -->
<properties> <properties>
<!-- .. SE-Libraries .................................................. --> <!-- .. SE-Libraries .................................................. -->
<CNNArch.version>0.3.0-SNAPSHOT</CNNArch.version> <CNNArch.version>0.3.1-SNAPSHOT</CNNArch.version>
<CNNTrain.version>0.2.6</CNNTrain.version> <CNNTrain.version>0.2.6</CNNTrain.version>
<embedded-montiarc-math-opt-generator>0.1.4</embedded-montiarc-math-opt-generator> <embedded-montiarc-math-opt-generator>0.1.4</embedded-montiarc-math-opt-generator>
......
...@@ -164,6 +164,16 @@ public class ArchitectureElementData { ...@@ -164,6 +164,16 @@ public class ArchitectureElementData {
.getStringValue(AllPredefinedLayers.POOL_TYPE_NAME).get(); .getStringValue(AllPredefinedLayers.POOL_TYPE_NAME).get();
} }
public int getOneHotIndex(){
return ((LayerSymbol) getElement())
.getIntValue(AllPredefinedLayers.ONE_HOT_INDEX_NAME).get();
}
public int getOneHotSize(){
return ((LayerSymbol) getElement())
.getIntValue(AllPredefinedLayers.ONE_HOT_SIZE).get();
}
@Nullable @Nullable
public List<Integer> getPadding(){ public List<Integer> getPadding(){
return getPadding((LayerSymbol) getElement()); return getPadding((LayerSymbol) getElement());
......
<#assign input = element.inputs[0]>
<#assign mode = definition_mode.toString()>
<#assign one_hot_index = element.oneHotIndex?c>
<#assign one_hot_size = element.oneHotSize?c>
<#if mode == "ARCHITECTURE_DEFINITION">
indexArray = mx.nd.array([${one_hot_index}])
indexVar = mx.sym.Variable('indexVar')
one_hot_vector = mx.symbol.one_hot(indices=indexVar,depth=${one_hot_size})
self.${element.name} = one_hot_vector.eval(indexVar=indexArray)
<#include "OutputShape.ftl">
</#if>
<#if mode == "FORWARD_FUNCTION">
${element.name} = self.${element.name}(${input})
</#if>
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment