Commit a4ae949c authored by Christian Fuß's avatar Christian Fuß

small changes to naming of layers in unrolls

parent 358fb624
Pipeline #177575 canceled with stages
...@@ -38,11 +38,13 @@ public class ArchitectureElementData { ...@@ -38,11 +38,13 @@ public class ArchitectureElementData {
private String name; private String name;
private ArchitectureElementSymbol element; private ArchitectureElementSymbol element;
private CNNArchTemplateController templateController; private CNNArchTemplateController templateController;
private boolean partOfUnroll;
public ArchitectureElementData(String name, ArchitectureElementSymbol element, CNNArchTemplateController templateController) { public ArchitectureElementData(String name, ArchitectureElementSymbol element, CNNArchTemplateController templateController) {
this.name = name; this.name = name;
this.element = element; this.element = element;
this.templateController = templateController; this.templateController = templateController;
this.partOfUnroll = partOfUnroll;
} }
public String getName() { public String getName() {
...@@ -69,6 +71,14 @@ public class ArchitectureElementData { ...@@ -69,6 +71,14 @@ public class ArchitectureElementData {
this.templateController = templateController; this.templateController = templateController;
} }
public boolean getPartOfUnroll() {
return partOfUnroll;
}
public void setPartOfUnroll(boolean partOfUnroll) {
this.partOfUnroll= partOfUnroll;
}
private LayerSymbol getLayerSymbol() { private LayerSymbol getLayerSymbol() {
if (getElement() instanceof VariableSymbol) { if (getElement() instanceof VariableSymbol) {
return ((VariableSymbol) getElement()).getLayerVariableDeclaration().getLayer(); return ((VariableSymbol) getElement()).getLayerVariableDeclaration().getLayer();
......
...@@ -157,13 +157,20 @@ public abstract class CNNArchTemplateController { ...@@ -157,13 +157,20 @@ public abstract class CNNArchTemplateController {
for (VariableSymbol element : getArchitecture().getInputs()){ for (VariableSymbol element : getArchitecture().getInputs()){
list.add(nameManager.getName(element)); list.add(nameManager.getName(element));
} }
for (UnrollSymbol unroll : getArchitecture().getUnrolls()){
for (SerialCompositeElementSymbol element: unroll.getBodiesForAllTimesteps()) {
list.add(nameManager.getName(element.getFirstAtomicElements().get(0)));
}
}
list.removeAll(Collections.singleton(null));
System.err.println("555555555_list: " + list);
return list; return list;
} }
public List<String> getArchitectureOutputs(){ public List<String> getArchitectureOutputs(){
List<String> list = new ArrayList<>(); List<String> list = new ArrayList<>();
for (VariableSymbol element : getArchitecture().getOutputs()){ for (VariableSymbol element : getArchitecture().getOutputs()){
if(nameManager.getName(element) != null) { if(nameManager.getName(element) != null && !list.contains(nameManager.getName(element))) {
list.add(nameManager.getName(element)); list.add(nameManager.getName(element));
} }
} }
......
...@@ -21,7 +21,6 @@ ...@@ -21,7 +21,6 @@
package de.monticore.lang.monticar.cnnarch.generator; package de.monticore.lang.monticar.cnnarch.generator;
import de.monticore.lang.monticar.cnnarch._symboltable.*; import de.monticore.lang.monticar.cnnarch._symboltable.*;
import de.monticore.lang.monticar.cnnarch.predefined.AllPredefinedLayers;
import de.monticore.lang.monticar.cnnarch.predefined.Convolution; import de.monticore.lang.monticar.cnnarch.predefined.Convolution;
import de.monticore.lang.monticar.cnnarch.predefined.FullyConnected; import de.monticore.lang.monticar.cnnarch.predefined.FullyConnected;
import de.monticore.lang.monticar.cnnarch.predefined.Pooling; import de.monticore.lang.monticar.cnnarch.predefined.Pooling;
...@@ -33,15 +32,22 @@ public class LayerNameCreator { ...@@ -33,15 +32,22 @@ public class LayerNameCreator {
private Map<ArchitectureElementSymbol, String> elementToName = new HashMap<>(); private Map<ArchitectureElementSymbol, String> elementToName = new HashMap<>();
private Map<String, ArchitectureElementSymbol> nameToElement = new HashMap<>(); private Map<String, ArchitectureElementSymbol> nameToElement = new HashMap<>();
private boolean partOfUnroll = false; private boolean partOfUnroll = false;
private boolean inFirstUnrollTimestep = true;
public LayerNameCreator(ArchitectureSymbol architecture) { public LayerNameCreator(ArchitectureSymbol architecture) {
int stage = 1; int stage = 1;
for (SerialCompositeElementSymbol stream : architecture.getStreams()) { for (SerialCompositeElementSymbol stream : architecture.getStreams()) {
stage = name(stream, stage, new ArrayList<>()); stage = name(stream, stage, new ArrayList<>());
} }
stage = 1;
for (UnrollSymbol unroll : architecture.getUnrolls()) { for (UnrollSymbol unroll : architecture.getUnrolls()) {
partOfUnroll = true; partOfUnroll = true;
stage = name(unroll.createUnrollForBackend().getBody(), stage, new ArrayList<>()); for(int index = 0; index < unroll.getBodiesForAllTimesteps().size(); index++) {
if(index > 0){
inFirstUnrollTimestep = false;
}
stage = name(unroll.getBodiesForAllTimesteps().get(index), stage, new ArrayList<>());
}
} }
} }
...@@ -102,12 +108,23 @@ public class LayerNameCreator { ...@@ -102,12 +108,23 @@ public class LayerNameCreator {
return Collections.max(endStages) + 1; return Collections.max(endStages) + 1;
} }
protected int add(ArchitectureElementSymbol architectureElement, int stage, List<Integer> streamIndices){ protected int add(ArchitectureElementSymbol architectureElement, int stage, List<Integer> streamIndices){
int endStage = stage; int endStage = stage;
if (!elementToName.containsKey(architectureElement) || partOfUnroll) { if (!elementToName.containsKey(architectureElement) || (partOfUnroll)) {
String name = createName(architectureElement, endStage, streamIndices); String name = createName(architectureElement, endStage, streamIndices);
while (nameToElement.containsKey(name)) { while (nameToElement.containsKey(name)) {
// The element is already registered, just in a different scope now and thus unrecognized (technically a different symbol)
if(architectureElement instanceof VariableSymbol && ((VariableSymbol) architectureElement).getType() == VariableSymbol.Type.IO){
elementToName.put(architectureElement, name);
return endStage;
}else if(partOfUnroll && !inFirstUnrollTimestep){
elementToName.put(architectureElement, name);
return endStage;
}
endStage++; endStage++;
name = createName(architectureElement, endStage, streamIndices); name = createName(architectureElement, endStage, streamIndices);
} }
...@@ -136,11 +153,9 @@ public class LayerNameCreator { ...@@ -136,11 +153,9 @@ public class LayerNameCreator {
String name = createBaseName(architectureElement); String name = createBaseName(architectureElement);
if (element.getType() == VariableSymbol.Type.IO) { if (element.getType() == VariableSymbol.Type.IO) {
if (element.getArrayAccess().isPresent() && !partOfUnroll){ if (element.getArrayAccess().isPresent()){
int arrayAccess = element.getArrayAccess().get().getIntValue().get(); int arrayAccess = element.getArrayAccess().get().getIntValue().get();
name = name + "_" + arrayAccess + "_"; name = name + "_" + arrayAccess + "_";
} else if(element.getArrayAccess().isPresent() && partOfUnroll) {
name = name + "_" + stage + "_";
} else { } else {
name = name + "_"; name = name + "_";
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment