Adapted Convolution layer template to get the correct dim_in if previous layer is the data layer

parent 28de4da4
......@@ -16,10 +16,6 @@
<#else>
<#assign kernelParameter = "kernel=[${kernelHeight},${kernelWidth}]">
</#if>
<#if input = tc.architectureInputs[0]> <#-- TODO: CHECK COMPARISON -->
${element.name} = brew.conv(model, ${input}, '${element.name}', dim_in=1, dim_out=${element.channels?c}, ${kernelParameter}, ${strideParameter})
<#else>
${element.name} = brew.conv(model, ${input}, '${element.name}', dim_in=${element.element.inputTypes[0].channels?c}, dim_out=${element.channels?c}, ${kernelParameter}, ${strideParameter})
</#if>
<#-- TODO: check how to adapt CNNArchLang argument no_bias=${element.noBias?string("True","False")} -->
<#include "OutputShape.ftl">
\ No newline at end of file
......@@ -58,8 +58,7 @@ class CNNCreator_Alexnet:
data = data
# data, output shape: {[3,224,224]}
conv1_ = brew.conv(model, data, 'conv1_', dim_in=1, dim_out=96, kernel=11, stride=4)
conv1_ = brew.conv(model, data, 'conv1_', dim_in=3, dim_out=96, kernel=11, stride=4)
# conv1_, output shape: {[96,55,55]}
lrn1_ = mx.symbol.LRN(data=conv1_,
alpha=0.0001,
......
......@@ -58,8 +58,7 @@ class CNNCreator_CifarClassifierNetwork:
data = data
# data, output shape: {[3,32,32]}
conv2_1_ = brew.conv(model, data, 'conv2_1_', dim_in=1, dim_out=8, kernel=3, stride=1)
conv2_1_ = brew.conv(model, data, 'conv2_1_', dim_in=3, dim_out=8, kernel=3, stride=1)
# conv2_1_, output shape: {[8,32,32]}
batchnorm2_1_ = mx.symbol.BatchNorm(data=conv2_1_,
fix_gamma=True,
......
......@@ -58,8 +58,7 @@ class CNNCreator_VGG16:
data = data
# data, output shape: {[3,224,224]}
conv1_ = brew.conv(model, data, 'conv1_', dim_in=1, dim_out=64, kernel=3, stride=1)
conv1_ = brew.conv(model, data, 'conv1_', dim_in=3, dim_out=64, kernel=3, stride=1)
# conv1_, output shape: {[64,224,224]}
relu1_ = brew.relu(model, conv1_, conv1_)
conv2_ = brew.conv(model, relu1_, 'conv2_', dim_in=64, dim_out=64, kernel=3, stride=1)
......
architecture LeNet(img_height=28, img_width=28, img_channels=3, classes=10){
architecture LeNet(img_height=28, img_width=28, img_channels=1, classes=10){
def input Z(0:255)^{img_channels, img_height, img_width} image
def output Q(0:1)^{classes} predictions
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment