Commit 195c5712 authored by Christian Fuß's avatar Christian Fuß
Browse files

Merge branch 'rnn' into develop

parents 72e83df4 da864824
Pipeline #202178 passed with stages
in 23 minutes and 25 seconds
......@@ -359,6 +359,7 @@
<version>2.19.1</version>
<configuration>
<argLine>-Xmx1024m -Xms1024m -XX:MaxPermSize=512m -Djdk.net.URLClassPath.disableClassPathURLCheck=true</argLine>
<trimStackTrace>false</trimStackTrace>
</configuration>
</plugin>
<plugin>
......
......@@ -56,10 +56,18 @@ grammar CNNArch extends de.monticore.CommonExpressions, de.monticore.lang.Math,
Architecture = methodDeclaration:LayerDeclaration*
instructions:(Instruction || ";")+ ";";
Instruction = (LayerVariableDeclaration | Stream);
Instruction = (LayerVariableDeclaration | NetworkInstruction);
LayerVariableDeclaration = "layer" Layer Name;
interface NetworkInstruction;
StreamInstruction implements NetworkInstruction = body:Stream;
UnrollInstruction implements NetworkInstruction = "timed" "<" timeParameter:TimeParameter ">"
Name "(" arguments:(ArchArgument || ",")* ")"
"{" body:Stream "}";
Stream = elements:(ArchitectureElement || "->")+;
interface ArchitectureElement;
......@@ -70,6 +78,7 @@ grammar CNNArch extends de.monticore.CommonExpressions, de.monticore.lang.Math,
Layer implements ArchitectureElement = Name "(" arguments:(ArchArgument || ",")* ")";
ParallelBlock implements ArchitectureElement = "("
groups:Stream "|"
groups:(Stream || "|")+ ")";
......@@ -85,6 +94,8 @@ grammar CNNArch extends de.monticore.CommonExpressions, de.monticore.lang.Math,
LayerParameter implements ArchParameter = Name ("=" default:ArchSimpleExpression)? ;
TimeParameter implements ArchParameter = Name ("=" default:ArchSimpleExpression)? ;
interface ArchArgument;
ArchParameterArgument implements ArchArgument = Name "=" rhs:ArchExpression ;
......
......@@ -64,13 +64,15 @@ public class CNNArchCocos {
public static CNNArchSymbolCoCoChecker createCNNArchPostResolveSymbolChecker() {
return new CNNArchSymbolCoCoChecker()
.addCoCo(new CheckIOType())
.addCoCo(new CheckIOArrayLength())
.addCoCo(new CheckElementInputs())
.addCoCo(new CheckIOAccessAndIOMissing())
.addCoCo(new CheckArchitectureFinished())
.addCoCo(new CheckNetworkStreamMissing())
.addCoCo(new CheckVariableMember())
.addCoCo(new CheckLayerVariableDeclarationLayerType())
.addCoCo(new CheckLayerVariableDeclarationIsUsed());
.addCoCo(new CheckLayerVariableDeclarationIsUsed())
.addCoCo(new CheckConstants());
}
//checks cocos based on symbols before the resolve method of the ArchitectureSymbol is called
......@@ -78,6 +80,7 @@ public class CNNArchCocos {
return new CNNArchSymbolCoCoChecker()
.addCoCo(new CheckVariableDeclarationName())
.addCoCo(new CheckVariableName())
.addCoCo(new CheckArgmaxLayer())
.addCoCo(new CheckExpressions());
}
......
......@@ -33,6 +33,15 @@ public class CNNArchSymbolCoCo {
else if (sym instanceof LayerDeclarationSymbol){
check((LayerDeclarationSymbol) sym);
}
else if (sym instanceof UnrollDeclarationSymbol){
check((UnrollDeclarationSymbol) sym);
}
else if (sym instanceof UnrollInstructionSymbol){
check((UnrollInstructionSymbol) sym);
}
else if (sym instanceof StreamInstructionSymbol){
check((StreamInstructionSymbol) sym);
}
else if (sym instanceof ArchitectureElementSymbol){
check((ArchitectureElementSymbol) sym);
}
......@@ -72,6 +81,10 @@ public class CNNArchSymbolCoCo {
//Override if needed
}
public void check(UnrollDeclarationSymbol sym){
//Override if needed
}
public void check(ArchitectureElementSymbol sym){
//Override if needed
}
......@@ -103,4 +116,12 @@ public class CNNArchSymbolCoCo {
public void check(MathExpressionSymbol sym){
//Override if needed
}
public void check(UnrollInstructionSymbol sym){
//Override if needed
}
public void check(StreamInstructionSymbol sym){
//Override if needed
}
}
......@@ -21,21 +21,23 @@
package de.monticore.lang.monticar.cnnarch._cocos;
import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.CompositeElementSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.NetworkInstructionSymbol;
import de.monticore.lang.monticar.cnnarch.helper.ErrorCodes;
import de.se_rwth.commons.SourcePosition;
import de.se_rwth.commons.logging.Log;
public class CheckArchitectureFinished extends CNNArchSymbolCoCo {
@Override
public void check(ArchitectureSymbol architecture) {
for (CompositeElementSymbol stream : architecture.getStreams()) {
if (!stream.getOutputTypes().isEmpty()){
for (NetworkInstructionSymbol networkInstruction : architecture.getNetworkInstructions()) {
if (!networkInstruction.getBody().getOutputTypes().isEmpty()) {
Log.error("0" + ErrorCodes.UNFINISHED_ARCHITECTURE + " The architecture is not finished. " +
"There are still open streams at the end of the architecture. "
, stream.getSourcePosition());
, networkInstruction.getSourcePosition());
}
}
if (architecture.getInputs().isEmpty()){
Log.error("0" + ErrorCodes.UNFINISHED_ARCHITECTURE + " The architecture has no inputs. "
, architecture.getSourcePosition());
......
/**
*
* ******************************************************************************
* MontiCAR Modeling Family, www.se-rwth.de
* Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
* All rights reserved.
*
* This project is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 3.0 of the License, or (at your option) any later version.
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package de.monticore.lang.monticar.cnnarch._cocos;
import de.monticore.lang.monticar.cnnarch._symboltable.*;
import de.monticore.lang.monticar.cnnarch.helper.ErrorCodes;
import de.monticore.lang.monticar.cnnarch.predefined.AllPredefinedLayers;
import de.se_rwth.commons.logging.Log;
public class CheckArgmaxLayer extends CNNArchSymbolCoCo {
@Override
public void check(ArchitectureElementSymbol symbol) {
if (symbol instanceof LayerSymbol && symbol.getName().equals(AllPredefinedLayers.ARG_MAX_NAME)) {
checkArgmaxBeforeOutput((LayerSymbol) symbol);
}
}
public void checkArgmaxBeforeOutput(LayerSymbol layer) {
if(!(layer.getOutputElement().get() instanceof VariableSymbol)){
Log.error("0" + ErrorCodes.ILLEGAL_LAYER_USE + " ArgMax Layer must be applied directly before an output symbol.");
}
}
}
......@@ -21,9 +21,7 @@
package de.monticore.lang.monticar.cnnarch._cocos;
import de.monticore.lang.monticar.cnnarch._ast.ASTArchArgument;
import de.monticore.lang.monticar.cnnarch._symboltable.ArgumentSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.LayerDeclarationSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.LayerSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.*;
import de.monticore.lang.monticar.cnnarch.helper.ErrorCodes;
import de.se_rwth.commons.Joiners;
import de.se_rwth.commons.logging.Log;
......@@ -42,7 +40,14 @@ public class CheckArgument implements CNNArchASTArchArgumentCoCo {
"Possible arguments are: " + Joiners.COMMA.join(layerDeclaration.getParameters())
, node.get_SourcePositionStart());
}
}else if(argument.getEnclosingScope().getSpanningSymbol().get() instanceof UnrollInstructionSymbol){
UnrollDeclarationSymbol layerDeclaration = argument.getUnroll().getDeclaration();
if (layerDeclaration != null && argument.getParameter() == null){
Log.error("0"+ ErrorCodes.UNKNOWN_ARGUMENT + " Unknown Argument. " +
"Parameter with name '" + node.getName() + "' does not exist. " +
"Possible arguments are: " + Joiners.COMMA.join(layerDeclaration.getParameters())
, node.get_SourcePositionStart());
}
}
}
......
/**
*
* ******************************************************************************
* MontiCAR Modeling Family, www.se-rwth.de
* Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
* All rights reserved.
*
* This project is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 3.0 of the License, or (at your option) any later version.
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package de.monticore.lang.monticar.cnnarch._cocos;
import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureElementSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.ConstantSymbol;
import de.monticore.lang.monticar.cnnarch.helper.ErrorCodes;
import de.se_rwth.commons.logging.Log;
import java.util.Optional;
public class CheckConstants extends CNNArchSymbolCoCo {
@Override
public void check(ArchitectureElementSymbol sym) {
if (sym instanceof ConstantSymbol) {
checkConstant((ConstantSymbol) sym);
}
}
public void checkConstant(ConstantSymbol constant) {
Optional<Boolean> isInt = constant.getExpression().isInt();
if (!isInt.isPresent() || !isInt.get()) {
Log.error("0" + ErrorCodes.INVALID_CONSTANT + " Invalid constant, only integers allowed. ",
constant.getSourcePosition());
}
}
}
......@@ -21,14 +21,14 @@
package de.monticore.lang.monticar.cnnarch._cocos;
import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.CompositeElementSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.NetworkInstructionSymbol;
public class CheckElementInputs extends CNNArchSymbolCoCo {
@Override
public void check(ArchitectureSymbol architecture) {
for (CompositeElementSymbol stream : architecture.getStreams()) {
stream.checkInput();
for (NetworkInstructionSymbol networkInstruction : architecture.getNetworkInstructions()) {
networkInstruction.getBody().checkInput();
}
}
}
......@@ -20,9 +20,7 @@
*/
package de.monticore.lang.monticar.cnnarch._cocos;
import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.IODeclarationSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.VariableSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.*;
import de.monticore.lang.monticar.cnnarch.helper.ErrorCodes;
import de.se_rwth.commons.Joiners;
import de.se_rwth.commons.logging.Log;
......@@ -63,7 +61,6 @@ public class CheckIOAccessAndIOMissing extends CNNArchSymbolCoCo {
}
}
private void checkIOArray(IODeclarationSymbol ioDeclaration){
List<Integer> unusedIndices = IntStream.range(0, ioDeclaration.getArrayLength()).boxed().collect(Collectors.toList());
......@@ -77,7 +74,7 @@ public class CheckIOAccessAndIOMissing extends CNNArchSymbolCoCo {
Log.error("0" + ErrorCodes.INVALID_ARRAY_ACCESS + " The IO array access value of '" + ioElement.getName() +
"' must be an integer between 0 and " + (ioDeclaration.getArrayLength()-1) + ". " +
"The current value is: " + ioElement.getArrayAccess().get().getValue().get().toString()
, ioElement.getSourcePosition());
, ioElement.getSourcePosition());
}
}
else{
......
/**
*
* ******************************************************************************
* MontiCAR Modeling Family, www.se-rwth.de
* Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
* All rights reserved.
*
* This project is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 3.0 of the License, or (at your option) any later version.
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package de.monticore.lang.monticar.cnnarch._cocos;
import de.monticore.lang.monticar.cnnarch._symboltable.*;
import de.monticore.lang.monticar.cnnarch.helper.ErrorCodes;
import de.se_rwth.commons.logging.Log;
public class CheckIOArrayLength extends CNNArchSymbolCoCo {
@Override
public void check(ArchitectureSymbol architecture) {
for (IODeclarationSymbol ioDeclaration : architecture.getIODeclarations()){
checkIO(ioDeclaration);
}
}
public void checkIO(IODeclarationSymbol ioDeclaration) {
if (ioDeclaration.getArrayLength() > IODeclarationSymbol.MAX_ARRAY_LENGTH) {
Log.error("0" + ErrorCodes.INVALID_IO_ARRAY_LENGTH + " Invalid IO array length. Length can not be bigger than " + IODeclarationSymbol.MAX_ARRAY_LENGTH
, ioDeclaration.getSourcePosition());
}
}
}
......@@ -49,7 +49,6 @@ public class CheckLayer implements CNNArchASTLayerCoCo{
nameSet.add(name);
}
}
LayerDeclarationSymbol layerDeclaration = ((LayerSymbol) node.getSymbolOpt().get()).getDeclaration();
if (layerDeclaration == null){
ArchitectureSymbol architecture = node.getSymbolOpt().get().getEnclosingScope().<ArchitectureSymbol>resolve("", ArchitectureSymbol.KIND).get();
......
......@@ -24,7 +24,10 @@ import de.monticore.lang.monticar.cnnarch._symboltable.*;
import de.monticore.lang.monticar.cnnarch.helper.ErrorCodes;
import de.se_rwth.commons.logging.Log;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.Set;
public class CheckLayerVariableDeclarationIsUsed extends CNNArchSymbolCoCo {
......@@ -35,9 +38,12 @@ public class CheckLayerVariableDeclarationIsUsed extends CNNArchSymbolCoCo {
boolean isUsed = false;
for (SerialCompositeElementSymbol stream : layerVariableDeclaration.getLayer().getArchitecture().getStreams()) {
Collection<ArchitectureElementSymbol> elements =
stream.getSpannedScope().resolveMany(layerVariableDeclaration.getName(), ArchitectureElementSymbol.KIND);
Set<String> allowedUnusedLayers = new HashSet();
allowedUnusedLayers.add("attention");
for (NetworkInstructionSymbol networkInstruction : layerVariableDeclaration.getLayer().getArchitecture().getNetworkInstructions()) {
Collection<ArchitectureElementSymbol> elements
= networkInstruction.getBody().getSpannedScope().resolveMany(layerVariableDeclaration.getName(), ArchitectureElementSymbol.KIND);
for (ArchitectureElementSymbol element : elements) {
if (element instanceof VariableSymbol && ((VariableSymbol) element).getMember() == VariableSymbol.Member.NONE) {
......@@ -51,6 +57,10 @@ public class CheckLayerVariableDeclarationIsUsed extends CNNArchSymbolCoCo {
}
}
if(allowedUnusedLayers.contains(sym.getName())){
isUsed = true;
}
if (!isUsed) {
Log.error("0" + ErrorCodes.UNUSED_LAYER + " Unused layer. " +
"Declared layer variables need to be used as layer at least once.",
......
......@@ -21,7 +21,7 @@
package de.monticore.lang.monticar.cnnarch._cocos;
import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.CompositeElementSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.NetworkInstructionSymbol;
import de.monticore.lang.monticar.cnnarch.helper.ErrorCodes;
import de.se_rwth.commons.logging.Log;
......@@ -31,8 +31,8 @@ public class CheckNetworkStreamMissing extends CNNArchSymbolCoCo {
public void check(ArchitectureSymbol architecture) {
boolean hasTrainableStream = false;
for (CompositeElementSymbol stream : architecture.getStreams()) {
hasTrainableStream |= stream.isTrainable();
for (NetworkInstructionSymbol networkInstruction : architecture.getNetworkInstructions()) {
hasTrainableStream |= networkInstruction.getBody().isTrainable();
}
if (!hasTrainableStream) {
......
/**
*
* ******************************************************************************
* MontiCAR Modeling Family, www.se-rwth.de
* Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
* All rights reserved.
*
* This project is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 3.0 of the License, or (at your option) any later version.
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package de.monticore.lang.monticar.cnnarch._cocos;
import de.monticore.lang.monticar.cnnarch._ast.ASTArchArgument;
import de.monticore.lang.monticar.cnnarch._ast.ASTUnrollInstruction;
import de.monticore.lang.monticar.cnnarch._symboltable.*;
import de.monticore.lang.monticar.cnnarch.helper.ErrorCodes;
import de.se_rwth.commons.Joiners;
import de.se_rwth.commons.logging.Log;
import java.util.HashSet;
import java.util.Set;
public class CheckUnroll implements CNNArchASTUnrollInstructionCoCo{
@Override
public void check(ASTUnrollInstruction node) {
Set<String> nameSet = new HashSet<>();
for (ASTArchArgument argument : node.getArgumentsList()){
String name = argument.getName();
if (nameSet.contains(name)){
Log.error("0" + ErrorCodes.DUPLICATED_ARG + " Duplicated name: " + name +
". Multiple values assigned to the same argument."
, argument.get_SourcePositionStart());
}
else {
nameSet.add(name);
}
}
UnrollDeclarationSymbol layerDeclaration = ((UnrollInstructionSymbol) node.getSymbolOpt().get()).getDeclaration();
if (layerDeclaration == null){
ArchitectureSymbol architecture = node.getSymbolOpt().get().getEnclosingScope().<ArchitectureSymbol>resolve("", ArchitectureSymbol.KIND).get();
Log.error("0" + ErrorCodes.UNKNOWN_LAYER + " Unknown layer. " +
"Layer with name '" + node.getName() + "' does not exist. " +
"Existing layers: " + Joiners.COMMA.join(architecture.getUnrollDeclarations()) + "."
, node.get_SourcePositionStart());
}
else {
Set<String> requiredArguments = new HashSet<>();
for (ParameterSymbol param : layerDeclaration.getParameters()){
if (!param.getDefaultExpression().isPresent()){
requiredArguments.add(param.getName());
}
}
for (ASTArchArgument argument : node.getArgumentsList()){
requiredArguments.remove(argument.getName());
}
for (String missingArgumentName : requiredArguments){
Log.error("0"+ErrorCodes.MISSING_ARGUMENT + " Missing argument. " +
"The argument '" + missingArgumentName + "' is required."
, node.get_SourcePositionStart());
}
}
}
}
......@@ -27,6 +27,8 @@ import de.monticore.lang.monticar.cnnarch._symboltable.VariableSymbol;
import de.monticore.lang.monticar.cnnarch.helper.ErrorCodes;
import de.se_rwth.commons.logging.Log;
import java.util.Optional;
public class CheckVariableMember extends CNNArchSymbolCoCo {
@Override
......@@ -40,14 +42,32 @@ public class CheckVariableMember extends CNNArchSymbolCoCo {
if (variable.getType() == VariableSymbol.Type.LAYER) {
LayerDeclarationSymbol layerDeclaration = variable.getLayerVariableDeclaration().getLayer().getDeclaration();
if (layerDeclaration.isPredefined() && !((PredefinedLayerDeclaration) layerDeclaration).isValidMember(variable.getMember())) {
if (layerDeclaration.isPredefined() && ((PredefinedLayerDeclaration) layerDeclaration).getArrayLength(variable.getMember()) == 0) {
Log.error("0" + ErrorCodes.INVALID_MEMBER + " Layer has no member " + variable.getMember().toString().toLowerCase() + ". ",
variable.getSourcePosition());
}
if (variable.getArrayAccess().isPresent()) {
Log.error("0" + ErrorCodes.INVALID_MEMBER + " Currently layer variable array access is not implemented. ",
variable.getSourcePosition());
Optional<Integer> arrayAccess = variable.getArrayAccess().get().getIntValue();
int arrayLength = 0;
if (layerDeclaration.isPredefined()) {
arrayLength = ((PredefinedLayerDeclaration) layerDeclaration).getArrayLength(variable.getMember());
}
String name = variable.getName() + "." + variable.getMember().toString().toLowerCase();
if (arrayAccess.isPresent() && arrayLength == 1) {
Log.error("0" + ErrorCodes.INVALID_ARRAY_ACCESS + " The layer variable '" + name +
"' does not support array access. "
, variable.getSourcePosition());
} else if (!arrayAccess.isPresent() || arrayAccess.get() < 0 || arrayAccess.get() >= arrayLength) {
Log.error("0" + ErrorCodes.INVALID_ARRAY_ACCESS + " The layer variable array access value of '" + name +
"' must be an integer between 0 and " + (arrayLength - 1) + ". " +
"The current value is: " + variable.getArrayAccess().get().getValue().get().toString()
, variable.getSourcePosition());
}
//
}
}
......
......@@ -222,9 +222,9 @@ public class ArchTypeSymbol extends CommonSymbol {
}
public static class Builder{
private int height = 0;
private int width = 0;
private int channels = 0;
private int height = 1;
private int width = 1;
private int channels = 1;
private ASTElementType domain = null;
public Builder height(int height){
......@@ -252,29 +252,22 @@ public class ArchTypeSymbol extends CommonSymbol {
domain.setRange(range);
return this;
}
public Builder elementType(String name, String start, String end){
domain = new ASTElementType();
domain.setName(name); //("Q(" + start + ":" + end +")");
ASTRange range = new ASTRange();
range.setStartValue(start);
range.setEndValue(end);
domain.setRange(range);