Skip to content
Snippets Groups Projects
Commit a5e9d403 authored by Sebastian Nickels's avatar Sebastian Nickels
Browse files

Outputs now can be used as inputs

parent 94d3b4a8
Branches
No related tags found
1 merge request!21Added new layers
...@@ -37,4 +37,9 @@ public class CNNArch2GluonArchitectureSupportChecker extends ArchitectureSupport ...@@ -37,4 +37,9 @@ public class CNNArch2GluonArchitectureSupportChecker extends ArchitectureSupport
return true; return true;
} }
@Override
protected boolean checkOutputAsInput(ArchitectureSymbol architecture) {
return true;
}
} }
...@@ -173,7 +173,7 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController { ...@@ -173,7 +173,7 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController {
Map<String, List<String>> inputs = new LinkedHashMap<>(); Map<String, List<String>> inputs = new LinkedHashMap<>();
for (ArchitectureElementSymbol element : stream.getFirstAtomicElements()) { for (ArchitectureElementSymbol element : stream.getFirstAtomicElements()) {
if (element.isInput()) { if (element.isInput() || element.isOutput()) {
List<Integer> intDimensions = element.getOutputTypes().get(0).getDimensions(); List<Integer> intDimensions = element.getOutputTypes().get(0).getDimensions();
List<String> dimensions = new ArrayList<>(); List<String> dimensions = new ArrayList<>();
......
<#if element.inputs?size gte 1>
<#assign input = element.inputs[0]> <#assign input = element.inputs[0]>
<#if mode == "FORWARD_FUNCTION"> <#if mode == "FORWARD_FUNCTION">
${element.name} = ${input} ${element.name} = ${input}
...@@ -6,3 +7,4 @@ ...@@ -6,3 +7,4 @@
<#elseif mode == "CPP_INLINE"> <#elseif mode == "CPP_INLINE">
${element.name} = ${input}; ${element.name} = ${input};
</#if> </#if>
</#if>
\ No newline at end of file
<#list tc.getLayerVariableMembers("batch_size")?keys as member> <#list tc.getLayerVariableMembers("batch_size")?keys as member>
${member} = mx.nd.zeroes((${tc.join(tc.getLayerVariableMembers("batch_size")[member], ", ")},), ctx=mx_context) ${member} = mx.nd.zeroes((${tc.join(tc.getLayerVariableMembers("batch_size")[member], ", ")},), ctx=mx_context)
</#list> </#list>
<#list tc.architecture.outputs as output>
${tc.getName(output)} = mx.nd.zeroes(((${tc.join(output.ioDeclaration.type.dimensions, ", ")},), ctx=mx_context)
</#list>
<#list tc.architecture.streams as stream> <#list tc.architecture.streams as stream>
<#if stream.isTrainable()> <#if stream.isTrainable()>
......
...@@ -136,6 +136,7 @@ class CNNSupervisedTrainer_Alexnet: ...@@ -136,6 +136,7 @@ class CNNSupervisedTrainer_Alexnet:
predictions_label = batch.label[0].as_in_context(mx_context) predictions_label = batch.label[0].as_in_context(mx_context)
with autograd.record(): with autograd.record():
predictions_ = mx.nd.zeroes(((10,), ctx=mx_context)
predictions_ = self._networks[0](data_) predictions_ = self._networks[0](data_)
...@@ -172,6 +173,7 @@ class CNNSupervisedTrainer_Alexnet: ...@@ -172,6 +173,7 @@ class CNNSupervisedTrainer_Alexnet:
] ]
if True: if True:
predictions_ = mx.nd.zeroes(((10,), ctx=mx_context)
predictions_ = self._networks[0](data_) predictions_ = self._networks[0](data_)
...@@ -192,6 +194,7 @@ class CNNSupervisedTrainer_Alexnet: ...@@ -192,6 +194,7 @@ class CNNSupervisedTrainer_Alexnet:
] ]
if True: if True:
predictions_ = mx.nd.zeroes(((10,), ctx=mx_context)
predictions_ = self._networks[0](data_) predictions_ = self._networks[0](data_)
......
...@@ -136,6 +136,7 @@ class CNNSupervisedTrainer_CifarClassifierNetwork: ...@@ -136,6 +136,7 @@ class CNNSupervisedTrainer_CifarClassifierNetwork:
softmax_label = batch.label[0].as_in_context(mx_context) softmax_label = batch.label[0].as_in_context(mx_context)
with autograd.record(): with autograd.record():
softmax_ = mx.nd.zeroes(((10,), ctx=mx_context)
softmax_ = self._networks[0](data_) softmax_ = self._networks[0](data_)
...@@ -172,6 +173,7 @@ class CNNSupervisedTrainer_CifarClassifierNetwork: ...@@ -172,6 +173,7 @@ class CNNSupervisedTrainer_CifarClassifierNetwork:
] ]
if True: if True:
softmax_ = mx.nd.zeroes(((10,), ctx=mx_context)
softmax_ = self._networks[0](data_) softmax_ = self._networks[0](data_)
...@@ -192,6 +194,7 @@ class CNNSupervisedTrainer_CifarClassifierNetwork: ...@@ -192,6 +194,7 @@ class CNNSupervisedTrainer_CifarClassifierNetwork:
] ]
if True: if True:
softmax_ = mx.nd.zeroes(((10,), ctx=mx_context)
softmax_ = self._networks[0](data_) softmax_ = self._networks[0](data_)
......
...@@ -136,6 +136,7 @@ class CNNSupervisedTrainer_VGG16: ...@@ -136,6 +136,7 @@ class CNNSupervisedTrainer_VGG16:
predictions_label = batch.label[0].as_in_context(mx_context) predictions_label = batch.label[0].as_in_context(mx_context)
with autograd.record(): with autograd.record():
predictions_ = mx.nd.zeroes(((1000,), ctx=mx_context)
predictions_ = self._networks[0](data_) predictions_ = self._networks[0](data_)
...@@ -172,6 +173,7 @@ class CNNSupervisedTrainer_VGG16: ...@@ -172,6 +173,7 @@ class CNNSupervisedTrainer_VGG16:
] ]
if True: if True:
predictions_ = mx.nd.zeroes(((1000,), ctx=mx_context)
predictions_ = self._networks[0](data_) predictions_ = self._networks[0](data_)
...@@ -192,6 +194,7 @@ class CNNSupervisedTrainer_VGG16: ...@@ -192,6 +194,7 @@ class CNNSupervisedTrainer_VGG16:
] ]
if True: if True:
predictions_ = mx.nd.zeroes(((1000,), ctx=mx_context)
predictions_ = self._networks[0](data_) predictions_ = self._networks[0](data_)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment