Commit 358fb624 authored by Christian Fuß's avatar Christian Fuß

fixed naming for unrolls

parent 42fe0363
Pipeline #176014 failed with stages
in 3 minutes and 24 seconds
......@@ -163,7 +163,9 @@ public abstract class CNNArchTemplateController {
public List<String> getArchitectureOutputs(){
List<String> list = new ArrayList<>();
for (VariableSymbol element : getArchitecture().getOutputs()){
list.add(nameManager.getName(element));
if(nameManager.getName(element) != null) {
list.add(nameManager.getName(element));
}
}
return list;
}
......
......@@ -21,6 +21,7 @@
package de.monticore.lang.monticar.cnnarch.generator;
import de.monticore.lang.monticar.cnnarch._symboltable.*;
import de.monticore.lang.monticar.cnnarch.predefined.AllPredefinedLayers;
import de.monticore.lang.monticar.cnnarch.predefined.Convolution;
import de.monticore.lang.monticar.cnnarch.predefined.FullyConnected;
import de.monticore.lang.monticar.cnnarch.predefined.Pooling;
......@@ -31,6 +32,7 @@ public class LayerNameCreator {
private Map<ArchitectureElementSymbol, String> elementToName = new HashMap<>();
private Map<String, ArchitectureElementSymbol> nameToElement = new HashMap<>();
private boolean partOfUnroll = false;
public LayerNameCreator(ArchitectureSymbol architecture) {
int stage = 1;
......@@ -38,7 +40,8 @@ public class LayerNameCreator {
stage = name(stream, stage, new ArrayList<>());
}
for (UnrollSymbol unroll : architecture.getUnrolls()) {
stage = name(unroll.getBody(), stage, new ArrayList<>());
partOfUnroll = true;
stage = name(unroll.createUnrollForBackend().getBody(), stage, new ArrayList<>());
}
}
......@@ -63,6 +66,13 @@ public class LayerNameCreator {
return stage;
}
} else {
if(!architectureElement.isResolved()){
try {
architectureElement.resolve();
} catch (ArchResolveException e) {
e.printStackTrace();
}
}
ArchitectureElementSymbol resolvedElement = architectureElement.getResolvedThis().get();
return name(resolvedElement, stage, streamIndices);
}
......@@ -92,17 +102,9 @@ public class LayerNameCreator {
return Collections.max(endStages) + 1;
}
protected int nameUnroll(UnrollSymbol unrollElement, int stage, List<Integer> streamIndices){
int endStage = stage;
for (ArchitectureElementSymbol subElement : unrollElement.getBody().getElements()){
endStage = name(subElement, endStage, streamIndices);
}
return endStage;
}
protected int add(ArchitectureElementSymbol architectureElement, int stage, List<Integer> streamIndices){
int endStage = stage;
if (!elementToName.containsKey(architectureElement)) {
if (!elementToName.containsKey(architectureElement) || partOfUnroll) {
String name = createName(architectureElement, endStage, streamIndices);
while (nameToElement.containsKey(name)) {
......@@ -134,9 +136,11 @@ public class LayerNameCreator {
String name = createBaseName(architectureElement);
if (element.getType() == VariableSymbol.Type.IO) {
if (element.getArrayAccess().isPresent()){
if (element.getArrayAccess().isPresent() && !partOfUnroll){
int arrayAccess = element.getArrayAccess().get().getIntValue().get();
name = name + "_" + arrayAccess + "_";
} else if(element.getArrayAccess().isPresent() && partOfUnroll) {
name = name + "_" + stage + "_";
} else {
name = name + "_";
}
......@@ -183,4 +187,3 @@ public class LayerNameCreator {
return stringBuilder.toString();
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment