Commit 071ebc54 authored by Sebastian Nickels's avatar Sebastian Nickels

Split of serial and parallel CompositeElementSymbol

parent d4f356a9
......@@ -46,6 +46,11 @@ public class CNNArch2MxNet extends CNNArchGenerator {
setGenerationTargetPath("./target/generated-sources-cnnarch/");
}
// TODO: Rewrite so that CNNArchSymbolCompiler is used in EMADL2CPP instead of this method
public boolean check(ArchitectureSymbol architecture) {
return architectureSupportChecker.check(architecture) && layerSupportChecker.check(architecture);
}
public void generate(Scope scope, String rootModelName){
CNNArchSymbolCompiler symbolCompiler = new CNNArchSymbolCompiler(architectureSupportChecker, layerSupportChecker);
ArchitectureSymbol architectureSymbol = symbolCompiler.compileArchitectureSymbol(scope, rootModelName);
......
......@@ -34,7 +34,7 @@ public class LayerNameCreator {
public LayerNameCreator(ArchitectureSymbol architecture) {
int stage = 1;
for (CompositeElementSymbol stream : architecture.getStreams()) {
for (SerialCompositeElementSymbol stream : architecture.getStreams()) {
stage = name(stream, stage, new ArrayList<>());
}
}
......@@ -48,8 +48,10 @@ public class LayerNameCreator {
}
protected int name(ArchitectureElementSymbol architectureElement, int stage, List<Integer> streamIndices){
if (architectureElement instanceof CompositeElementSymbol){
return nameComposite((CompositeElementSymbol) architectureElement, stage, streamIndices);
if (architectureElement instanceof SerialCompositeElementSymbol) {
return nameSerialComposite((SerialCompositeElementSymbol) architectureElement, stage, streamIndices);
} else if (architectureElement instanceof ParallelCompositeElementSymbol){
return nameParallelComposite((ParallelCompositeElementSymbol) architectureElement, stage, streamIndices);
} else{
if (architectureElement.isAtomic()){
if (architectureElement.getMaxSerialLength().get() > 0){
......@@ -64,27 +66,27 @@ public class LayerNameCreator {
}
}
protected int nameComposite(CompositeElementSymbol compositeElement, int stage, List<Integer> streamIndices){
if (compositeElement.isParallel()){
int startStage = stage + 1;
streamIndices.add(1);
int lastIndex = streamIndices.size() - 1;
protected int nameSerialComposite(SerialCompositeElementSymbol compositeElement, int stage, List<Integer> streamIndices){
int endStage = stage;
for (ArchitectureElementSymbol subElement : compositeElement.getElements()){
endStage = name(subElement, endStage, streamIndices);
}
return endStage;
}
List<Integer> endStages = new ArrayList<>();
for (ArchitectureElementSymbol subElement : compositeElement.getElements()){
endStages.add(name(subElement, startStage, streamIndices));
streamIndices.set(lastIndex, streamIndices.get(lastIndex) + 1);
}
protected int nameParallelComposite(ParallelCompositeElementSymbol compositeElement, int stage, List<Integer> streamIndices){
int startStage = stage + 1;
streamIndices.add(1);
int lastIndex = streamIndices.size() - 1;
streamIndices.remove(lastIndex);
return Collections.max(endStages) + 1;
} else {
int endStage = stage;
for (ArchitectureElementSymbol subElement : compositeElement.getElements()){
endStage = name(subElement, endStage, streamIndices);
}
return endStage;
List<Integer> endStages = new ArrayList<>();
for (ArchitectureElementSymbol subElement : compositeElement.getElements()){
endStages.add(name(subElement, startStage, streamIndices));
streamIndices.set(lastIndex, streamIndices.get(lastIndex) + 1);
}
streamIndices.remove(lastIndex);
return Collections.max(endStages) + 1;
}
protected int add(ArchitectureElementSymbol architectureElement, int stage, List<Integer> streamIndices){
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment