Aufgrund einer Wartung wird GitLab am 18.01. zwischen 8:00 und 9:00 Uhr kurzzeitig nicht zur Verfügung stehen. / Due to maintenance, GitLab will be temporarily unavailable on 18.01. between 8:00 and 9:00 am.

Commit 75291479 authored by lr119628's avatar lr119628
Browse files

[update]: updated naming logic

parent 24cdb964
......@@ -3,6 +3,7 @@ package de.monticore.lang.monticar.cnnarch.generator;
import de.monticore.lang.monticar.cnnarch._symboltable.*;
import de.monticore.lang.monticar.cnnarch.predefined.*;
import sun.nio.ch.Net;
import java.util.*;
......@@ -10,7 +11,6 @@ public class LayerNameCreator {
private Map<ArchitectureElementSymbol, String> elementToName = new HashMap<>();
private Set<String> names = new HashSet<>();
//List<String> artifical_layer = new ArrayList<String>(); // name list for the defined layers within cnnarch
public LayerNameCreator(ArchitectureSymbol architecture) {
int stage = 1;
for (NetworkInstructionSymbol networkInstruction : architecture.getNetworkInstructions()) {
......@@ -26,79 +26,80 @@ public class LayerNameCreator {
}
}
public String getName(ArchitectureElementSymbol architectureElement){
public String getName(ArchitectureElementSymbol architectureElement) {
return elementToName.get(architectureElement);
}
protected int name(ArchitectureElementSymbol architectureElement, int stage, List<Integer> streamIndices){
protected int name(ArchitectureElementSymbol architectureElement, int stage, List<Integer> streamIndices) {
if (architectureElement instanceof SerialCompositeElementSymbol) {
return nameSerialComposite((SerialCompositeElementSymbol) architectureElement, stage, streamIndices);
} else if (architectureElement instanceof ParallelCompositeElementSymbol) {
return nameParallelComposite((ParallelCompositeElementSymbol) architectureElement, stage, streamIndices);
} else {
if (architectureElement.isAtomic()&& !architectureElement.isArtificial()) {
boolean noAdaNet = !architectureElement.containsAdaNet();
// flag which is true if there is no AdaNet inside the architecture
if (architectureElement.isAtomic() && (!architectureElement.isArtificial() || noAdaNet)) {
if (architectureElement.getMaxSerialLength().get() > 0) {
return add(architectureElement, stage, streamIndices);
} else {
return stage;
}
}else{
} else {
ArchitectureElementSymbol resolvedElement = (ArchitectureElementSymbol) architectureElement.getResolvedThis().get();
int final_stage = name(resolvedElement, stage, streamIndices);
if (architectureElement.isArtificial()){
// if the element is artificial the name needs to be added
final_stage = add(architectureElement,final_stage,streamIndices);
if (architectureElement.isArtificial() && !noAdaNet) {
// if the element is artificial the name needs to be added only if an adaNet layer is present
final_stage = add(architectureElement, final_stage, streamIndices);
}
return final_stage;
}
}
}
protected int nameSerialComposite(SerialCompositeElementSymbol compositeElement, int stage, List<Integer> streamIndices){
protected int nameAdaNetBlock(String target,ArchitectureElementSymbol subElement,int endStage, List<Integer> streamIndices){
ArchitectureElementSymbol currentBlock = ((AdaNet) ((LayerSymbol) subElement).getDeclaration()).getBlock(target).get();
if (currentBlock.isArtificial()) {
boolean oldState = currentBlock.containsAdaNet();
currentBlock.setAdaNet(true);
endStage = name(currentBlock, endStage, streamIndices);
currentBlock.setAdaNet(oldState);
}
return endStage;
}
protected int nameSerialComposite(SerialCompositeElementSymbol compositeElement, int stage, List<Integer> streamIndices) {
int endStage = stage;
for (ArchitectureElementSymbol subElement : compositeElement.getElements()){
if (subElement.isArtificial()) {
for (ArchitectureElementSymbol subElement : compositeElement.getElements()) {
if (subElement.isArtificial() && compositeElement.containsAdaNet()) {
endStage = name(subElement, endStage, streamIndices);
}else if(subElement.getName().equals(AllPredefinedLayers.AdaNet_Name)){
ArchitectureElementSymbol currentBlock;
// get OutBlock and name it
currentBlock = ((AdaNet)((LayerSymbol) subElement).getDeclaration()).getBlock(AllPredefinedLayers.Out).get();
if(currentBlock.isArtificial()){
endStage = name(currentBlock,endStage,streamIndices);
}
// get inBlock and name it
currentBlock = ((AdaNet)((LayerSymbol) subElement).getDeclaration()).getBlock(AllPredefinedLayers.In).get();
if(currentBlock.isArtificial()){
endStage = name(currentBlock,endStage,streamIndices);
}
// get buildingBlock and name it
currentBlock = ((AdaNet)((LayerSymbol) subElement).getDeclaration()).getBlock(AllPredefinedLayers.Block).get();
if(currentBlock.isArtificial()){
endStage = name(currentBlock,endStage,streamIndices);
}
} else if (subElement.getName().equals(AllPredefinedLayers.AdaNet_Name)) {
// name outBlock
endStage = nameAdaNetBlock(AllPredefinedLayers.Out,subElement,endStage,streamIndices);
// name inBlock
endStage = nameAdaNetBlock(AllPredefinedLayers.In,subElement,endStage,streamIndices);
// name buildBlock
endStage = nameAdaNetBlock(AllPredefinedLayers.Block,subElement,endStage,streamIndices);
endStage = name(subElement,endStage,streamIndices);
}else{
endStage = name(subElement, endStage, streamIndices);
} else {
endStage = name(subElement, endStage, streamIndices);
}
}
for (List<ArchitectureElementSymbol> subNetwork : compositeElement.getEpisodicSubNetworks()){
for (ArchitectureElementSymbol subElement : subNetwork){
for (List<ArchitectureElementSymbol> subNetwork : compositeElement.getEpisodicSubNetworks()) {
for (ArchitectureElementSymbol subElement : subNetwork) {
endStage = name(subElement, endStage, streamIndices);
}
}
return endStage;
}
protected int nameParallelComposite(ParallelCompositeElementSymbol compositeElement, int stage, List<Integer> streamIndices){
protected int nameParallelComposite(ParallelCompositeElementSymbol compositeElement, int stage, List<Integer> streamIndices) {
int startStage = stage + 1;
streamIndices.add(1);
int lastIndex = streamIndices.size() - 1;
List<Integer> endStages = new ArrayList<>();
for (ArchitectureElementSymbol subElement : compositeElement.getElements()){
for (ArchitectureElementSymbol subElement : compositeElement.getElements()) {
endStages.add(name(subElement, startStage, streamIndices));
streamIndices.set(lastIndex, streamIndices.get(lastIndex) + 1);
}
......@@ -107,7 +108,7 @@ public class LayerNameCreator {
return Collections.max(endStages) + 1;
}
protected int add(ArchitectureElementSymbol architectureElement, int stage, List<Integer> streamIndices){
protected int add(ArchitectureElementSymbol architectureElement, int stage, List<Integer> streamIndices) {
int endStage = stage;
if (!elementToName.containsKey(architectureElement)) {
String name = createName(architectureElement, endStage, streamIndices);
......@@ -125,7 +126,7 @@ public class LayerNameCreator {
return endStage;
}
protected String createName(ArchitectureElementSymbol architectureElement, int stage, List<Integer> streamIndices){
protected String createName(ArchitectureElementSymbol architectureElement, int stage, List<Integer> streamIndices) {
if (architectureElement instanceof VariableSymbol) {
VariableSymbol element = (VariableSymbol) architectureElement;
......@@ -139,7 +140,7 @@ public class LayerNameCreator {
}
}
if (element.getArrayAccess().isPresent()){
if (element.getArrayAccess().isPresent()) {
int arrayAccess = element.getArrayAccess().get().getIntValue().get();
name = name + arrayAccess + "_";
}
......@@ -151,7 +152,7 @@ public class LayerNameCreator {
}
protected String createBaseName(ArchitectureElementSymbol architectureElement){
protected String createBaseName(ArchitectureElementSymbol architectureElement) {
if (architectureElement instanceof LayerSymbol) {
LayerDeclarationSymbol layerDeclaration = ((LayerSymbol) architectureElement).getDeclaration();
if (layerDeclaration instanceof Convolution) {
......@@ -165,16 +166,16 @@ public class LayerNameCreator {
} else {
return layerDeclaration.getName().toLowerCase();
}
} else if (architectureElement instanceof CompositeElementSymbol){
} else if (architectureElement instanceof CompositeElementSymbol) {
return "group";
} else {
return architectureElement.getName();
}
}
protected String createStreamPostfix(List<Integer> streamIndices){
protected String createStreamPostfix(List<Integer> streamIndices) {
StringBuilder stringBuilder = new StringBuilder();
for (int streamIndex : streamIndices){
for (int streamIndex : streamIndices) {
stringBuilder.append("_");
stringBuilder.append(streamIndex);
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment