Commit 78009eaf authored by Christian Fuß's avatar Christian Fuß
Browse files

added OneHotLayer and some parts of the Unrolling

parent b3f4bc77
Pipeline #150120 failed with stages
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
......@@ -25,6 +25,10 @@ grammar CNNArch extends de.monticore.CommonExpressions, de.monticore.lang.Math,
parameters:(LayerParameter || ",")* ")" "{"
body:Stream "}";
UnrollDeclaration = "unroll" "<" timeParameter:LayerParameter ">"
Name "(" parameters:(LayerParameter || ",")* ")"
"{" body:Stream "}";
IODeclaration = "def"
(in:"input" | out:"output")
type:ArchType
......@@ -53,7 +57,8 @@ grammar CNNArch extends de.monticore.CommonExpressions, de.monticore.lang.Math,
@attribute body
The architecture of the neural network.
*/
Architecture = methodDeclaration:LayerDeclaration*
Architecture = unrollDeclarations:UnrollDeclaration*
methodDeclaration:LayerDeclaration*
instructions:(Instruction || ";")+ ";";
interface Instruction;
......@@ -66,6 +71,8 @@ grammar CNNArch extends de.monticore.CommonExpressions, de.monticore.lang.Math,
Layer implements ArchitectureElement = Name "(" arguments:(ArchArgument || ",")* ")";
Unroll implements ArchitectureElement = Name "(" arguments:(ArchArgument || ",")* ")";
ParallelBlock implements ArchitectureElement = "("
groups:Stream "|"
groups:(Stream || "|")+ ")";
......
......@@ -33,6 +33,9 @@ public class CNNArchSymbolCoCo {
else if (sym instanceof LayerDeclarationSymbol){
check((LayerDeclarationSymbol) sym);
}
else if (sym instanceof UnrollDeclarationSymbol){
check((UnrollDeclarationSymbol) sym);
}
else if (sym instanceof ArchitectureElementSymbol){
check((ArchitectureElementSymbol) sym);
}
......@@ -72,6 +75,10 @@ public class CNNArchSymbolCoCo {
//Override if needed
}
public void check(UnrollDeclarationSymbol sym){
//Override if needed
}
public void check(ArchitectureElementSymbol sym){
//Override if needed
}
......
/**
*
* ******************************************************************************
* MontiCAR Modeling Family, www.se-rwth.de
* Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
* All rights reserved.
*
* This project is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 3.0 of the License, or (at your option) any later version.
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package de.monticore.lang.monticar.cnnarch._cocos;
import de.monticore.lang.monticar.cnnarch._ast.ASTArchArgument;
import de.monticore.lang.monticar.cnnarch._ast.ASTLayer;
import de.monticore.lang.monticar.cnnarch._ast.ASTUnroll;
import de.monticore.lang.monticar.cnnarch._symboltable.*;
import de.monticore.lang.monticar.cnnarch.helper.ErrorCodes;
import de.se_rwth.commons.Joiners;
import de.se_rwth.commons.logging.Log;
import java.util.HashSet;
import java.util.Set;
public class CheckUnroll implements CNNArchASTUnrollCoCo{
@Override
public void check(ASTUnroll node) {
Set<String> nameSet = new HashSet<>();
for (ASTArchArgument argument : node.getArgumentsList()){
String name = argument.getName();
if (nameSet.contains(name)){
Log.error("0" + ErrorCodes.DUPLICATED_ARG + " Duplicated name: " + name +
". Multiple values assigned to the same argument."
, argument.get_SourcePositionStart());
}
else {
nameSet.add(name);
}
}
LayerDeclarationSymbol layerDeclaration = ((LayerSymbol) node.getSymbolOpt().get()).getDeclaration();
if (layerDeclaration == null){
ArchitectureSymbol architecture = node.getSymbolOpt().get().getEnclosingScope().<ArchitectureSymbol>resolve("", ArchitectureSymbol.KIND).get();
Log.error("0" + ErrorCodes.UNKNOWN_LAYER + " Unknown layer. " +
"Layer with name '" + node.getName() + "' does not exist. " +
"Existing layers: " + Joiners.COMMA.join(architecture.getLayerDeclarations()) + "."
, node.get_SourcePositionStart());
}
else {
Set<String> requiredArguments = new HashSet<>();
for (VariableSymbol param : layerDeclaration.getParameters()){
if (!param.getDefaultExpression().isPresent()){
requiredArguments.add(param.getName());
}
}
for (ASTArchArgument argument : node.getArgumentsList()){
requiredArguments.remove(argument.getName());
}
for (String missingArgumentName : requiredArguments){
Log.error("0"+ErrorCodes.MISSING_ARGUMENT + " Missing argument. " +
"The argument '" + missingArgumentName + "' is required."
, node.get_SourcePositionStart());
}
}
}
}
......@@ -242,6 +242,29 @@ public class CNNArchSymbolTableCreator extends de.monticore.symboltable.CommonSy
removeCurrentScope();
}
@Override
public void visit(ASTUnrollDeclaration ast) {
UnrollDeclarationSymbol unrollDeclaration = new UnrollDeclarationSymbol(ast.getName());
addToScopeAndLinkWithNode(unrollDeclaration, ast);
}
@Override
public void endVisit(ASTUnrollDeclaration ast) {
UnrollDeclarationSymbol unrollDeclaration = (UnrollDeclarationSymbol) ast.getSymbolOpt().get();
unrollDeclaration.setBody((CompositeElementSymbol) ast.getBody().getSymbolOpt().get());
List<VariableSymbol> parameters = new ArrayList<>(4);
for (ASTLayerParameter astParam : ast.getParametersList()){
VariableSymbol parameter = (VariableSymbol) astParam.getSymbolOpt().get();
parameters.add(parameter);
}
unrollDeclaration.setParameters(parameters);
removeCurrentScope();
}
@Override
public void visit(ASTLayerParameter ast) {
VariableSymbol variable = new VariableSymbol(ast.getName());
......
/**
*
* ******************************************************************************
* MontiCAR Modeling Family, www.se-rwth.de
* Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
* All rights reserved.
*
* This project is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 3.0 of the License, or (at your option) any later version.
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package de.monticore.lang.monticar.cnnarch._symboltable;
import com.google.gson.internal.Streams;
import de.monticore.lang.monticar.cnnarch.helper.ErrorCodes;
import de.monticore.lang.monticar.cnnarch.predefined.AllPredefinedLayers;
import de.monticore.lang.monticar.ranges._ast.ASTRange;
import de.se_rwth.commons.logging.Log;
import org.jscience.mathematics.number.Rational;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.function.BinaryOperator;
import java.util.stream.Stream;
abstract public class PredefinedUnrollDeclaration extends UnrollDeclarationSymbol {
public PredefinedUnrollDeclaration(String name) {
super(name);
}
@Override
protected void setParameters(List<VariableSymbol> parameters) {
super.setParameters(parameters);
for (VariableSymbol param : parameters){
param.putInScope(getSpannedScope());
}
}
@Override
public boolean isPredefined() {
return true;
}
abstract public List<ArchTypeSymbol> computeOutputTypes(List<ArchTypeSymbol> inputTypes, UnrollSymbol layer);
abstract public void checkInput(List<ArchTypeSymbol> inputTypes, UnrollSymbol layer);
@Override
public PredefinedUnrollDeclaration deepCopy() {
throw new IllegalStateException("Copy method should not be called for predefined layer declarations.");
}
//the following methods are only here to avoid duplication. They are used by multiple subclasses.
//check if inputTypes is of size 1
protected void errorIfInputSizeIsNotOne(List<ArchTypeSymbol> inputTypes, LayerSymbol layer){
if (inputTypes.size() != 1){
Log.error("0" + ErrorCodes.INVALID_ELEMENT_INPUT_SHAPE + " Invalid layer input. " +
getName() + " layer can only handle one input stream. " +
"Current number of input streams " + inputTypes.size() + "."
, layer.getSourcePosition());
}
}
protected void errorIfInputIsEmpty(List<ArchTypeSymbol> inputTypes, LayerSymbol layer){
if (inputTypes.size() == 0){
Log.error("0" + ErrorCodes.INVALID_ELEMENT_INPUT_SHAPE + " Invalid layer input. Number of input streams is 0"
, layer.getSourcePosition());
}
}
//check input for convolution and pooling
protected static void errorIfInputSmallerThanKernel(List<ArchTypeSymbol> inputTypes, LayerSymbol layer){
if (!inputTypes.isEmpty()) {
int inputHeight = inputTypes.get(0).getHeight();
int inputWidth = inputTypes.get(0).getWidth();
int kernelHeight = layer.getIntTupleValue(AllPredefinedLayers.KERNEL_NAME).get().get(0);
int kernelWidth = layer.getIntTupleValue(AllPredefinedLayers.KERNEL_NAME).get().get(1);
if (kernelHeight > inputHeight || kernelWidth > inputWidth){
if (layer.getStringValue(AllPredefinedLayers.PADDING_NAME).equals(AllPredefinedLayers.PADDING_VALID)){
Log.error("0" + ErrorCodes.INVALID_ELEMENT_INPUT_SHAPE + " Invalid layer input. " +
"The input resolution is smaller than the kernel and the padding mode is 'valid'." +
"This would result in an output resolution of 0x0."
, layer.getSourcePosition());
}
else {
Log.warn("The input resolution is smaller than the kernel. " +
"This results in an output resolution of 1x1. " +
"If this warning appears multiple times, consider changing your architecture"
, layer.getSourcePosition());
}
}
}
}
//output type function for convolution and pooling
protected static List<ArchTypeSymbol> computeConvAndPoolOutputShape(ArchTypeSymbol inputType, LayerSymbol method, int channels) {
String borderModeSetting = method.getStringValue(AllPredefinedLayers.PADDING_NAME).get();
if (borderModeSetting.equals(AllPredefinedLayers.PADDING_SAME)){
return computeOutputShapeWithSamePadding(inputType, method, channels);
}
else if (borderModeSetting.equals(AllPredefinedLayers.PADDING_VALID)){
return computeOutputShapeWithValidPadding(inputType, method, channels);
}
else if (borderModeSetting.equals(AllPredefinedLayers.PADDING_NO_LOSS)){
return computeOutputShapeWithNoLossPadding(inputType, method, channels);
}
else{
throw new IllegalStateException("border_mode is " + borderModeSetting + ". This should never happen.");
}
}
//padding with border_mode=valid, no padding
private static List<ArchTypeSymbol> computeOutputShapeWithValidPadding(ArchTypeSymbol inputType, LayerSymbol method, int channels){
int strideHeight = method.getIntTupleValue(AllPredefinedLayers.STRIDE_NAME).get().get(0);
int strideWidth = method.getIntTupleValue(AllPredefinedLayers.STRIDE_NAME).get().get(1);
int kernelHeight = method.getIntTupleValue(AllPredefinedLayers.KERNEL_NAME).get().get(0);
int kernelWidth = method.getIntTupleValue(AllPredefinedLayers.KERNEL_NAME).get().get(1);
int inputHeight = inputType.getHeight();
int inputWidth = inputType.getWidth();
int outputWidth;
int outputHeight;
if (inputWidth < kernelWidth || inputHeight < kernelHeight){
outputWidth = 0;
outputHeight = 0;
}
else {
outputWidth = 1 + (inputWidth - kernelWidth) / strideWidth;
outputHeight = 1 + (inputHeight - kernelHeight) / strideHeight;
}
return Collections.singletonList(new ArchTypeSymbol.Builder()
.height(outputHeight)
.width(outputWidth)
.channels(channels)
.elementType("-oo", "oo")
.build());
}
//padding until no data gets discarded, same as valid with a stride of 1
private static List<ArchTypeSymbol> computeOutputShapeWithNoLossPadding(ArchTypeSymbol inputType, LayerSymbol method, int channels){
int strideHeight = method.getIntTupleValue(AllPredefinedLayers.STRIDE_NAME).get().get(0);
int strideWidth = method.getIntTupleValue(AllPredefinedLayers.STRIDE_NAME).get().get(1);
int kernelHeight = method.getIntTupleValue(AllPredefinedLayers.KERNEL_NAME).get().get(0);
int kernelWidth = method.getIntTupleValue(AllPredefinedLayers.KERNEL_NAME).get().get(1);
int inputHeight = inputType.getHeight();
int inputWidth = inputType.getWidth();
int outputWidth = 1 + Math.max(0, ((inputWidth - kernelWidth + strideWidth - 1) / strideWidth));
int outputHeight = 1 + Math.max(0, ((inputHeight - kernelHeight + strideHeight - 1) / strideHeight));
return Collections.singletonList(new ArchTypeSymbol.Builder()
.height(outputHeight)
.width(outputWidth)
.channels(channels)
.elementType("-oo", "oo")
.build());
}
//padding with border_mode='same'
private static List<ArchTypeSymbol> computeOutputShapeWithSamePadding(ArchTypeSymbol inputType, LayerSymbol method, int channels){
int strideHeight = method.getIntTupleValue(AllPredefinedLayers.STRIDE_NAME).get().get(0);
int strideWidth = method.getIntTupleValue(AllPredefinedLayers.STRIDE_NAME).get().get(1);
int inputHeight = inputType.getHeight();
int inputWidth = inputType.getWidth();
int outputWidth = (inputWidth + strideWidth - 1) / strideWidth;
int outputHeight = (inputHeight + strideWidth - 1) / strideHeight;
return Collections.singletonList(new ArchTypeSymbol.Builder()
.height(outputHeight)
.width(outputWidth)
.channels(channels)
.elementType("-oo", "oo")
.build());
}
protected List<String> computeStartAndEndValue(List<ArchTypeSymbol> inputTypes, BinaryOperator<Rational> startValAccumulator, BinaryOperator<Rational> endValAccumulator){
Stream.Builder<Rational> startValues = Stream.builder();
Stream.Builder<Rational> endValues = Stream.builder();
String start = null;
String end = null;
for (ArchTypeSymbol inputType : inputTypes){
Optional<ASTRange> range = inputType.getDomain().getRangeOpt();
if (range.isPresent()) {
if (range.get().hasNoLowerLimit()){
start = "-oo";
}
else {
startValues.add(range.get().getStartValue());
}
if (range.get().hasNoUpperLimit()){
end = "oo";
}
else {
endValues.add(range.get().getEndValue());
}
}
}
if (start == null){
start = "" + startValues.build().reduce(startValAccumulator).get().doubleValue();
}
if (end == null){
end = "" + endValues.build().reduce(endValAccumulator).get().doubleValue();
}
return Arrays.asList(start, end);
}
}
/**
*
* ******************************************************************************
* MontiCAR Modeling Family, www.se-rwth.de
* Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
* All rights reserved.
*
* This project is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 3.0 of the License, or (at your option) any later version.
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package de.monticore.lang.monticar.cnnarch._symboltable;
import de.monticore.symboltable.SymbolKind;
public class UnrollDeclarationKind implements SymbolKind {
private static final String NAME = "de.monticore.lang.monticar.cnnarch._symboltable.UnrollDeclarationKind";
@Override
public String getName() {
return NAME;
}
@Override
public boolean isKindOf(SymbolKind kind) {
return NAME.equals(kind.getName()) || SymbolKind.super.isKindOf(kind);
}
}
/**
*
* ******************************************************************************
* MontiCAR Modeling Family, www.se-rwth.de
* Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
* All rights reserved.
*
* This project is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 3.0 of the License, or (at your option) any later version.
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package de.monticore.lang.monticar.cnnarch._symboltable;
import de.monticore.symboltable.MutableScope;
import de.monticore.symboltable.Symbol;
import java.util.Optional;
public class UnrollDeclarationScope extends de.monticore.symboltable.CommonScope {
public UnrollDeclarationScope() {
super(true);
}
public UnrollDeclarationScope(Optional<MutableScope> enclosingScope) {
super(enclosingScope, true);
}
@Override
public void add(Symbol symbol) {
super.add(symbol);
if (symbol instanceof ArchitectureElementSymbol){
ArchitectureElementScope subScope = ((ArchitectureElementSymbol) symbol).getSpannedScope();
addSubScope(subScope);
subScope.setResolvingFilters(getResolvingFilters());
}
}
}
\ No newline at end of file
/**
*
* ******************************************************************************
* MontiCAR Modeling Family, www.se-rwth.de
* Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
* All rights reserved.
*
* This project is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 3.0 of the License, or (at your option) any later version.
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
/* generated by template symboltable.ScopeSpanningSymbol*/
package de.monticore.lang.monticar.cnnarch._symboltable;
import de.monticore.lang.monticar.cnnarch.predefined.AllPredefinedVariables;
import de.monticore.symboltable.CommonScopeSpanningSymbol;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
public class UnrollDeclarationSymbol extends CommonScopeSpanningSymbol {
public static final UnrollDeclarationKind KIND = new UnrollDeclarationKind();
private List<VariableSymbol> parameters;
private CompositeElementSymbol body;
protected UnrollDeclarationSymbol(String name) {
super(name, KIND);
}
@Override
protected UnrollDeclarationScope createSpannedScope() {
return new UnrollDeclarationScope();
}
@Override
public UnrollDeclarationScope getSpannedScope() {
return (UnrollDeclarationScope) super.getSpannedScope();
}
public List<VariableSymbol> getParameters() {
return parameters;
}
protected void setParameters(List<VariableSymbol> parameters) {
this.parameters = parameters;
if (!getParameter(AllPredefinedVariables.CONDITIONAL_ARG_NAME).isPresent()){
VariableSymbol ifParam = AllPredefinedVariables.createConditionalParameter();
this.parameters.add(ifParam);
ifParam.putInScope(getSpannedScope());
}
if (!getParameter(AllPredefinedVariables.SERIAL_ARG_NAME).isPresent()){
VariableSymbol forParam = AllPredefinedVariables.createSerialParameter();
this.parameters.add(forParam);
forParam.putInScope(getSpannedScope());
}
if (!getParameter(AllPredefinedVariables.PARALLEL_ARG_NAME).isPresent()){
VariableSymbol forParam = AllPredefinedVariables.createParallelParameter();