Commit 375a75c1 authored by Evgeny Kusmenko's avatar Evgeny Kusmenko

Merge branch 'merge' into 'master'

Merge

See merge request !22
parents 3d9809dd fa320623
Pipeline #158179 passed with stages
in 16 minutes and 46 seconds
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
......@@ -30,7 +30,7 @@
<groupId>de.monticore.lang.monticar</groupId>
<artifactId>cnn-arch</artifactId>
<version>0.3.0-SNAPSHOT</version>
<version>0.3.1-SNAPSHOT</version>
......
......@@ -23,7 +23,7 @@ grammar CNNArch extends de.monticore.CommonExpressions, de.monticore.lang.Math,
LayerDeclaration = "def"
Name "("
parameters:(LayerParameter || ",")* ")" "{"
body:ArchBody "}";
body:Stream "}";
IODeclaration = "def"
(in:"input" | out:"output")
......@@ -53,20 +53,24 @@ grammar CNNArch extends de.monticore.CommonExpressions, de.monticore.lang.Math,
@attribute body
The architecture of the neural network.
*/
Architecture = methodDeclaration:LayerDeclaration*
body:ArchBody ;
Architecture = methodDeclaration:LayerDeclaration*
instructions:(Instruction || ";")+ ";";
scope ArchBody = elements:(ArchitectureElement || "->")*;
interface Instruction;
Stream implements Instruction = elements:(ArchitectureElement || "->")+;
interface ArchitectureElement;
IOElement implements ArchitectureElement = Name ("[" index:ArchSimpleExpression "]")?;
Constant implements ArchitectureElement = ArchSimpleExpression;
Layer implements ArchitectureElement = Name "(" arguments:(ArchArgument || ",")* ")";
ParallelBlock implements ArchitectureElement = "("
groups:ArchBody "|"
groups:(ArchBody || "|")+ ")";
groups:Stream "|"
groups:(Stream || "|")+ ")";
ArrayAccessLayer implements ArchitectureElement = "[" index:ArchSimpleExpression "]";
......@@ -160,4 +164,4 @@ grammar CNNArch extends de.monticore.CommonExpressions, de.monticore.lang.Math,
ast ArchArgument = method String getName(){}
method ASTArchExpression getRhs(){};
}
\ No newline at end of file
}
......@@ -67,30 +67,13 @@ public abstract class CNNArchGenerator {
generate(scope, rootModelName);
}
public abstract boolean check(ArchitectureSymbol architecture);
public abstract void generate(Scope scope, String rootModelName);
//check cocos with CNNArchCocos.checkAll(architecture) before calling this method.
public abstract Map<String, String> generateStrings(ArchitectureSymbol architecture);
public void checkValidGeneration(ArchitectureSymbol architecture){
if (architecture.getInputs().size() > 1){
Log.error("This cnn architecture has multiple inputs, " +
"which is currently not supported by the code generator. "
, architecture.getSourcePosition());
}
if (architecture.getOutputs().size() > 1){
Log.error("This cnn architecture has multiple outputs, " +
"which is currently not supported by the code generator. "
, architecture.getSourcePosition());
}
if (architecture.getOutputs().get(0).getDefinition().getType().getWidth() != 1 ||
architecture.getOutputs().get(0).getDefinition().getType().getHeight() != 1){
Log.error("This cnn architecture has a multi-dimensional output, " +
"which is currently not supported by the code generator."
, architecture.getSourcePosition());
}
}
//check cocos with CNNArchCocos.checkAll(architecture) before calling this method.
public void generateFiles(ArchitectureSymbol architecture) throws IOException{
Map<String, String> fileContentMap = generateStrings(architecture);
......
......@@ -66,7 +66,8 @@ public class CNNArchCocos {
.addCoCo(new CheckIOType())
.addCoCo(new CheckElementInputs())
.addCoCo(new CheckIOAccessAndIOMissing())
.addCoCo(new CheckArchitectureFinished());
.addCoCo(new CheckArchitectureFinished())
.addCoCo(new CheckNetworkStreamMissing());
}
//checks cocos based on symbols before the resolve method of the ArchitectureSymbol is called
......
......@@ -21,6 +21,7 @@
package de.monticore.lang.monticar.cnnarch._cocos;
import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.CompositeElementSymbol;
import de.monticore.lang.monticar.cnnarch.helper.ErrorCodes;
import de.se_rwth.commons.logging.Log;
......@@ -28,10 +29,12 @@ public class CheckArchitectureFinished extends CNNArchSymbolCoCo {
@Override
public void check(ArchitectureSymbol architecture) {
if (!architecture.getBody().getOutputTypes().isEmpty()){
Log.error("0" + ErrorCodes.UNFINISHED_ARCHITECTURE + " The architecture is not finished. " +
"There are still open streams at the end of the architecture. "
, architecture.getSourcePosition());
for (CompositeElementSymbol stream : architecture.getStreams()) {
if (!stream.getOutputTypes().isEmpty()){
Log.error("0" + ErrorCodes.UNFINISHED_ARCHITECTURE + " The architecture is not finished. " +
"There are still open streams at the end of the architecture. "
, architecture.getSourcePosition());
}
}
if (architecture.getInputs().isEmpty()){
Log.error("0" + ErrorCodes.UNFINISHED_ARCHITECTURE + " The architecture has no inputs. "
......
......@@ -23,6 +23,7 @@ package de.monticore.lang.monticar.cnnarch._cocos;
import de.monticore.lang.monticar.cnnarch._ast.ASTArchArgument;
import de.monticore.lang.monticar.cnnarch._symboltable.ArgumentSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.LayerDeclarationSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.LayerSymbol;
import de.monticore.lang.monticar.cnnarch.helper.ErrorCodes;
import de.se_rwth.commons.Joiners;
import de.se_rwth.commons.logging.Log;
......@@ -32,12 +33,16 @@ public class CheckArgument implements CNNArchASTArchArgumentCoCo {
@Override
public void check(ASTArchArgument node) {
ArgumentSymbol argument = (ArgumentSymbol) node.getSymbolOpt().get();
LayerDeclarationSymbol layerDeclaration = argument.getLayer().getDeclaration();
if (layerDeclaration != null && argument.getParameter() == null){
Log.error("0"+ ErrorCodes.UNKNOWN_ARGUMENT + " Unknown Argument. " +
"Parameter with name '" + node.getName() + "' does not exist. " +
"Possible arguments are: " + Joiners.COMMA.join(layerDeclaration.getParameters())
, node.get_SourcePositionStart());
if(argument.getEnclosingScope().getSpanningSymbol().get() instanceof LayerSymbol) {
LayerDeclarationSymbol layerDeclaration = argument.getLayer().getDeclaration();
if (layerDeclaration != null && argument.getParameter() == null){
Log.error("0"+ ErrorCodes.UNKNOWN_ARGUMENT + " Unknown Argument. " +
"Parameter with name '" + node.getName() + "' does not exist. " +
"Possible arguments are: " + Joiners.COMMA.join(layerDeclaration.getParameters())
, node.get_SourcePositionStart());
}
}
}
......
......@@ -21,11 +21,14 @@
package de.monticore.lang.monticar.cnnarch._cocos;
import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.CompositeElementSymbol;
public class CheckElementInputs extends CNNArchSymbolCoCo {
@Override
public void check(ArchitectureSymbol architecture) {
architecture.getBody().checkInput();
for (CompositeElementSymbol stream : architecture.getStreams()) {
stream.checkInput();
}
}
}
......@@ -21,10 +21,7 @@
package de.monticore.lang.monticar.cnnarch._cocos;
import de.monticore.lang.monticar.cnnarch._ast.ASTLayerDeclaration;
import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureElementSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.CompositeElementSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.LayerDeclarationSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.LayerSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.*;
import de.monticore.lang.monticar.cnnarch.helper.ErrorCodes;
import de.se_rwth.commons.logging.Log;
......
/**
*
* ******************************************************************************
* MontiCAR Modeling Family, www.se-rwth.de
* Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
* All rights reserved.
*
* This project is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 3.0 of the License, or (at your option) any later version.
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package de.monticore.lang.monticar.cnnarch._cocos;
import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.CompositeElementSymbol;
import de.monticore.lang.monticar.cnnarch.helper.ErrorCodes;
import de.se_rwth.commons.logging.Log;
public class CheckNetworkStreamMissing extends CNNArchSymbolCoCo {
@Override
public void check(ArchitectureSymbol architecture) {
boolean hasNetworkStream = false;
for (CompositeElementSymbol stream : architecture.getStreams()) {
hasNetworkStream |= stream.isNetwork();
}
if (!hasNetworkStream) {
Log.error("0" + ErrorCodes.MISSING_NETWORK_STREAM + " The architecture has no network stream. "
, architecture.getSourcePosition());
}
}
}
......@@ -37,7 +37,7 @@ public class ArchitectureSymbol extends CommonScopeSpanningSymbol {
public static final ArchitectureKind KIND = new ArchitectureKind();
private ArchitectureElementSymbol body;
private List<SerialCompositeElementSymbol> streams = new ArrayList<>();
private List<IOSymbol> inputs = new ArrayList<>();
private List<IOSymbol> outputs = new ArrayList<>();
private Map<String, IODeclarationSymbol> ioDeclarationMap = new HashMap<>();
......@@ -48,12 +48,12 @@ public class ArchitectureSymbol extends CommonScopeSpanningSymbol {
super("", KIND);
}
public ArchitectureElementSymbol getBody() {
return body;
public List<SerialCompositeElementSymbol> getStreams() {
return streams;
}
protected void setBody(ArchitectureElementSymbol body) {
this.body = body;
public void setStreams(List<SerialCompositeElementSymbol> streams) {
this.streams = streams;
}
public String getDataPath() {
......@@ -103,30 +103,44 @@ public class ArchitectureSymbol extends CommonScopeSpanningSymbol {
return getSpannedScope().resolveLocally(LayerDeclarationSymbol.KIND);
}
public void resolve() {
for (CompositeElementSymbol stream : streams) {
stream.checkIfResolvable();
public void resolve(){
getBody().checkIfResolvable();
try{
getBody().resolveOrError();
}
catch (ArchResolveException e){
//do nothing; error is already logged
try {
stream.resolveOrError();
}
catch (ArchResolveException e) {
// Do nothing; error is already logged
}
}
}
public List<ArchitectureElementSymbol> getFirstElements(){
/*public List<ArchitectureElementSymbol> getFirstElements() {
if (!getBody().isResolved()){
resolve();
}
return getBody().getFirstAtomicElements();
}
}*/
public boolean isResolved(){
return getBody().isResolved();
boolean resolved = true;
for (CompositeElementSymbol stream : streams) {
resolved &= stream.isResolved();
}
return resolved;
}
public boolean isResolvable(){
return getBody().isResolvable();
boolean resolvable = true;
for (CompositeElementSymbol stream : streams) {
resolvable &= stream.isResolvable();
}
return resolvable;
}
public void putInScope(Scope scope){
......@@ -145,22 +159,32 @@ public class ArchitectureSymbol extends CommonScopeSpanningSymbol {
*/
public ArchitectureSymbol preResolveDeepCopy(Scope enclosingScopeOfCopy){
ArchitectureSymbol copy = new ArchitectureSymbol();
copy.setBody(getBody().preResolveDeepCopy());
if (getAstNode().isPresent()){
copy.setAstNode(getAstNode().get());
}
copy.getSpannedScope().getAsMutableScope().add(AllPredefinedVariables.createTrueConstant());
copy.getSpannedScope().getAsMutableScope().add(AllPredefinedVariables.createFalseConstant());
for (LayerDeclarationSymbol layerDeclaration : AllPredefinedLayers.createList()){
copy.getSpannedScope().getAsMutableScope().add(layerDeclaration);
}
for (LayerDeclarationSymbol layerDeclaration : getSpannedScope().<LayerDeclarationSymbol>resolveLocally(LayerDeclarationSymbol.KIND)){
if (!layerDeclaration.isPredefined()) {
copy.getSpannedScope().getAsMutableScope().add(layerDeclaration.deepCopy());
}
}
copy.getBody().putInScope(copy.getSpannedScope());
List<SerialCompositeElementSymbol> copyStreams = new ArrayList<>();
for (SerialCompositeElementSymbol stream : streams) {
SerialCompositeElementSymbol copyStream = stream.preResolveDeepCopy();
copyStream.putInScope(copy.getSpannedScope());
copyStreams.add(copyStream);
}
copy.setStreams(copyStreams);
copy.putInScope(enclosingScopeOfCopy);
return copy;
}
......
......@@ -20,7 +20,6 @@
*/
package de.monticore.lang.monticar.cnnarch._symboltable;
import de.monticore.expressionsbasis._ast.ASTExpression;
import de.monticore.lang.math._symboltable.MathSymbolTableCreator;
import de.monticore.lang.math._symboltable.expression.*;
......@@ -145,8 +144,12 @@ public class CNNArchSymbolTableCreator extends de.monticore.symboltable.CommonSy
}
public void endVisit(final ASTArchitecture node) {
//ArchitectureSymbol architecture = (ArchitectureSymbol) node.getSymbolOpt().get();
architecture.setBody((ArchitectureElementSymbol) node.getBody().getSymbolOpt().get());
List<SerialCompositeElementSymbol> streams = new ArrayList<>();
for (ASTInstruction astInstruction : node.getInstructionsList()){
ASTStream astStream = (ASTStream)astInstruction; // TODO: For now all instructions are streams
streams.add((SerialCompositeElementSymbol) astStream.getSymbolOpt().get());
}
architecture.setStreams(streams);
removeCurrentScope();
}
......@@ -226,7 +229,7 @@ public class CNNArchSymbolTableCreator extends de.monticore.symboltable.CommonSy
@Override
public void endVisit(ASTLayerDeclaration ast) {
LayerDeclarationSymbol layerDeclaration = (LayerDeclarationSymbol) ast.getSymbolOpt().get();
layerDeclaration.setBody((CompositeElementSymbol) ast.getBody().getSymbolOpt().get());
layerDeclaration.setBody((SerialCompositeElementSymbol) ast.getBody().getSymbolOpt().get());
List<VariableSymbol> parameters = new ArrayList<>(4);
for (ASTLayerParameter astParam : ast.getParametersList()){
......@@ -316,18 +319,17 @@ public class CNNArchSymbolTableCreator extends de.monticore.symboltable.CommonSy
@Override
public void visit(ASTParallelBlock node) {
CompositeElementSymbol compositeElement = new CompositeElementSymbol();
compositeElement.setParallel(true);
ParallelCompositeElementSymbol compositeElement = new ParallelCompositeElementSymbol();
addToScopeAndLinkWithNode(compositeElement, node);
}
@Override
public void endVisit(ASTParallelBlock node) {
CompositeElementSymbol compositeElement = (CompositeElementSymbol) node.getSymbolOpt().get();
ParallelCompositeElementSymbol compositeElement = (ParallelCompositeElementSymbol) node.getSymbolOpt().get();
List<ArchitectureElementSymbol> elements = new ArrayList<>();
for (ASTArchBody astBody : node.getGroupsList()){
elements.add((CompositeElementSymbol) astBody.getSymbolOpt().get());
for (ASTStream astStream : node.getGroupsList()){
elements.add((SerialCompositeElementSymbol) astStream.getSymbolOpt().get());
}
compositeElement.setElements(elements);
......@@ -335,16 +337,14 @@ public class CNNArchSymbolTableCreator extends de.monticore.symboltable.CommonSy
}
@Override
public void visit(ASTArchBody ast) {
CompositeElementSymbol compositeElement = new CompositeElementSymbol();
compositeElement.setParallel(false);
public void visit(ASTStream ast) {
SerialCompositeElementSymbol compositeElement = new SerialCompositeElementSymbol();
addToScopeAndLinkWithNode(compositeElement, ast);
}
@Override
public void endVisit(ASTArchBody ast) {
CompositeElementSymbol compositeElement = (CompositeElementSymbol) ast.getSymbolOpt().get();
public void endVisit(ASTStream ast) {
SerialCompositeElementSymbol compositeElement = (SerialCompositeElementSymbol) ast.getSymbolOpt().get();
List<ArchitectureElementSymbol> elements = new ArrayList<>();
for (ASTArchitectureElement astElement : ast.getElementsList()){
elements.add((ArchitectureElementSymbol) astElement.getSymbolOpt().get());
......@@ -384,6 +384,17 @@ public class CNNArchSymbolTableCreator extends de.monticore.symboltable.CommonSy
addToScopeAndLinkWithNode(argument, node);
}
public void visit(ASTConstant node) {
ConstantSymbol constant = new ConstantSymbol();
addToScopeAndLinkWithNode(constant, node);
}
public void endVisit(ASTConstant node) {
ConstantSymbol constant = (ConstantSymbol) node.getSymbolOpt().get();
constant.setExpression((ArchSimpleExpressionSymbol) node.getArchSimpleExpression().getSymbolOpt().get());
removeCurrentScope();
}
public void visit(ASTIOElement node) {
IOSymbol ioElement = new IOSymbol(node.getName());
addToScopeAndLinkWithNode(ioElement, node);
......
......@@ -20,122 +20,44 @@
*/
package de.monticore.lang.monticar.cnnarch._symboltable;
import de.monticore.lang.monticar.cnnarch.helper.ErrorCodes;
import de.monticore.symboltable.Scope;
import de.monticore.symboltable.Symbol;
import de.se_rwth.commons.logging.Log;
import java.util.*;
public class CompositeElementSymbol extends ArchitectureElementSymbol {
public abstract class CompositeElementSymbol extends ArchitectureElementSymbol {
private boolean parallel;
private List<ArchitectureElementSymbol> elements;
protected List<ArchitectureElementSymbol> elements = new ArrayList<>();
protected CompositeElementSymbol() {
public CompositeElementSymbol() {
super("");
setResolvedThis(this);
}
public boolean isParallel() {
return parallel;
}
protected void setParallel(boolean parallel) {
this.parallel = parallel;
}
public List<ArchitectureElementSymbol> getElements() {
return elements;
}
protected void setElements(List<ArchitectureElementSymbol> elements) {
ArchitectureElementSymbol previous = null;
for (ArchitectureElementSymbol current : elements){
if (previous != null && !isParallel()){
current.setInputElement(previous);
previous.setOutputElement(current);
}
else {
if (getInputElement().isPresent()){
current.setInputElement(getInputElement().get());
}
if (getOutputElement().isPresent()){
current.setOutputElement(getOutputElement().get());
}
}
previous = current;
}
this.elements = elements;
}
@Override
public boolean isAtomic() {
return getElements().isEmpty();
}
abstract protected void setElements(List<ArchitectureElementSymbol> elements);
@Override
public void setInputElement(ArchitectureElementSymbol inputElement) {
super.setInputElement(inputElement);
if (isParallel()){
for (ArchitectureElementSymbol current : getElements()){
current.setInputElement(inputElement);
}
}
else {
if (!getElements().isEmpty()){
getElements().get(0).setInputElement(inputElement);
}
}
}
public boolean isNetwork() {
boolean isNetwork = false;
@Override
public void setOutputElement(ArchitectureElementSymbol outputElement) {
super.setOutputElement(outputElement);
if (isParallel()){
for (ArchitectureElementSymbol current : getElements()){
current.setOutputElement(outputElement);
for (ArchitectureElementSymbol element : elements) {
if (element instanceof CompositeElementSymbol) {
isNetwork |= ((CompositeElementSymbol) element).isNetwork();
}
}
else {
if (!getElements().isEmpty()){
getElements().get(getElements().size()-1).setOutputElement(outputElement);
else if (element instanceof LayerSymbol) {
isNetwork |= ((LayerSymbol) element).getDeclaration().isNetworkLayer();
}
}
}
@Override
public List<ArchitectureElementSymbol> getFirstAtomicElements() {
if (getElements().isEmpty()){
return Collections.singletonList(this);
}
else if (isParallel()){
List<ArchitectureElementSymbol> firstElements = new ArrayList<>();
for (ArchitectureElementSymbol element : getElements()){
firstElements.addAll(element.getFirstAtomicElements());
}
return firstElements;
}
else {
return getElements().get(0).getFirstAtomicElements();
}
return isNetwork;
}
@Override
public List<ArchitectureElementSymbol> getLastAtomicElements() {
if (getElements().isEmpty()){
return Collections.singletonList(this);
}
else if (isParallel()){
List<ArchitectureElementSymbol> lastElements = new ArrayList<>();
for (ArchitectureElementSymbol element : getElements()){
lastElements.addAll(element.getLastAtomicElements());
}
return lastElements;
}
else {
return getElements().get(getElements().size()-1).getLastAtomicElements();
}
public boolean isAtomic() {
return getElements().isEmpty();
}
@Override
......@@ -177,73 +99,6 @@ public class CompositeElementSymbol extends ArchitectureElementSymbol {
}
}
@Override
public List<ArchTypeSymbol> computeOutputTypes() {
if (getElements().isEmpty()){
if (getInputElement().isPresent()){
return getInputElement().get().getOutputTypes();
}
else {
return Collections.emptyList();
}
}
else {
if (isParallel()){
List<ArchTypeSymbol> outputShapes = new ArrayList<>(getElements().size());
for (ArchitectureElementSymbol element : getElements()){
if (element.getOutputTypes().size() != 0){
outputShapes.add(element.getOutputTypes().get(0));
}
}
return outputShapes;
}
else {
for (ArchitectureElementSymbol element : getElements()){
element.getOutputTypes();
}
return getElements().get(getElements().size() - 1).getOutputTypes();
}
}