execute.ftl 5.96 KB
Newer Older
1
<#list tc.architectureInputSymbols as input>
2
3
    vector<float> ${tc.getName(input)} = CNNTranslator::translate(${input.name}<#if input.arrayAccess.isPresent()>[${input.arrayAccess.get().intValue.get()?c}]</#if>);
</#list>
4
5
6
7
8
9
10
11
12
13
14
15
16

<#if tc.architectureOutputSymbols?size gt 1>
<#assign outputName = tc.getNameWithoutIndex(tc.getName(tc.architectureOutputSymbols[0]))>
    vector<vector<float>> ${outputName}(${tc.architectureOutputSymbols?size});
    for (size_t i = 0; i < ${outputName}.size(); ++i) {
        ${outputName}[i].emplace_back(${tc.join(tc.architectureOutputSymbols[0].ioDeclaration.type.dimensions, " * ")});
    }
<#else>
<#list tc.architectureOutputSymbols as output>
    vector<float> ${tc.getName(output)}(${tc.join(tc.architectureOutputSymbols[0].ioDeclaration.type.dimensions, " * ")});<#sep>,
</#list>
</#if>

17
18
<#list tc.getLayerVariableMembers()?keys as member>
    vector<float> ${member}(${tc.join(tc.getLayerVariableMembers()[member], " * ")});
19
</#list>
20

Sebastian Nickels's avatar
Sebastian Nickels committed
21
22
23
<#list tc.architecture.constants as constant>
    vector<float> ${tc.getName(constant)}{${constant.intValue?c}};
</#list>
24

Sebastian Nickels's avatar
Sebastian Nickels committed
25
<#list tc.architecture.networkInstructions as networkInstruction>
Sebastian Nickels's avatar
Sebastian Nickels committed
26
<#if networkInstruction.isUnroll()>
27
28
29
30
31
32
    {
        int k = ${tc.getBeamSearchWidth(networkInstruction)};
<#list tc.getUnrollInputNames(networkInstruction, "1") as inputName>
<#if tc.getNameWithoutIndex(inputName) == tc.outputName>
        vector<pair<vector<vector<float>>, double>> sequences{make_pair(vector<float>{${inputName}}, 1.0)};
</#if>
Sebastian Nickels's avatar
Sebastian Nickels committed
33
</#list>
34
35
36
37
38
39
40
41
42
43
44

        for (size_t i = 1; i < ${tc.getBeamSearchMaxLength(networkInstruction)}; ++i) {
            vector<pair<vector<vector<float>>, double>> allCandidates;

            for (const pair<vector<vector<float>>& pair : sequences) {
                vector<vector<float>> seq = pair.first;
                double score = pair.second;

<#list tc.getUnrollInputNames(networkInstruction, "i") as inputName>
<#if tc.getNameWithoutIndex(inputName) == tc.outputName>
                ${inputName} = seq.back();
45
</#if>
46
47
48
49
50
</#list>
                _predictor_${networkInstruction?index}_.predict(${tc.join(tc.getUnrollInputNames(networkInstruction, "i"), ", ")}, ${tc.join(tc.getUnrollOutputNames(networkInstruction, "i"), ", ")});
<#list tc.getUnrollOutputNames(networkInstruction, "i") as outputName>
<#if tc.getNameWithoutIndex(outputName) == tc.outputName>
                vector<float>& out = ${outputName};
Sebastian Nickels's avatar
Sebastian Nickels committed
51
</#if>
52
</#list>
53

54
55
56
57
58
59
60
61
                vector<pair<int, float>> topk;
                for (size_t i = 0; i < out.size(); ++i) {
                    topk.emplace_back(i, out[i]);
                }

                sort(topk.begin(), topk.end(), [] (const pair<int, float>& p1, const pair<int, float>& p2) {
                    return p1.second > p2.second;
                };
62
                topk = vector<pair<int, float>>(topk.begin(), topk.begin() + std::min<int>(k, topk.size()));
63
64
65
66
67
68
69
70
71
72
73

                for (const pair<int, float>& pair : topk) {
                    vector<vector<float>> currentSeq = seq;
                    currentSeq.push_back(vector<float>{pair.first});
                    allCandidates.emplace_back(currentSeq, score * pair.second);
                }
            }

            sort(allCandidates.begin(), allCandidates.end(), [] (const pair<vector<vector<float>>, double>& p1, const pair<vector<vector<float>>, double>& p2) {
                return p1.second > p2.second;
            });
74
            sequences = vector<pair<vector<vector<float>>, double>>(allCandidates.begin(), allCandidates.begin() + std::min<int>(k, allCandidates.size()));
75
76
77
78
79
        }

        for (size_t i = 1; i < ${tc.getBeamSearchMaxLength(networkInstruction)}; ++i) {
<#list tc.getUnrollOutputNames(networkInstruction, "i") as outputName>
<#if tc.getNameWithoutIndex(outputName) == tc.outputName>
Sebastian Nickels's avatar
Sebastian Nickels committed
80
            ${outputName} = sequences[0].first[i];
81
82
83
84
85
86
87
88
89
</#if>
</#list>
        }
    }
<#else>
    _predictor_${networkInstruction?index}_.predict(${tc.join(tc.getStreamInputNames(networkInstruction.body, true), ", ")}, ${tc.join(tc.getStreamOutputNames(networkInstruction.body, true), ", ")});
</#if>

</#list>
90
<#list tc.architectureOutputSymbols as output>
91
<#assign shape = output.ioDeclaration.type.dimensions>
92
<#if shape?size == 1>
93
<#if (output.ioDeclaration.type.domain.isNaturalNumber() || output.ioDeclaration.type.domain.isWholeNumber())>
94
    ${output.name}<#if output.arrayAccess.isPresent()>[${output.arrayAccess.get().intValue.get()?c}]</#if> = CNNTranslator::translateToIntCol(${tc.getNameAsArray(tc.getName(output))}, std::vector<size_t> {${shape[0]?c}});
95
<#else>
96
    ${output.name}<#if output.arrayAccess.isPresent()>[${output.arrayAccess.get().intValue.get()?c}]</#if> = CNNTranslator::translateToCol(${tc.getNameAsArray(tc.getName(output))}, std::vector<size_t> {${shape[0]?c}});
97
</#if>
98
</#if>
99
<#if shape?size == 2>
100
<#if (output.ioDeclaration.type.domain.isNaturalNumber() || output.ioDeclaration.type.domain.isWholeNumber())>
101
    ${output.name}<#if output.arrayAccess.isPresent()>[${output.arrayAccess.get().intValue.get()?c}]</#if> = CNNTranslator::translateToIntMat(${tc.getNameAsArray(tc.getName(output))}, std::vector<size_t> {${shape[0]?c}, ${shape[1]?c}});
102
<#else>
103
    ${output.name}<#if output.arrayAccess.isPresent()>[${output.arrayAccess.get().intValue.get()?c}]</#if> = CNNTranslator::translateToMat(${tc.getNameAsArray(tc.getName(output))}, std::vector<size_t> {${shape[0]?c}, ${shape[1]?c}});
104
</#if>
105
</#if>
106
<#if shape?size == 3>
107
<#if (output.ioDeclaration.type.domain.isNaturalNumber() || output.ioDeclaration.type.domain.isWholeNumber())>
108
    ${output.name}<#if output.arrayAccess.isPresent()>[${output.arrayAccess.get().intValue.get()?c}]</#if> = CNNTranslator::translateToIntCube(${tc.getNameAsArray(tc.getName(output))}, std::vector<size_t> {${shape[0]?c}, ${shape[1]?c}, ${shape[2]?c}});
109
<#else>
110
    ${output.name}<#if output.arrayAccess.isPresent()>[${output.arrayAccess.get().intValue.get()?c}]</#if> = CNNTranslator::translateToCube(${tc.getNameAsArray(tc.getName(output))}, std::vector<size_t> {${shape[0]?c}, ${shape[1]?c}, ${shape[2]?c}});
111
</#if>
112
113
</#if>
</#list>