Skip to content
Snippets Groups Projects
Commit 904c654d authored by Sebastian Nickels's avatar Sebastian Nickels
Browse files

Changed templates

parent a132a0fa
No related branches found
No related tags found
1 merge request!23Added Unroll-related features and layers
Pipeline #180571 failed
......@@ -9,35 +9,35 @@
#include <CNNBufferFile.h>
<#list tc.architecture.streams as stream>
<#if stream.isTrainable()>
class ${tc.fileNameWithoutEnding}_${stream?index}{
<#list tc.architecture.networkInstructions as networkInstruction>
<#if networkInstruction.body.isTrainable()>
class ${tc.fileNameWithoutEnding}_${networkInstruction?index}{
public:
const std::string json_file = "model/${tc.componentName}/model_${stream?index}_newest-symbol.json";
const std::string param_file = "model/${tc.componentName}/model_${stream?index}_newest-0000.params";
const std::string json_file = "model/${tc.componentName}/model_${networkInstruction?index}_newest-symbol.json";
const std::string param_file = "model/${tc.componentName}/model_${networkInstruction?index}_newest-0000.params";
const std::vector<std::string> input_keys = {
<#if tc.getStreamInputNames(stream)?size == 1>
<#if tc.getStreamInputNames(networkInstruction.body)?size == 1>
"data"
<#else>
<#list tc.getStreamInputNames(stream) as variable>"data${variable?index}"<#sep>, </#list>
<#list tc.getStreamInputNames(networkInstruction.body) as variable>"data${variable?index}"<#sep>, </#list>
</#if>
};
const std::vector<std::vector<mx_uint>> input_shapes = {<#list tc.getStreamInputDimensions(stream) as dimensions>{${tc.join(dimensions, ", ")}}<#sep>, </#list>};
const std::vector<std::vector<mx_uint>> input_shapes = {<#list tc.getStreamInputDimensions(networkInstruction.body) as dimensions>{${tc.join(dimensions, ", ")}}<#sep>, </#list>};
const bool use_gpu = false;
PredictorHandle handle;
explicit ${tc.fileNameWithoutEnding}_${stream?index}(){
explicit ${tc.fileNameWithoutEnding}_${networkInstruction?index}(){
init(json_file, param_file, input_keys, input_shapes, use_gpu);
}
~${tc.fileNameWithoutEnding}_${stream?index}(){
~${tc.fileNameWithoutEnding}_${networkInstruction?index}(){
if(handle) MXPredFree(handle);
}
void predict(${tc.join(tc.getStreamInputNames(stream), ", ", "const std::vector<float> &in_", "")},
${tc.join(tc.getStreamOutputNames(stream), ", ", "std::vector<float> &out_", "")}){
<#list tc.getStreamInputNames(stream) as variable>
void predict(${tc.join(tc.getStreamInputNames(networkInstruction.body), ", ", "const std::vector<float> &in_", "")},
${tc.join(tc.getStreamOutputNames(networkInstruction.body), ", ", "std::vector<float> &out_", "")}){
<#list tc.getStreamInputNames(networkInstruction.body) as variable>
MXPredSetInput(handle, input_keys[${variable?index}].c_str(), in_${variable}.data(), static_cast<mx_uint>(in_${variable}.size()));
</#list>
......@@ -48,7 +48,7 @@ public:
mx_uint shape_len;
size_t size;
<#list tc.getStreamOutputNames(stream) as variable>
<#list tc.getStreamOutputNames(networkInstruction.body) as variable>
output_index = ${variable?index?c};
MXPredGetOutputShape(handle, output_index, &shape, &shape_len);
size = 1;
......
......@@ -3,18 +3,10 @@
<#if element.padding??>
self.${element.name}padding = Padding(padding=(${tc.join(element.padding, ",")}))
</#if>
<#if element.partOfUnroll>
self.${element.name} = gluon.nn.Conv2D(channels=${element.channels?c},
kernel_size=(${tc.join(element.kernel, ",")}),
strides=(${tc.join(element.stride, ",")}),
use_bias=${element.noBias?string("False", "True")},
params=Net_${element.unrollIndex + tc.architecture.streams?size}().${element.name}.collect_params())
<#else>
self.${element.name} = gluon.nn.Conv2D(channels=${element.channels?c},
kernel_size=(${tc.join(element.kernel, ",")}),
strides=(${tc.join(element.stride, ",")}),
use_bias=${element.noBias?string("False", "True")})
</#if>
<#include "OutputShape.ftl">
<#elseif mode == "FORWARD_FUNCTION">
<#if element.padding??>
......
<#assign input = element.inputs[0]>
<#if mode == "ARCHITECTURE_DEFINITION">
<#if element.partOfUnroll>
self.${element.name} = gluon.nn.Embedding(input_dim=${element.inputDim?c}, output_dim=${element.outputDim?c},
params=Net_${element.unrollIndex + tc.architecture.streams?size}().${element.name}.collect_params())
<#else>
self.${element.name} = gluon.nn.Embedding(input_dim=${element.inputDim?c}, output_dim=${element.outputDim?c})
</#if>
self.${element.name} = gluon.nn.Embedding(input_dim=${element.inputDim?c}, output_dim=${element.outputDim?c})
<#include "OutputShape.ftl">
<#elseif mode == "FORWARD_FUNCTION">
${element.name} = self.${element.name}(${input})
......
......@@ -3,13 +3,8 @@
<#assign use_bias = element.noBias?string("False","True")>
<#assign flatten = element.flatten?string("True","False")>
<#if mode == "ARCHITECTURE_DEFINITION">
<#if element.partOfUnroll>
self.${element.name} = gluon.nn.Dense(units=${units}, use_bias=${use_bias}, flatten=${flatten},
params=Net_${element.unrollIndex + tc.architecture.streams?size}().${element.name}.collect_params())
<#else>
self.${element.name} = gluon.nn.Dense(units=${units}, use_bias=${use_bias}, flatten=${flatten})
</#if>
<#include "OutputShape.ftl">
<#include "OutputShape.ftl">
<#elseif mode == "FORWARD_FUNCTION">
${element.name} = self.${element.name}(${input})
</#if>
\ No newline at end of file
......@@ -13,11 +13,11 @@
</#if>
</#list>
<#list tc.architecture.streams as stream>
<#if stream.isTrainable()>
_predictor_${stream?index}_.predict(${tc.join(tc.getStreamInputNames(stream), ", ")}, ${tc.join(tc.getStreamOutputNames(stream), ", ")});
<#list tc.architecture.networkInstructions as networkInstruction>
<#if networkInstruction.body.isTrainable()>
_predictor_${networkInstruction?index}_.predict(${tc.join(tc.getStreamInputNames(networkInstruction.body), ", ")}, ${tc.join(tc.getStreamOutputNames(networkInstruction.body), ", ")});
<#else>
${tc.include(stream, "CPP_INLINE")}
${tc.include(networkInstruction.body, "CPP_INLINE")}
</#if>
</#list>
......
......@@ -7,10 +7,10 @@
</#if>
</#list>
<#list tc.architecture.streams as stream>
<#if stream.isTrainable()>
${tc.join(tc.getStreamOutputNames(stream), ", ")} = self._networks[${stream?index}](${tc.join(tc.getStreamInputNames(stream), ", ")})
<#list tc.architecture.networkInstructions as networkInstruction>
<#if networkInstruction.body.isTrainable()>
${tc.join(tc.getStreamOutputNames(networkInstruction.body), ", ")} = self._networks[${networkInstruction?index}](${tc.join(tc.getStreamInputNames(networkInstruction.body), ", ")})
<#else>
${tc.include(stream, "PYTHON_INLINE")}
${tc.include(networkInstruction.body, "PYTHON_INLINE")}
</#if>
</#list>
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment