Skip to content
Snippets Groups Projects
Commit 12b7f4ef authored by Christian Fuß's avatar Christian Fuß
Browse files

fixed naming of some elements in unrolls

parent 97212579
No related branches found
No related tags found
1 merge request!23Added Unroll-related features and layers
Pipeline #179274 failed
......@@ -48,7 +48,7 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController {
getTemplateConfiguration().processTemplate(ftlContext, templatePath, writer);
}
public void include(VariableSymbol element, boolean partOfUnroll, Writer writer, NetDefinitionMode netDefinitionMode){
public void include(VariableSymbol element, boolean partOfUnroll, int unrollIndex, Writer writer, NetDefinitionMode netDefinitionMode){
ArchitectureElementData previousElement = getCurrentElement();
setCurrentElement(element);
......@@ -66,13 +66,13 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController {
}
}
else {
include(element.getResolvedThis().get(), partOfUnroll, writer, netDefinitionMode);
include(element.getResolvedThis().get(), partOfUnroll, unrollIndex, writer, netDefinitionMode);
}
setCurrentElement(previousElement);
}
public void include(ConstantSymbol constant, boolean partOfUnroll, Writer writer, NetDefinitionMode netDefinitionMode) {
public void include(ConstantSymbol constant, boolean partOfUnroll, int unrollIndex, Writer writer, NetDefinitionMode netDefinitionMode) {
ArchitectureElementData previousElement = getCurrentElement();
setCurrentElement(constant);
......@@ -80,85 +80,88 @@ public class CNNArch2GluonTemplateController extends CNNArchTemplateController {
include(TEMPLATE_ELEMENTS_DIR_PATH, "Const", writer, netDefinitionMode);
}
else {
include(constant.getResolvedThis().get(), partOfUnroll, writer, netDefinitionMode);
include(constant.getResolvedThis().get(), partOfUnroll, unrollIndex, writer, netDefinitionMode);
}
setCurrentElement(previousElement);
}
public void include(LayerSymbol layer, boolean partOfUnroll, Writer writer, NetDefinitionMode netDefinitionMode){
public void include(LayerSymbol layer, boolean partOfUnroll, int unrollIndex, Writer writer, NetDefinitionMode netDefinitionMode){
ArchitectureElementData previousElement = getCurrentElement();
setCurrentElement(layer);
getCurrentElement().setPartOfUnroll(partOfUnroll);
getCurrentElement().setUnrollIndex(unrollIndex);
if (layer.isAtomic()){
String templateName = layer.getDeclaration().getName();
include(TEMPLATE_ELEMENTS_DIR_PATH, templateName, writer, netDefinitionMode);
}
else {
include(layer.getResolvedThis().get(), partOfUnroll, writer, netDefinitionMode);
include(layer.getResolvedThis().get(), partOfUnroll, unrollIndex, writer, netDefinitionMode);
}
setCurrentElement(previousElement);
}
public void include(UnrollSymbol unrollElement, boolean partOfUnroll, Writer writer, NetDefinitionMode netDefinitionMode){
include(unrollElement.getBody(), partOfUnroll, writer, netDefinitionMode);
public void include(UnrollSymbol unrollElement, boolean partOfUnroll, int unrollIndex, Writer writer, NetDefinitionMode netDefinitionMode){
include(unrollElement.getBody(), partOfUnroll, unrollIndex, writer, netDefinitionMode);
String templateName = unrollElement.getDeclaration().getName();
include(TEMPLATE_ELEMENTS_DIR_PATH, templateName, writer, netDefinitionMode);
}
public void include(CompositeElementSymbol compositeElement, boolean partOfUnroll, Writer writer, NetDefinitionMode netDefinitionMode){
public void include(CompositeElementSymbol compositeElement, boolean partOfUnroll, int unrollIndex, Writer writer, NetDefinitionMode netDefinitionMode){
ArchitectureElementData previousElement = getCurrentElement();
setCurrentElement(compositeElement);
for (ArchitectureElementSymbol element : compositeElement.getElements()){
include(element, partOfUnroll, writer, netDefinitionMode);
include(element, partOfUnroll, unrollIndex, writer, netDefinitionMode);
}
setCurrentElement(previousElement);
}
public void include(ArchitectureElementSymbol architectureElement, boolean partOfUnroll, Writer writer, NetDefinitionMode netDefinitionMode){
public void include(ArchitectureElementSymbol architectureElement, boolean partOfUnroll, int unrollIndex, Writer writer, NetDefinitionMode netDefinitionMode){
if (architectureElement instanceof CompositeElementSymbol){
include((CompositeElementSymbol) architectureElement, partOfUnroll, writer, netDefinitionMode);
include((CompositeElementSymbol) architectureElement, partOfUnroll, unrollIndex, writer, netDefinitionMode);
}
else if (architectureElement instanceof LayerSymbol){
include((LayerSymbol) architectureElement, partOfUnroll, writer, netDefinitionMode);
include((LayerSymbol) architectureElement, partOfUnroll, unrollIndex, writer, netDefinitionMode);
}
else if (architectureElement instanceof ConstantSymbol) {
include((ConstantSymbol) architectureElement, partOfUnroll, writer, netDefinitionMode);
include((ConstantSymbol) architectureElement, partOfUnroll, unrollIndex, writer, netDefinitionMode);
}
else {
include((VariableSymbol) architectureElement, partOfUnroll, writer, netDefinitionMode);
include((VariableSymbol) architectureElement, partOfUnroll, unrollIndex, writer, netDefinitionMode);
}
}
public void include(ArchitectureElementSymbol architectureElementSymbol, String netDefinitionMode) {
include(architectureElementSymbol, false, NetDefinitionMode.fromString(netDefinitionMode));
include(architectureElementSymbol, false, -1, NetDefinitionMode.fromString(netDefinitionMode));
}
public void include(ArchitectureElementSymbol architectureElementSymbol, boolean partOfUnroll, String netDefinitionMode) {
include(architectureElementSymbol, partOfUnroll, NetDefinitionMode.fromString(netDefinitionMode));
public void include(ArchitectureElementSymbol architectureElementSymbol, boolean partOfUnroll, int unrollIndex, String netDefinitionMode) {
int layerIndex = -1;
include(architectureElementSymbol, partOfUnroll, unrollIndex, NetDefinitionMode.fromString(netDefinitionMode));
}
public void include(UnrollSymbol unrollSymbol, boolean partOfUnroll, String netDefinitionMode) {
include(unrollSymbol, partOfUnroll, NetDefinitionMode.fromString(netDefinitionMode));
public void include(UnrollSymbol unrollSymbol, boolean partOfUnroll, int unrollIndex, String netDefinitionMode) {
include(unrollSymbol, partOfUnroll, unrollIndex, NetDefinitionMode.fromString(netDefinitionMode));
}
public void include(ArchitectureElementSymbol architectureElement, boolean partOfUnroll, NetDefinitionMode netDefinitionMode){
public void include(ArchitectureElementSymbol architectureElement, boolean partOfUnroll, int unrollIndex, NetDefinitionMode netDefinitionMode){
if (getWriter() == null){
throw new IllegalStateException("missing writer");
}
include(architectureElement, partOfUnroll, getWriter(), netDefinitionMode);
include(architectureElement, partOfUnroll, unrollIndex, getWriter(), netDefinitionMode);
}
public void include(UnrollSymbol unroll, boolean partOfUnroll, NetDefinitionMode netDefinitionMode){
public void include(UnrollSymbol unroll, boolean partOfUnroll, int unrollIndex, NetDefinitionMode netDefinitionMode){
if (getWriter() == null){
throw new IllegalStateException("missing writer");
}
include(unroll, partOfUnroll, getWriter(), netDefinitionMode);
include(unroll, partOfUnroll, unrollIndex, getWriter(), netDefinitionMode);
}
public Set<String> getStreamInputNames(SerialCompositeElementSymbol stream) {
......
......@@ -97,15 +97,20 @@ ${tc.include(stream, "FORWARD_FUNCTION")}
<#list tc.architecture.unrolls as unroll>
<#list unroll.getBodiesForAllTimesteps() as body>
<#if body.isTrainable()>
<#if body?index == 0>
<#assign partOfUnroll = false>
<#else>
<#assign partOfUnroll = true>
</#if>
class Net_${tc.architecture.streams?size + body?index}(gluon.HybridBlock):
def __init__(self, data_mean=None, data_std=None, **kwargs):
super(Net_${tc.architecture.streams?size + body?index}, self).__init__(**kwargs)
self.last_layers = {}
with self.name_scope():
${tc.include(body, false, "ARCHITECTURE_DEFINITION")}
${tc.include(body, partOfUnroll, unroll?index, "ARCHITECTURE_DEFINITION")}
def hybrid_forward(self, F, ${tc.join(tc.getStreamInputNames(body), ", ")}):
${tc.include(body, true, "FORWARD_FUNCTION")}
${tc.include(body, partOfUnroll, unroll?index, "FORWARD_FUNCTION")}
return ${tc.join(tc.getStreamOutputNames(body), ", ")}
</#if>
</#list>
......
......@@ -4,15 +4,12 @@
<#assign flatten = element.flatten?string("True","False")>
<#if mode == "ARCHITECTURE_DEFINITION">
<#if element.partOfUnroll>
${element.name} = Net_1().${element.name}(${input})
<#assign unrollIndex = element.unrollIndex>
self.${element.name} = gluon.nn.Dense(units=${units}, use_bias=${use_bias}, flatten=${flatten}, params=Net_${unrollIndex + tc.architecture.streams?size}().${element.name}.collect_params())
<#else>
self.${element.name} = gluon.nn.Dense(units=${units}, use_bias=${use_bias}, flatten=${flatten})
</#if>
<#include "OutputShape.ftl">
<#elseif mode == "FORWARD_FUNCTION">
<#if element.partOfUnroll>
${element.name} = Net_1().${element.name}(${input})
<#else>
${element.name} = self.${element.name}(${input})
</#if>
</#if>
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment