Commit abe79bcb authored by Christian Fuß's avatar Christian Fuß
Browse files

added classes for ReduceSum, ExpandDims, Dot and Repeat layers in CNNNet

parent 3440c1a0
Pipeline #188337 failed with stages
in 28 seconds
......@@ -40,6 +40,40 @@ class Concatenate(gluon.HybridBlock):
def hybrid_forward(self, F, *x):
return F.concat(*x, dim=self.dim)
class Repeat(gluon.HybridBlock):
def __init__(self, repeats, axis=1, **kwargs):
super(Repeat, self).__init__(**kwargs)
with self.name_scope():
self.axis = axis
self.repeats = repeats
def hybrid_forward(self, F, x):
return F.repeat(data=x, axis=self.axis, repeats=self.repeats)
class Dot(gluon.HybridBlock):
def __init__(self, **kwargs):
super(Dot, self).__init__(**kwargs)
def hybrid_forward(self, F, *x):
return F.dot(*x)
class ExpandDims(gluon.HybridBlock):
def __init__(self, dim=1, **kwargs):
super(ExpandDims, self).__init__(**kwargs)
with self.name_scope():
self.dim = dim
def hybrid_forward(self, F, x):
return F.expand_dims(data=x, axis=self.dim)
class ReduceSum(gluon.HybridBlock):
def __init__(self, axis=1, **kwargs):
super(ReduceSum, self).__init__(**kwargs)
with self.name_scope():
self.axis = axis
def hybrid_forward(self, F, x):
return F.sum(data=x, axis=self.axis)
class ZScoreNormalization(gluon.HybridBlock):
def __init__(self, data_mean, data_std, **kwargs):
......
<#if mode == "FORWARD_FUNCTION">
${element.name} = mx.symbol.dot(${element.inputs[0]}, ${element.inputs[1]})
<#elseif mode == "PYTHON_INLINE">
${element.name} = mx.symbol.dot(${element.inputs[0]}, ${element.inputs[1]})
<#if mode == "ARCHITECTURE_DEFINITION">
self.${element.name} = Dot()
<#include "OutputShape.ftl">
<#elseif mode == "FORWARD_FUNCTION">
${element.name} = self.${element.name}(${tc.join(element.inputs, ", ")})
</#if>
\ No newline at end of file
<#assign dim = element.dim?c>
<#if mode == "FORWARD_FUNCTION">
${element.name} = mx.symbol.expand_dims(data = ${element.inputs[0]}, axis=${dim})
<#elseif mode == "PYTHON_INLINE">
${element.name} = mx.symbol.expand_dims(data = ${element.inputs[0]}, axis=${dim})
<#if mode == "ARCHITECTURE_DEFINITION">
self.${element.name} = ExpandDims(dim=${dim})
<#include "OutputShape.ftl">
<#elseif mode == "FORWARD_FUNCTION">
${element.name} = self.${element.name}(${element.inputs[0]})
</#if>
\ No newline at end of file
<#assign axis = element.axis?c>
<#if mode == "FORWARD_FUNCTION">
${element.name} = mx.symbol.sum(data = ${element.inputs[0]}, axis=${axis})
<#elseif mode == "PYTHON_INLINE">
${element.name} = mx.symbol.sum(data = ${element.inputs[0]}, axis=${axis})
<#if mode == "ARCHITECTURE_DEFINITION">
self.${element.name} = ReduceSum(axis=${axis})
<#include "OutputShape.ftl">
<#elseif mode == "FORWARD_FUNCTION">
${element.name} = self.${element.name}(${element.inputs[0]})
</#if>
\ No newline at end of file
<#assign repeats = element.repeats?c>
<#assign axis = element.axis?c>
<#if mode == "FORWARD_FUNCTION">
${element.name} = mx.symbol.repeat(repeats=${repeats}, axis=${axis})
<#elseif mode == "PYTHON_INLINE">
${element.name} = mx.symbol.repeat(repeats=${repeats}, axis=${axis})
<#assign repeats = element.repeats?c>
<#if mode == "ARCHITECTURE_DEFINITION">
self.${element.name} = Repeat(repeats=${repeats}, axis=${axis})
<#include "OutputShape.ftl">
<#elseif mode == "FORWARD_FUNCTION">
${element.name} = self.${element.name}(${element.inputs[0]})
</#if>
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment