Commit 91b34c98 authored by lr119628's avatar lr119628
Browse files

[update]: formatting and rebasing

parent dbb3d28e
......@@ -106,7 +106,6 @@ public abstract class CNNArchTemplateController {
public boolean containsAdaNet(){
return this.architecture.containsAdaNet();
}
public String getName(ArchitectureElementSymbol layer){
return nameManager.getName(layer);
}
......@@ -133,8 +132,7 @@ public abstract class CNNArchTemplateController {
for (ArchitectureElementSymbol input : layer.getPrevious()) {
if(input.isArtificial()){
inputNames.add(getName(input));
}else
if (input.getOutputTypes().size() == 1) {
}else if (input.getOutputTypes().size() == 1) {
inputNames.add(getName(input));
} else {
for (int i = 0; i < input.getOutputTypes().size(); i++) {
......
......@@ -15,6 +15,7 @@ public class LayerNameCreator {
private Map<ArchitectureElementSymbol, String> elementToName = new HashMap<>();
private Set<String> names = new HashSet<>();
public LayerNameCreator(ArchitectureSymbol architecture) {
int stage = 1;
for (NetworkInstructionSymbol networkInstruction : architecture.getNetworkInstructions()) {
......@@ -30,11 +31,11 @@ public class LayerNameCreator {
}
}
public String getName(ArchitectureElementSymbol architectureElement) {
public String getName(ArchitectureElementSymbol architectureElement){
return elementToName.get(architectureElement);
}
protected int name(ArchitectureElementSymbol architectureElement, int stage, List<Integer> streamIndices) {
protected int name(ArchitectureElementSymbol architectureElement, int stage, List<Integer> streamIndices){
if (architectureElement instanceof SerialCompositeElementSymbol) {
return nameSerialComposite((SerialCompositeElementSymbol) architectureElement, stage, streamIndices);
} else if (architectureElement instanceof ParallelCompositeElementSymbol) {
......@@ -80,7 +81,6 @@ public class LayerNameCreator {
endStage = name(subElement, endStage, streamIndices);
} else if (subElement.getName().equals(AllPredefinedLayers.AdaNet_Name)) {
// name outBlock
endStage = nameAdaNetBlock(AllPredefinedLayers.Out,subElement,endStage,streamIndices);
// name inBlock
endStage = nameAdaNetBlock(AllPredefinedLayers.In,subElement,endStage,streamIndices);
......@@ -101,13 +101,13 @@ public class LayerNameCreator {
return endStage;
}
protected int nameParallelComposite(ParallelCompositeElementSymbol compositeElement, int stage, List<Integer> streamIndices) {
protected int nameParallelComposite(ParallelCompositeElementSymbol compositeElement, int stage, List<Integer> streamIndices){
int startStage = stage + 1;
streamIndices.add(1);
int lastIndex = streamIndices.size() - 1;
List<Integer> endStages = new ArrayList<>();
for (ArchitectureElementSymbol subElement : compositeElement.getElements()) {
for (ArchitectureElementSymbol subElement : compositeElement.getElements()){
endStages.add(name(subElement, startStage, streamIndices));
streamIndices.set(lastIndex, streamIndices.get(lastIndex) + 1);
}
......@@ -116,7 +116,7 @@ public class LayerNameCreator {
return Collections.max(endStages) + 1;
}
protected int add(ArchitectureElementSymbol architectureElement, int stage, List<Integer> streamIndices) {
protected int add(ArchitectureElementSymbol architectureElement, int stage, List<Integer> streamIndices){
int endStage = stage;
if (!elementToName.containsKey(architectureElement)) {
String name = createName(architectureElement, endStage, streamIndices);
......@@ -134,7 +134,7 @@ public class LayerNameCreator {
return endStage;
}
protected String createName(ArchitectureElementSymbol architectureElement, int stage, List<Integer> streamIndices) {
protected String createName(ArchitectureElementSymbol architectureElement, int stage, List<Integer> streamIndices){
if (architectureElement instanceof VariableSymbol) {
VariableSymbol element = (VariableSymbol) architectureElement;
......@@ -148,7 +148,7 @@ public class LayerNameCreator {
}
}
if (element.getArrayAccess().isPresent()) {
if (element.getArrayAccess().isPresent()){
int arrayAccess = element.getArrayAccess().get().getIntValue().get();
name = name + arrayAccess + "_";
}
......@@ -160,7 +160,7 @@ public class LayerNameCreator {
}
protected String createBaseName(ArchitectureElementSymbol architectureElement) {
protected String createBaseName(ArchitectureElementSymbol architectureElement){
if (architectureElement instanceof LayerSymbol) {
LayerDeclarationSymbol layerDeclaration = ((LayerSymbol) architectureElement).getDeclaration();
if (layerDeclaration instanceof Convolution) {
......@@ -174,16 +174,16 @@ public class LayerNameCreator {
} else {
return layerDeclaration.getName().toLowerCase();
}
} else if (architectureElement instanceof CompositeElementSymbol) {
} else if (architectureElement instanceof CompositeElementSymbol){
return "group";
} else {
return architectureElement.getName();
}
}
protected String createStreamPostfix(List<Integer> streamIndices) {
protected String createStreamPostfix(List<Integer> streamIndices){
StringBuilder stringBuilder = new StringBuilder();
for (int streamIndex : streamIndices) {
for (int streamIndex : streamIndices){
stringBuilder.append("_");
stringBuilder.append(streamIndex);
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment