Commit 45906f7b authored by Sebastian Nickels's avatar Sebastian Nickels

Added CNNArchGenerator.check() to allow EMADL2CPP use...

Added CNNArchGenerator.check() to allow EMADL2CPP use ArchitectureSupportChecker and LayerSupportChecker, splitted serial and parallel CompositeElementSymbol into seperate classes and added an isNetwork() method to distinguish between actual network streams and basic assignments
parent c9f47ac6
......@@ -67,6 +67,8 @@ public abstract class CNNArchGenerator {
generate(scope, rootModelName);
}
public abstract boolean check(ArchitectureSymbol architecture);
public abstract void generate(Scope scope, String rootModelName);
//check cocos with CNNArchCocos.checkAll(architecture) before calling this method.
......
......@@ -21,10 +21,7 @@
package de.monticore.lang.monticar.cnnarch._cocos;
import de.monticore.lang.monticar.cnnarch._ast.ASTLayerDeclaration;
import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureElementSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.CompositeElementSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.LayerDeclarationSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.LayerSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.*;
import de.monticore.lang.monticar.cnnarch.helper.ErrorCodes;
import de.se_rwth.commons.logging.Log;
......
......@@ -37,7 +37,7 @@ public class ArchitectureSymbol extends CommonScopeSpanningSymbol {
public static final ArchitectureKind KIND = new ArchitectureKind();
private List<CompositeElementSymbol> streams;
private List<SerialCompositeElementSymbol> streams;
private List<IOSymbol> inputs = new ArrayList<>();
private List<IOSymbol> outputs = new ArrayList<>();
private Map<String, IODeclarationSymbol> ioDeclarationMap = new HashMap<>();
......@@ -48,11 +48,11 @@ public class ArchitectureSymbol extends CommonScopeSpanningSymbol {
super("", KIND);
}
public List<CompositeElementSymbol> getStreams() {
public List<SerialCompositeElementSymbol> getStreams() {
return streams;
}
public void setStreams(List<CompositeElementSymbol> streams) {
public void setStreams(List<SerialCompositeElementSymbol> streams) {
this.streams = streams;
}
......@@ -177,9 +177,9 @@ public class ArchitectureSymbol extends CommonScopeSpanningSymbol {
}
}
List<CompositeElementSymbol> copyStreams = new ArrayList<>();
for (CompositeElementSymbol stream : streams) {
CompositeElementSymbol copyStream = stream.preResolveDeepCopy();
List<SerialCompositeElementSymbol> copyStreams = new ArrayList<>();
for (SerialCompositeElementSymbol stream : streams) {
SerialCompositeElementSymbol copyStream = stream.preResolveDeepCopy();
copyStream.putInScope(copy.getSpannedScope());
copyStreams.add(copyStream);
}
......
......@@ -145,10 +145,10 @@ public class CNNArchSymbolTableCreator extends de.monticore.symboltable.CommonSy
}
public void endVisit(final ASTArchitecture node) {
List<CompositeElementSymbol> streams = new ArrayList<>();
List<SerialCompositeElementSymbol> streams = new ArrayList<>();
for (ASTInstruction astInstruction : node.getInstructionsList()){
ASTStream astStream = (ASTStream)astInstruction; // TODO: For now all instructions are streams
streams.add((CompositeElementSymbol) astStream.getSymbolOpt().get());
streams.add((SerialCompositeElementSymbol) astStream.getSymbolOpt().get());
}
architecture.setStreams(streams);
......@@ -230,7 +230,7 @@ public class CNNArchSymbolTableCreator extends de.monticore.symboltable.CommonSy
@Override
public void endVisit(ASTLayerDeclaration ast) {
LayerDeclarationSymbol layerDeclaration = (LayerDeclarationSymbol) ast.getSymbolOpt().get();
layerDeclaration.setBody((CompositeElementSymbol) ast.getBody().getSymbolOpt().get());
layerDeclaration.setBody((SerialCompositeElementSymbol) ast.getBody().getSymbolOpt().get());
List<VariableSymbol> parameters = new ArrayList<>(4);
for (ASTLayerParameter astParam : ast.getParametersList()){
......@@ -320,18 +320,17 @@ public class CNNArchSymbolTableCreator extends de.monticore.symboltable.CommonSy
@Override
public void visit(ASTParallelBlock node) {
CompositeElementSymbol compositeElement = new CompositeElementSymbol();
compositeElement.setParallel(true);
ParallelCompositeElementSymbol compositeElement = new ParallelCompositeElementSymbol();
addToScopeAndLinkWithNode(compositeElement, node);
}
@Override
public void endVisit(ASTParallelBlock node) {
CompositeElementSymbol compositeElement = (CompositeElementSymbol) node.getSymbolOpt().get();
ParallelCompositeElementSymbol compositeElement = (ParallelCompositeElementSymbol) node.getSymbolOpt().get();
List<ArchitectureElementSymbol> elements = new ArrayList<>();
for (ASTStream astStream : node.getGroupsList()){
elements.add((CompositeElementSymbol) astStream.getSymbolOpt().get());
elements.add((SerialCompositeElementSymbol) astStream.getSymbolOpt().get());
}
compositeElement.setElements(elements);
......@@ -340,14 +339,13 @@ public class CNNArchSymbolTableCreator extends de.monticore.symboltable.CommonSy
@Override
public void visit(ASTStream ast) {
CompositeElementSymbol compositeElement = new CompositeElementSymbol();
compositeElement.setParallel(false);
SerialCompositeElementSymbol compositeElement = new SerialCompositeElementSymbol();
addToScopeAndLinkWithNode(compositeElement, ast);
}
@Override
public void endVisit(ASTStream ast) {
CompositeElementSymbol compositeElement = (CompositeElementSymbol) ast.getSymbolOpt().get();
SerialCompositeElementSymbol compositeElement = (SerialCompositeElementSymbol) ast.getSymbolOpt().get();
List<ArchitectureElementSymbol> elements = new ArrayList<>();
for (ASTArchitectureElement astElement : ast.getElementsList()){
elements.add((ArchitectureElementSymbol) astElement.getSymbolOpt().get());
......
......@@ -20,122 +20,45 @@
*/
package de.monticore.lang.monticar.cnnarch._symboltable;
import de.monticore.lang.monticar.cnnarch.helper.ErrorCodes;
import de.monticore.symboltable.Scope;
import de.monticore.symboltable.Symbol;
import de.se_rwth.commons.logging.Log;
import java.util.*;
public class CompositeElementSymbol extends ArchitectureElementSymbol {
public abstract class CompositeElementSymbol extends ArchitectureElementSymbol {
private boolean parallel;
private List<ArchitectureElementSymbol> elements;
protected List<ArchitectureElementSymbol> elements = new ArrayList<>();
protected CompositeElementSymbol() {
public CompositeElementSymbol() {
super("");
setResolvedThis(this);
}
public boolean isParallel() {
return parallel;
}
protected void setParallel(boolean parallel) {
this.parallel = parallel;
}
public List<ArchitectureElementSymbol> getElements() {
return elements;
}
protected void setElements(List<ArchitectureElementSymbol> elements) {
ArchitectureElementSymbol previous = null;
for (ArchitectureElementSymbol current : elements){
if (previous != null && !isParallel()){
current.setInputElement(previous);
previous.setOutputElement(current);
}
else {
if (getInputElement().isPresent()){
current.setInputElement(getInputElement().get());
}
if (getOutputElement().isPresent()){
current.setOutputElement(getOutputElement().get());
}
}
previous = current;
}
this.elements = elements;
}
@Override
public boolean isAtomic() {
return getElements().isEmpty();
}
abstract protected void setElements(List<ArchitectureElementSymbol> elements);
@Override
public void setInputElement(ArchitectureElementSymbol inputElement) {
super.setInputElement(inputElement);
if (isParallel()){
for (ArchitectureElementSymbol current : getElements()){
current.setInputElement(inputElement);
}
}
else {
if (!getElements().isEmpty()){
getElements().get(0).setInputElement(inputElement);
}
}
}
public boolean isNetwork() {
boolean isNetwork = false;
@Override
public void setOutputElement(ArchitectureElementSymbol outputElement) {
super.setOutputElement(outputElement);
if (isParallel()){
for (ArchitectureElementSymbol current : getElements()){
current.setOutputElement(outputElement);
for (ArchitectureElementSymbol element : elements) {
if (element instanceof CompositeElementSymbol) {
isNetwork |= ((CompositeElementSymbol) element).isNetwork();
}
}
else {
if (!getElements().isEmpty()){
getElements().get(getElements().size()-1).setOutputElement(outputElement);
else if (element instanceof LayerSymbol)
{
isNetwork |= ((LayerSymbol) element).getDeclaration().isNetworkLayer();
}
}
}
@Override
public List<ArchitectureElementSymbol> getFirstAtomicElements() {
if (getElements().isEmpty()){
return Collections.singletonList(this);
}
else if (isParallel()){
List<ArchitectureElementSymbol> firstElements = new ArrayList<>();
for (ArchitectureElementSymbol element : getElements()){
firstElements.addAll(element.getFirstAtomicElements());
}
return firstElements;
}
else {
return getElements().get(0).getFirstAtomicElements();
}
return isNetwork;
}
@Override
public List<ArchitectureElementSymbol> getLastAtomicElements() {
if (getElements().isEmpty()){
return Collections.singletonList(this);
}
else if (isParallel()){
List<ArchitectureElementSymbol> lastElements = new ArrayList<>();
for (ArchitectureElementSymbol element : getElements()){
lastElements.addAll(element.getLastAtomicElements());
}
return lastElements;
}
else {
return getElements().get(getElements().size()-1).getLastAtomicElements();
}
public boolean isAtomic() {
return getElements().isEmpty();
}
@Override
......@@ -177,73 +100,6 @@ public class CompositeElementSymbol extends ArchitectureElementSymbol {
}
}
@Override
public List<ArchTypeSymbol> computeOutputTypes() {
if (getElements().isEmpty()){
if (getInputElement().isPresent()){
return getInputElement().get().getOutputTypes();
}
else {
return Collections.emptyList();
}
}
else {
if (isParallel()){
List<ArchTypeSymbol> outputShapes = new ArrayList<>(getElements().size());
for (ArchitectureElementSymbol element : getElements()){
if (element.getOutputTypes().size() != 0){
outputShapes.add(element.getOutputTypes().get(0));
}
}
return outputShapes;
}
else {
for (ArchitectureElementSymbol element : getElements()){
element.getOutputTypes();
}
return getElements().get(getElements().size() - 1).getOutputTypes();
}
}
}
@Override
public void checkInput() {
if (!getElements().isEmpty()){
for (ArchitectureElementSymbol element : getElements()){
element.checkInput();
}
if (isParallel()){
for (ArchitectureElementSymbol element : getElements()){
if (element.getOutputTypes().size() > 1){
Log.error("0" + ErrorCodes.MISSING_MERGE + " Missing merge layer (Add(), Concatenate() or [i]). " +
"Each stream at the end of a parallelization block can only have one output stream. "
, getSourcePosition());
}
}
}
}
}
@Override
public Optional<Integer> getParallelLength() {
if (isParallel()){
return Optional.of(getElements().size());
}
else {
return Optional.of(1);
}
}
@Override
public Optional<List<Integer>> getSerialLengths() {
if (isParallel()){
return Optional.of(Collections.nCopies(getElements().size(), 1));
}
else {
return Optional.of(Collections.singletonList(getElements().size()));
}
}
@Override
protected void putInScope(Scope scope) {
Collection<Symbol> symbolsInScope = scope.getLocalSymbols().get(getName());
......@@ -254,48 +110,4 @@ public class CompositeElementSymbol extends ArchitectureElementSymbol {
}
}
}
@Override
protected CompositeElementSymbol preResolveDeepCopy() {
CompositeElementSymbol copy = new CompositeElementSymbol();
copy.setParallel(isParallel());
if (getAstNode().isPresent()){
copy.setAstNode(getAstNode().get());
}
List<ArchitectureElementSymbol> elements = new ArrayList<>(getElements().size());
for (ArchitectureElementSymbol element : getElements()){
ArchitectureElementSymbol elementCopy = element.preResolveDeepCopy();
elements.add(elementCopy);
}
copy.setElements(elements);
return copy;
}
public static class Builder{
private boolean parallel = false;
private List<ArchitectureElementSymbol> elements = new ArrayList<>();
public Builder parallel(boolean parallel){
this.parallel = parallel;
return this;
}
public Builder elements(List<ArchitectureElementSymbol> elements){
this.elements = elements;
return this;
}
public Builder elements(ArchitectureElementSymbol... elements){
this.elements = Arrays.asList(elements);
return this;
}
public CompositeElementSymbol build(){
CompositeElementSymbol sym = new CompositeElementSymbol();
sym.setParallel(parallel);
sym.setElements(elements);
return sym;
}
}
}
......@@ -110,10 +110,8 @@ public class IOSymbol extends ArchitectureElementSymbol {
if (!getArrayAccess().isPresent() && getDefinition().getArrayLength() > 1){
//transform io array into parallel composite
List<ArchitectureElementSymbol> parallelElements = createExpandedParallelElements();
CompositeElementSymbol composite = new CompositeElementSymbol.Builder()
.parallel(true)
.elements(parallelElements)
.build();
ParallelCompositeElementSymbol composite = new ParallelCompositeElementSymbol();
composite.setElements(parallelElements);
getSpannedScope().getAsMutableScope().add(composite);
composite.setAstNode(getAstNode().get());
......@@ -157,8 +155,7 @@ public class IOSymbol extends ArchitectureElementSymbol {
}
else {
for (int i = 0; i < getDefinition().getArrayLength(); i++){
CompositeElementSymbol serialComposite = new CompositeElementSymbol();
serialComposite.setParallel(false);
SerialCompositeElementSymbol serialComposite = new SerialCompositeElementSymbol();
IOSymbol ioElement = new IOSymbol(getName());
ioElement.setArrayAccess(i);
......
......@@ -35,7 +35,7 @@ public class LayerDeclarationSymbol extends CommonScopeSpanningSymbol {
public static final LayerDeclarationKind KIND = new LayerDeclarationKind();
private List<VariableSymbol> parameters;
private CompositeElementSymbol body;
private SerialCompositeElementSymbol body;
protected LayerDeclarationSymbol(String name) {
......@@ -75,11 +75,15 @@ public class LayerDeclarationSymbol extends CommonScopeSpanningSymbol {
}
}
public CompositeElementSymbol getBody() {
public SerialCompositeElementSymbol getBody() {
return body;
}
protected void setBody(CompositeElementSymbol body) {
public boolean isNetworkLayer() {
return body.isNetwork();
}
protected void setBody(SerialCompositeElementSymbol body) {
this.body = body;
}
......@@ -109,7 +113,7 @@ public class LayerDeclarationSymbol extends CommonScopeSpanningSymbol {
reset();
set(layer.getArguments());
CompositeElementSymbol copy = getBody().preResolveDeepCopy();
SerialCompositeElementSymbol copy = getBody().preResolveDeepCopy();
copy.putInScope(getSpannedScope());
copy.resolveOrError();
getSpannedScope().remove(copy);
......
......@@ -142,7 +142,7 @@ public class LayerSymbol extends ArchitectureElementSymbol {
if (!isActive() || maxSerialLength == 0) {
//set resolvedThis to empty composite to remove the layer.
setResolvedThis(new CompositeElementSymbol.Builder().build());
setResolvedThis(new SerialCompositeElementSymbol());
}
else if (parallelLength == 1 && maxSerialLength == 1) {
//resolve the layer call
......@@ -186,10 +186,8 @@ public class LayerSymbol extends ArchitectureElementSymbol {
for (List<ArchitectureElementSymbol> serialElements : elements) {
serialComposites.add(createSerialSequencePart(serialElements));
}
CompositeElementSymbol parallelElement = new CompositeElementSymbol.Builder()
.parallel(true)
.elements(serialComposites)
.build();
ParallelCompositeElementSymbol parallelElement = new ParallelCompositeElementSymbol();
parallelElement.setElements(serialComposites);
if (getAstNode().isPresent()) {
parallelElement.setAstNode(getAstNode().get());
......@@ -203,10 +201,8 @@ public class LayerSymbol extends ArchitectureElementSymbol {
return elements.get(0);
}
else {
CompositeElementSymbol serialComposite = new CompositeElementSymbol.Builder()
.parallel(false)
.elements(elements)
.build();
SerialCompositeElementSymbol serialComposite = new SerialCompositeElementSymbol();
serialComposite.setElements(elements);
if (getAstNode().isPresent()){
serialComposite.setAstNode(getAstNode().get());
......
/**
*
* ******************************************************************************
* MontiCAR Modeling Family, www.se-rwth.de
* Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
* All rights reserved.
*
* This project is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 3.0 of the License, or (at your option) any later version.
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package de.monticore.lang.monticar.cnnarch._symboltable;
import de.monticore.lang.monticar.cnnarch.helper.ErrorCodes;
import de.se_rwth.commons.logging.Log;
import java.util.*;
public class ParallelCompositeElementSymbol extends CompositeElementSymbol {
protected void setElements(List<ArchitectureElementSymbol> elements) {
for (ArchitectureElementSymbol current : elements){
if (getInputElement().isPresent()){
current.setInputElement(getInputElement().get());
}
if (getOutputElement().isPresent()){
current.setOutputElement(getOutputElement().get());
}
}
this.elements = elements;
}
@Override
public void setInputElement(ArchitectureElementSymbol inputElement) {
super.setInputElement(inputElement);
for (ArchitectureElementSymbol current : getElements()){
current.setInputElement(inputElement);
}
}
@Override
public void setOutputElement(ArchitectureElementSymbol outputElement) {
super.setOutputElement(outputElement);
for (ArchitectureElementSymbol current : getElements()){
current.setOutputElement(outputElement);
}
}
@Override
public List<ArchitectureElementSymbol> getFirstAtomicElements() {
if (getElements().isEmpty()){
return Collections.singletonList(this);
}
else {
List<ArchitectureElementSymbol> firstElements = new ArrayList<>();
for (ArchitectureElementSymbol element : getElements()){
firstElements.addAll(element.getFirstAtomicElements());
}
return firstElements;
}
}
@Override
public List<ArchitectureElementSymbol> getLastAtomicElements() {
if (getElements().isEmpty()){
return Collections.singletonList(this);
}
else {
List<ArchitectureElementSymbol> lastElements = new ArrayList<>();
for (ArchitectureElementSymbol element : getElements()){
lastElements.addAll(element.getLastAtomicElements());
}
return lastElements;
}
}
@Override
public List<ArchTypeSymbol> computeOutputTypes() {
if (getElements().isEmpty()){
if (getInputElement().isPresent()){
return getInputElement().get().getOutputTypes();
}
else {
return Collections.emptyList();
}
}
else {
List<ArchTypeSymbol> outputShapes = new ArrayList<>(getElements().size());
for (ArchitectureElementSymbol element : getElements()){
if (element.getOutputTypes().size() != 0){
outputShapes.add(element.getOutputTypes().get(0));
}
}
return outputShapes;
}
}
@Override
public void checkInput() {
if (!getElements().isEmpty()){
for (ArchitectureElementSymbol element : getElements()){
element.checkInput();
}
for (ArchitectureElementSymbol element : getElements()){
if (element.getOutputTypes().size() > 1){
Log.error("0" + ErrorCodes.MISSING_MERGE + " Missing merge layer (Add(), Concatenate() or [i]). " +
"Each stream at the end of a parallelization block can only have one output stream. "
, getSourcePosition());
}
}
}
}
@Override
public Optional<Integer> getParallelLength() {
return Optional.of(getElements().size());
}
@Override
public Optional<List<Integer>> getSerialLengths() {
return Optional.of(Collections.nCopies(getElements().size(), 1));
}
@Override
protected ParallelCompositeElementSymbol preResolveDeepCopy() {
ParallelCompositeElementSymbol copy = new ParallelCompositeElementSymbol();
if (getAstNode().isPresent()){
copy.setAstNode(getAstNode().get());
}
List<ArchitectureElementSymbol> elements = new ArrayList<>(getElements().size());
for (ArchitectureElementSymbol element : getElements()){
ArchitectureElementSymbol elementCopy = element.preResolveDeepCopy();
elements.add(elementCopy);
}
copy.setElements(elements);
return copy;
}
}
......@@ -53,6 +53,16 @@ abstract public class PredefinedLayerDeclaration extends LayerDeclarationSymbol
return true;
}
/**
* This method is used to distinguish between neural networks like "source -> FullyConnected() -> target" and
* basic assignments like "1 -> OneHot() -> target". The generators use this to avoid creating an own
* network for each assignment. Override by predefined layers which are trainable.
*/
@Override
public boolean isNetworkLayer() {
return false;
}
abstract public List<ArchTypeSymbol> computeOutputTypes(List<ArchTypeSymbol> inputTypes, LayerSymbol layer);
abstract public void checkInput(List<ArchTypeSymbol> inputTypes, LayerSymbol layer);
......
/**
*
* ******************************************************************************