Commit 5608cd10 authored by Carlos Alfredo Yeverino Rodriguez's avatar Carlos Alfredo Yeverino Rodriguez
Browse files

Corrected Input and Output layer templates. Minor esthetic change in OutputShape template

parent 61c0b334
...@@ -6,17 +6,16 @@ ...@@ -6,17 +6,16 @@
<#if heightIndex != 0><#assign indexList = indexList + [heightIndex]></#if> <#if heightIndex != 0><#assign indexList = indexList + [heightIndex]></#if>
<#if widthIndex != 0><#assign indexList = indexList + [widthIndex]></#if> <#if widthIndex != 0><#assign indexList = indexList + [widthIndex]></#if>
<#assign dimensions = element.element.outputTypes[0].dimensions> <#assign dimensions = element.element.outputTypes[0].dimensions>
#${element.name} = mx.sym.var("${element.name}", ${element.name}, label = AddInput(model, batch_size=100)
# shape=(0,${tc.join(dimensions, ",")}))
<#include "OutputShape.ftl"> <#include "OutputShape.ftl">
<#if heightIndex != channelIndex + 1 || widthIndex != heightIndex + 1> <#if heightIndex != channelIndex + 1 || widthIndex != heightIndex + 1>
#${element.name} = mx.symbol.transpose(data=${element.name}, ${element.name} = mx.symbol.transpose(data=${element.name},mx.sym.var <#-- TODO: check how to adapt CNNArchLang transpose case -->
# axes=(0,${tc.join(indexList, ",")})) axes=(0,${tc.join(indexList, ",")}))
</#if> </#if>
<#if indexList?size != 3> <#if indexList?size != 3>
#${element.name} = mx.symbol.reshape(data=${element.name}, ${element.name} = mx.symbol.reshape(data=${element.name}, <#-- TODO: check how to adapt CNNArchLang transpose case -->
# shape=(0,${element.element.outputTypes[0].channels?c},${element.element.outputTypes[0].height?c},${element.element.outputTypes[0].width?c})) shape=(0,${element.element.outputTypes[0].channels?c},${element.element.outputTypes[0].height?c},${element.element.outputTypes[0].width?c}))
</#if> </#if>
# ${element.name} = mx.symbol.broadcast_sub(${element.name}, _data_mean_) workspace.FeedBlob("${element.name}", ${element.name}, device_option=device_opts)
# ${element.name} = mx.symbol.broadcast_div(${element.name}, _data_std_) workspace.FeedBlob("label", label, device_option=device_opts)
<#assign input = element.inputs[0]>
<#if element.softmaxOutput> <#if element.softmaxOutput>
pred = brew.fc(model, ${element.inputs[0]}, 'pred', 500, 10) ${element.name} = brew.softmax(model, ${input}, '${element.name}')
${element.name} = brew.softmax(model, pred, '${element.name}')
<#elseif element.logisticRegressionOutput> <#elseif element.logisticRegressionOutput>
${element.name} = mx.symbol.LogisticRegressionOutput(data=${element.inputs[0]}, ${element.name} = mx.symbol.LogisticRegressionOutput(data=${element.inputs[0]},
name="${element.name}") name="${element.name}")
<#elseif element.linearRegressionOutput> <#elseif element.linearRegressionOutput>
${element.name} = mx.symbol.LinearRegressionOutput(data=${element.inputs[0]}, ${element.name} = mx.symbol.LinearRegressionOutput(data=${element.inputs[0]},
name="${element.name}") name="${element.name}")
</#if> </#if>
\ No newline at end of file
model.net.AddExternalOutput(${element.name})
return ${element.name}
\ No newline at end of file
# ${element.name}, output shape: {<#list element.element.outputTypes as type>[${tc.join(type.dimensions, ",")}]</#list>} # ${element.name}, output shape: {<#list element.element.outputTypes as type>[${tc.join(type.dimensions, ",")}]</#list>
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment