Commit a5e9d403 authored by Sebastian N.'s avatar Sebastian N.
Browse files

Outputs now can be used as inputs

parent 94d3b4a8
......@@ -37,4 +37,9 @@ public class CNNArch2GluonArchitectureSupportChecker extends ArchitectureSupport
return true;
}
@Override
protected boolean checkOutputAsInput(ArchitectureSymbol architecture) {
return true;
}
}
......@@ -173,7 +173,7 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController {
Map<String, List<String>> inputs = new LinkedHashMap<>();
for (ArchitectureElementSymbol element : stream.getFirstAtomicElements()) {
if (element.isInput()) {
if (element.isInput() || element.isOutput()) {
List<Integer> intDimensions = element.getOutputTypes().get(0).getDimensions();
List<String> dimensions = new ArrayList<>();
......
<#if element.inputs?size gte 1>
<#assign input = element.inputs[0]>
<#if mode == "FORWARD_FUNCTION">
${element.name} = ${input}
......@@ -6,3 +7,4 @@
<#elseif mode == "CPP_INLINE">
${element.name} = ${input};
</#if>
</#if>
\ No newline at end of file
<#list tc.getLayerVariableMembers("batch_size")?keys as member>
${member} = mx.nd.zeroes((${tc.join(tc.getLayerVariableMembers("batch_size")[member], ", ")},), ctx=mx_context)
</#list>
<#list tc.architecture.outputs as output>
${tc.getName(output)} = mx.nd.zeroes(((${tc.join(output.ioDeclaration.type.dimensions, ", ")},), ctx=mx_context)
</#list>
<#list tc.architecture.streams as stream>
<#if stream.isTrainable()>
......
......@@ -136,6 +136,7 @@ class CNNSupervisedTrainer_Alexnet:
predictions_label = batch.label[0].as_in_context(mx_context)
with autograd.record():
predictions_ = mx.nd.zeroes(((10,), ctx=mx_context)
predictions_ = self._networks[0](data_)
......@@ -172,6 +173,7 @@ class CNNSupervisedTrainer_Alexnet:
]
if True:
predictions_ = mx.nd.zeroes(((10,), ctx=mx_context)
predictions_ = self._networks[0](data_)
......@@ -192,6 +194,7 @@ class CNNSupervisedTrainer_Alexnet:
]
if True:
predictions_ = mx.nd.zeroes(((10,), ctx=mx_context)
predictions_ = self._networks[0](data_)
......
......@@ -136,6 +136,7 @@ class CNNSupervisedTrainer_CifarClassifierNetwork:
softmax_label = batch.label[0].as_in_context(mx_context)
with autograd.record():
softmax_ = mx.nd.zeroes(((10,), ctx=mx_context)
softmax_ = self._networks[0](data_)
......@@ -172,6 +173,7 @@ class CNNSupervisedTrainer_CifarClassifierNetwork:
]
if True:
softmax_ = mx.nd.zeroes(((10,), ctx=mx_context)
softmax_ = self._networks[0](data_)
......@@ -192,6 +194,7 @@ class CNNSupervisedTrainer_CifarClassifierNetwork:
]
if True:
softmax_ = mx.nd.zeroes(((10,), ctx=mx_context)
softmax_ = self._networks[0](data_)
......
......@@ -136,6 +136,7 @@ class CNNSupervisedTrainer_VGG16:
predictions_label = batch.label[0].as_in_context(mx_context)
with autograd.record():
predictions_ = mx.nd.zeroes(((1000,), ctx=mx_context)
predictions_ = self._networks[0](data_)
......@@ -172,6 +173,7 @@ class CNNSupervisedTrainer_VGG16:
]
if True:
predictions_ = mx.nd.zeroes(((1000,), ctx=mx_context)
predictions_ = self._networks[0](data_)
......@@ -192,6 +194,7 @@ class CNNSupervisedTrainer_VGG16:
]
if True:
predictions_ = mx.nd.zeroes(((1000,), ctx=mx_context)
predictions_ = self._networks[0](data_)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment