diff --git a/src/main/resources/templates/gluon/elements/OneHot.ftl b/src/main/resources/templates/gluon/elements/OneHot.ftl index 369bbeaa674e24a9984cdad0511b396144408bf6..66a2fce5d9680a4a73a92ae978e870f90d248ebd 100644 --- a/src/main/resources/templates/gluon/elements/OneHot.ftl +++ b/src/main/resources/templates/gluon/elements/OneHot.ftl @@ -1,13 +1,13 @@ <#assign input = element.inputs[0]> <#assign size = element.size> <#if mode == "ARCHITECTURE_DEFINITION"> - self.${element.name} = OneHot(size=${size}) - <#include "OutputShape.ftl"> + self.${element.name} = OneHot(size=${element.element.outputTypes[0].dimensions[0]}) + <#include "OutputShape.ftl"> <#elseif mode == "FORWARD_FUNCTION"> ${element.name} = self.${element.name}(${input}) <#elseif mode == "PYTHON_INLINE"> - ${element.name} = nd.one_hot(indices=${input}, depth=${size}) + ${element.name} = nd.one_hot(indices=${input}, depth=${element.element.outputTypes[0].dimensions[0]}) <#elseif mode == "CPP_INLINE"> - vector<float> ${element.name}(${size}, 0); + vector<float> ${element.name}(${element.element.outputTypes[0].dimensions[0]}, 0); ${element.name}[${input}[0]] = 1; </#if> \ No newline at end of file