Commit a4ae949c authored by Christian Fuß's avatar Christian Fuß

small changes to naming of layers in unrolls

parent 358fb624
Pipeline #177575 canceled with stages
......@@ -38,11 +38,13 @@ public class ArchitectureElementData {
private String name;
private ArchitectureElementSymbol element;
private CNNArchTemplateController templateController;
private boolean partOfUnroll;
public ArchitectureElementData(String name, ArchitectureElementSymbol element, CNNArchTemplateController templateController) {
this.name = name;
this.element = element;
this.templateController = templateController;
this.partOfUnroll = partOfUnroll;
}
public String getName() {
......@@ -69,6 +71,14 @@ public class ArchitectureElementData {
this.templateController = templateController;
}
public boolean getPartOfUnroll() {
return partOfUnroll;
}
public void setPartOfUnroll(boolean partOfUnroll) {
this.partOfUnroll= partOfUnroll;
}
private LayerSymbol getLayerSymbol() {
if (getElement() instanceof VariableSymbol) {
return ((VariableSymbol) getElement()).getLayerVariableDeclaration().getLayer();
......
......@@ -157,13 +157,20 @@ public abstract class CNNArchTemplateController {
for (VariableSymbol element : getArchitecture().getInputs()){
list.add(nameManager.getName(element));
}
for (UnrollSymbol unroll : getArchitecture().getUnrolls()){
for (SerialCompositeElementSymbol element: unroll.getBodiesForAllTimesteps()) {
list.add(nameManager.getName(element.getFirstAtomicElements().get(0)));
}
}
list.removeAll(Collections.singleton(null));
System.err.println("555555555_list: " + list);
return list;
}
public List<String> getArchitectureOutputs(){
List<String> list = new ArrayList<>();
for (VariableSymbol element : getArchitecture().getOutputs()){
if(nameManager.getName(element) != null) {
if(nameManager.getName(element) != null && !list.contains(nameManager.getName(element))) {
list.add(nameManager.getName(element));
}
}
......
......@@ -21,7 +21,6 @@
package de.monticore.lang.monticar.cnnarch.generator;
import de.monticore.lang.monticar.cnnarch._symboltable.*;
import de.monticore.lang.monticar.cnnarch.predefined.AllPredefinedLayers;
import de.monticore.lang.monticar.cnnarch.predefined.Convolution;
import de.monticore.lang.monticar.cnnarch.predefined.FullyConnected;
import de.monticore.lang.monticar.cnnarch.predefined.Pooling;
......@@ -33,15 +32,22 @@ public class LayerNameCreator {
private Map<ArchitectureElementSymbol, String> elementToName = new HashMap<>();
private Map<String, ArchitectureElementSymbol> nameToElement = new HashMap<>();
private boolean partOfUnroll = false;
private boolean inFirstUnrollTimestep = true;
public LayerNameCreator(ArchitectureSymbol architecture) {
int stage = 1;
for (SerialCompositeElementSymbol stream : architecture.getStreams()) {
stage = name(stream, stage, new ArrayList<>());
}
stage = 1;
for (UnrollSymbol unroll : architecture.getUnrolls()) {
partOfUnroll = true;
stage = name(unroll.createUnrollForBackend().getBody(), stage, new ArrayList<>());
for(int index = 0; index < unroll.getBodiesForAllTimesteps().size(); index++) {
if(index > 0){
inFirstUnrollTimestep = false;
}
stage = name(unroll.getBodiesForAllTimesteps().get(index), stage, new ArrayList<>());
}
}
}
......@@ -102,12 +108,23 @@ public class LayerNameCreator {
return Collections.max(endStages) + 1;
}
protected int add(ArchitectureElementSymbol architectureElement, int stage, List<Integer> streamIndices){
int endStage = stage;
if (!elementToName.containsKey(architectureElement) || partOfUnroll) {
if (!elementToName.containsKey(architectureElement) || (partOfUnroll)) {
String name = createName(architectureElement, endStage, streamIndices);
while (nameToElement.containsKey(name)) {
// The element is already registered, just in a different scope now and thus unrecognized (technically a different symbol)
if(architectureElement instanceof VariableSymbol && ((VariableSymbol) architectureElement).getType() == VariableSymbol.Type.IO){
elementToName.put(architectureElement, name);
return endStage;
}else if(partOfUnroll && !inFirstUnrollTimestep){
elementToName.put(architectureElement, name);
return endStage;
}
endStage++;
name = createName(architectureElement, endStage, streamIndices);
}
......@@ -136,11 +153,9 @@ public class LayerNameCreator {
String name = createBaseName(architectureElement);
if (element.getType() == VariableSymbol.Type.IO) {
if (element.getArrayAccess().isPresent() && !partOfUnroll){
if (element.getArrayAccess().isPresent()){
int arrayAccess = element.getArrayAccess().get().getIntValue().get();
name = name + "_" + arrayAccess + "_";
} else if(element.getArrayAccess().isPresent() && partOfUnroll) {
name = name + "_" + stage + "_";
} else {
name = name + "_";
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment