Commit 74d278fe authored by Christian Fuß's avatar Christian Fuß
Browse files

added OneHotLayer

parent 4b8bf9e2
Pipeline #150125 failed with stages
......@@ -22,6 +22,7 @@ public class CNNArch2GluonLayerSupportChecker extends LayerSupportChecker {
supportedLayerList.add(AllPredefinedLayers.ADD_NAME);
supportedLayerList.add(AllPredefinedLayers.CONCATENATE_NAME);
supportedLayerList.add(AllPredefinedLayers.FLATTEN_NAME);
supportedLayerList.add(AllPredefinedLayers.ONE_HOT_NAME);
}
}
......@@ -71,7 +71,7 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController {
if (layer.isAtomic()){
ArchitectureElementSymbol nextElement = layer.getOutputElement().get();
if (!isSoftmaxOutput(nextElement) && !isLogisticRegressionOutput(nextElement)){
if (!isSoftmaxOutput(nextElement) && !isLogisticRegressionOutput(nextElement) && !isOneHotOutput(nextElement)){
String templateName = layer.getDeclaration().getName();
include(TEMPLATE_ELEMENTS_DIR_PATH, templateName, writer, netDefinitionMode);
}
......
......@@ -2,6 +2,16 @@ import mxnet as mx
import numpy as np
from mxnet import gluon
class OneHot(gluon.HybridBlock):
def __init__(self, size, **kwargs):
super(OneHot, self).__init__(**kwargs)
with self.name_scope():
self.size = size
def hybrid_forward(self, F, x):
return F.one_hot(indices=F.argmax(data=x, axis=1), depth=self.size)
class Softmax(gluon.HybridBlock):
def __init__(self, **kwargs):
super(Softmax, self).__init__(**kwargs)
......
<#assign input = element.inputs[0]>
<#assign mode = definition_mode.toString()>
<#assign size = element.size>
<#if mode == "ARCHITECTURE_DEFINITION">
self.${element.name} = OneHot(size=${size})
<#include "OutputShape.ftl">
</#if>
<#if mode == "FORWARD_FUNCTION">
${element.name} = self.${element.name}(${input})
</#if>
......@@ -7,6 +7,8 @@
self.last_layers['${element.name}'] = 'sigmoid'
<#elseif element.linearRegressionOutput>
self.last_layers['${element.name}'] = 'linear'
<#elseif element.oneHotOutput>
self.last_layers['${element.name}'] = 'softmax'
</#if>
</#if>
<#if mode == "FORWARD_FUNCTION">
......
......@@ -2,6 +2,16 @@ import mxnet as mx
import numpy as np
from mxnet import gluon
class OneHot(gluon.HybridBlock):
def __init__(self, size, **kwargs):
super(OneHot, self).__init__(**kwargs)
with self.name_scope():
self.size = size
def hybrid_forward(self, F, x):
return F.one_hot(indices=F.argmax(data=x, axis=1), depth=self.size)
class Softmax(gluon.HybridBlock):
def __init__(self, **kwargs):
super(Softmax, self).__init__(**kwargs)
......
......@@ -2,6 +2,16 @@ import mxnet as mx
import numpy as np
from mxnet import gluon
class OneHot(gluon.HybridBlock):
def __init__(self, size, **kwargs):
super(OneHot, self).__init__(**kwargs)
with self.name_scope():
self.size = size
def hybrid_forward(self, F, x):
return F.one_hot(indices=F.argmax(data=x, axis=1), depth=self.size)
class Softmax(gluon.HybridBlock):
def __init__(self, **kwargs):
super(Softmax, self).__init__(**kwargs)
......
......@@ -2,6 +2,16 @@ import mxnet as mx
import numpy as np
from mxnet import gluon
class OneHot(gluon.HybridBlock):
def __init__(self, size, **kwargs):
super(OneHot, self).__init__(**kwargs)
with self.name_scope():
self.size = size
def hybrid_forward(self, F, x):
return F.one_hot(indices=F.argmax(data=x, axis=1), depth=self.size)
class Softmax(gluon.HybridBlock):
def __init__(self, **kwargs):
super(Softmax, self).__init__(**kwargs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment