Commit 6b545750 authored by Christian Fuß's avatar Christian Fuß

fixed naming for multiple layers of one kind in unrolls

parent 7dab14b2
Pipeline #179273 failed with stages
in 4 minutes and 35 seconds
......@@ -39,12 +39,14 @@ public class ArchitectureElementData {
private ArchitectureElementSymbol element;
private CNNArchTemplateController templateController;
private boolean partOfUnroll;
private int unrollIndex;
public ArchitectureElementData(String name, ArchitectureElementSymbol element, CNNArchTemplateController templateController) {
this.name = name;
this.element = element;
this.templateController = templateController;
this.partOfUnroll = partOfUnroll;
this.unrollIndex = unrollIndex;
}
public String getName() {
......@@ -79,6 +81,14 @@ public class ArchitectureElementData {
this.partOfUnroll= partOfUnroll;
}
public int getUnrollIndex(){
return unrollIndex;
}
public void setUnrollIndex(int unrollIndex){
this.unrollIndex = unrollIndex;
}
private LayerSymbol getLayerSymbol() {
if (getElement() instanceof VariableSymbol) {
return ((VariableSymbol) getElement()).getLayerVariableDeclaration().getLayer();
......
......@@ -31,6 +31,8 @@ public class LayerNameCreator {
private Map<ArchitectureElementSymbol, String> elementToName = new HashMap<>();
private Map<String, ArchitectureElementSymbol> nameToElement = new HashMap<>();
private ArrayList<String> currentUnrollElementNames = new ArrayList<>();
private int elementIndex = 0;
private boolean partOfUnroll = false;
private boolean inFirstUnrollTimestep = true;
......@@ -41,11 +43,13 @@ public class LayerNameCreator {
}
stage = 1;
for (UnrollSymbol unroll : architecture.getUnrolls()) {
currentUnrollElementNames = new ArrayList<>();
partOfUnroll = true;
for(int index = 0; index < unroll.getBodiesForAllTimesteps().size(); index++) {
if(index > 0){
inFirstUnrollTimestep = false;
}
elementIndex = 0;
stage = name(unroll.getBodiesForAllTimesteps().get(index), stage, new ArrayList<>());
}
}
......@@ -121,7 +125,8 @@ public class LayerNameCreator {
elementToName.put(architectureElement, name);
return endStage;
}else if(partOfUnroll && !inFirstUnrollTimestep){
elementToName.put(architectureElement, name);
elementToName.put(architectureElement, currentUnrollElementNames.get(elementIndex));
elementIndex++;
return endStage;
}
......@@ -142,6 +147,13 @@ public class LayerNameCreator {
if (!isLayerVariable) {
nameToElement.put(name, architectureElement);
}
if(inFirstUnrollTimestep){
currentUnrollElementNames.add(name);
}
if(partOfUnroll){
elementIndex++;
}
}
return endStage;
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment