Commit 12b7f4ef authored by Christian Fuß's avatar Christian Fuß
Browse files

fixed naming of some elements in unrolls

parent 97212579
Pipeline #179274 failed with stages
in 2 minutes and 38 seconds
......@@ -48,7 +48,7 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController {
getTemplateConfiguration().processTemplate(ftlContext, templatePath, writer);
}
public void include(VariableSymbol element, boolean partOfUnroll, Writer writer, NetDefinitionMode netDefinitionMode){
public void include(VariableSymbol element, boolean partOfUnroll, int unrollIndex, Writer writer, NetDefinitionMode netDefinitionMode){
ArchitectureElementData previousElement = getCurrentElement();
setCurrentElement(element);
......@@ -66,13 +66,13 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController {
}
}
else {
include(element.getResolvedThis().get(), partOfUnroll, writer, netDefinitionMode);
include(element.getResolvedThis().get(), partOfUnroll, unrollIndex, writer, netDefinitionMode);
}
setCurrentElement(previousElement);
}
public void include(ConstantSymbol constant, boolean partOfUnroll, Writer writer, NetDefinitionMode netDefinitionMode) {
public void include(ConstantSymbol constant, boolean partOfUnroll, int unrollIndex, Writer writer, NetDefinitionMode netDefinitionMode) {
ArchitectureElementData previousElement = getCurrentElement();
setCurrentElement(constant);
......@@ -80,85 +80,88 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController {
include(TEMPLATE_ELEMENTS_DIR_PATH, "Const", writer, netDefinitionMode);
}
else {
include(constant.getResolvedThis().get(), partOfUnroll, writer, netDefinitionMode);
include(constant.getResolvedThis().get(), partOfUnroll, unrollIndex, writer, netDefinitionMode);
}
setCurrentElement(previousElement);
}
public void include(LayerSymbol layer, boolean partOfUnroll, Writer writer, NetDefinitionMode netDefinitionMode){
public void include(LayerSymbol layer, boolean partOfUnroll, int unrollIndex, Writer writer, NetDefinitionMode netDefinitionMode){
ArchitectureElementData previousElement = getCurrentElement();
setCurrentElement(layer);
getCurrentElement().setPartOfUnroll(partOfUnroll);
getCurrentElement().setUnrollIndex(unrollIndex);
if (layer.isAtomic()){
String templateName = layer.getDeclaration().getName();
include(TEMPLATE_ELEMENTS_DIR_PATH, templateName, writer, netDefinitionMode);
}
else {
include(layer.getResolvedThis().get(), partOfUnroll, writer, netDefinitionMode);
include(layer.getResolvedThis().get(), partOfUnroll, unrollIndex, writer, netDefinitionMode);
}
setCurrentElement(previousElement);
}
public void include(UnrollSymbol unrollElement, boolean partOfUnroll, Writer writer, NetDefinitionMode netDefinitionMode){
include(unrollElement.getBody(), partOfUnroll, writer, netDefinitionMode);
public void include(UnrollSymbol unrollElement, boolean partOfUnroll, int unrollIndex, Writer writer, NetDefinitionMode netDefinitionMode){
include(unrollElement.getBody(), partOfUnroll, unrollIndex, writer, netDefinitionMode);
String templateName = unrollElement.getDeclaration().getName();
include(TEMPLATE_ELEMENTS_DIR_PATH, templateName, writer, netDefinitionMode);
}
public void include(CompositeElementSymbol compositeElement, boolean partOfUnroll, Writer writer, NetDefinitionMode netDefinitionMode){
public void include(CompositeElementSymbol compositeElement, boolean partOfUnroll, int unrollIndex, Writer writer, NetDefinitionMode netDefinitionMode){
ArchitectureElementData previousElement = getCurrentElement();
setCurrentElement(compositeElement);
for (ArchitectureElementSymbol element : compositeElement.getElements()){
include(element, partOfUnroll, writer, netDefinitionMode);
include(element, partOfUnroll, unrollIndex, writer, netDefinitionMode);
}
setCurrentElement(previousElement);
}
public void include(ArchitectureElementSymbol architectureElement, boolean partOfUnroll, Writer writer, NetDefinitionMode netDefinitionMode){
public void include(ArchitectureElementSymbol architectureElement, boolean partOfUnroll, int unrollIndex, Writer writer, NetDefinitionMode netDefinitionMode){
if (architectureElement instanceof CompositeElementSymbol){
include((CompositeElementSymbol) architectureElement, partOfUnroll, writer, netDefinitionMode);
include((CompositeElementSymbol) architectureElement, partOfUnroll, unrollIndex, writer, netDefinitionMode);
}
else if (architectureElement instanceof LayerSymbol){
include((LayerSymbol) architectureElement, partOfUnroll, writer, netDefinitionMode);
include((LayerSymbol) architectureElement, partOfUnroll, unrollIndex, writer, netDefinitionMode);
}
else if (architectureElement instanceof ConstantSymbol) {
include((ConstantSymbol) architectureElement, partOfUnroll, writer, netDefinitionMode);
include((ConstantSymbol) architectureElement, partOfUnroll, unrollIndex, writer, netDefinitionMode);
}
else {
include((VariableSymbol) architectureElement, partOfUnroll, writer, netDefinitionMode);
include((VariableSymbol) architectureElement, partOfUnroll, unrollIndex, writer, netDefinitionMode);
}
}
public void include(ArchitectureElementSymbol architectureElementSymbol, String netDefinitionMode) {
include(architectureElementSymbol, false, NetDefinitionMode.fromString(netDefinitionMode));
include(architectureElementSymbol, false, -1, NetDefinitionMode.fromString(netDefinitionMode));
}
public void include(ArchitectureElementSymbol architectureElementSymbol, boolean partOfUnroll, String netDefinitionMode) {
include(architectureElementSymbol, partOfUnroll, NetDefinitionMode.fromString(netDefinitionMode));
public void include(ArchitectureElementSymbol architectureElementSymbol, boolean partOfUnroll, int unrollIndex, String netDefinitionMode) {
int layerIndex = -1;
include(architectureElementSymbol, partOfUnroll, unrollIndex, NetDefinitionMode.fromString(netDefinitionMode));
}
public void include(UnrollSymbol unrollSymbol, boolean partOfUnroll, String netDefinitionMode) {
include(unrollSymbol, partOfUnroll, NetDefinitionMode.fromString(netDefinitionMode));
public void include(UnrollSymbol unrollSymbol, boolean partOfUnroll, int unrollIndex, String netDefinitionMode) {
include(unrollSymbol, partOfUnroll, unrollIndex, NetDefinitionMode.fromString(netDefinitionMode));
}
public void include(ArchitectureElementSymbol architectureElement, boolean partOfUnroll, NetDefinitionMode netDefinitionMode){
public void include(ArchitectureElementSymbol architectureElement, boolean partOfUnroll, int unrollIndex, NetDefinitionMode netDefinitionMode){
if (getWriter() == null){
throw new IllegalStateException("missing writer");
}
include(architectureElement, partOfUnroll, getWriter(), netDefinitionMode);
include(architectureElement, partOfUnroll, unrollIndex, getWriter(), netDefinitionMode);
}
public void include(UnrollSymbol unroll, boolean partOfUnroll, NetDefinitionMode netDefinitionMode){
public void include(UnrollSymbol unroll, boolean partOfUnroll, int unrollIndex, NetDefinitionMode netDefinitionMode){
if (getWriter() == null){
throw new IllegalStateException("missing writer");
}
include(unroll, partOfUnroll, getWriter(), netDefinitionMode);
include(unroll, partOfUnroll, unrollIndex, getWriter(), netDefinitionMode);
}
public Set<String> getStreamInputNames(SerialCompositeElementSymbol stream) {
......
......@@ -97,15 +97,20 @@ ${tc.include(stream, "FORWARD_FUNCTION")}
<#list tc.architecture.unrolls as unroll>
<#list unroll.getBodiesForAllTimesteps() as body>
<#if body.isTrainable()>
<#if body?index == 0>
<#assign partOfUnroll = false>
<#else>
<#assign partOfUnroll = true>
</#if>
class Net_${tc.architecture.streams?size + body?index}(gluon.HybridBlock):
def __init__(self, data_mean=None, data_std=None, **kwargs):
super(Net_${tc.architecture.streams?size + body?index}, self).__init__(**kwargs)
self.last_layers = {}
with self.name_scope():
${tc.include(body, false, "ARCHITECTURE_DEFINITION")}
${tc.include(body, partOfUnroll, unroll?index, "ARCHITECTURE_DEFINITION")}
def hybrid_forward(self, F, ${tc.join(tc.getStreamInputNames(body), ", ")}):
${tc.include(body, true, "FORWARD_FUNCTION")}
${tc.include(body, partOfUnroll, unroll?index, "FORWARD_FUNCTION")}
return ${tc.join(tc.getStreamOutputNames(body), ", ")}
</#if>
</#list>
......
......@@ -4,15 +4,12 @@
<#assign flatten = element.flatten?string("True","False")>
<#if mode == "ARCHITECTURE_DEFINITION">
<#if element.partOfUnroll>
${element.name} = Net_1().${element.name}(${input})
<#assign unrollIndex = element.unrollIndex>
self.${element.name} = gluon.nn.Dense(units=${units}, use_bias=${use_bias}, flatten=${flatten}, params=Net_${unrollIndex + tc.architecture.streams?size}().${element.name}.collect_params())
<#else>
self.${element.name} = gluon.nn.Dense(units=${units}, use_bias=${use_bias}, flatten=${flatten})
</#if>
<#include "OutputShape.ftl">
<#elseif mode == "FORWARD_FUNCTION">
<#if element.partOfUnroll>
${element.name} = Net_1().${element.name}(${input})
<#else>
${element.name} = self.${element.name}(${input})
</#if>
</#if>
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment