Commit 38bd40be authored by Christian Fuß's avatar Christian Fuß
Browse files

resolved merge conflicts

parents a19ac91b 0d4c4530
Pipeline #159142 failed with stages
......@@ -392,7 +392,14 @@ All predefined methods start with a capital letter and all constructed methods h
Opposite of *Concatenate*. Handles a single input stream and splits it into *n* output streams.
The output streams have the same height and width as the input stream and a number channels which is in general `input_channels / n`.
The last output stream will have a higher number of channels than the other if `input_channels` is not divisible by `n`.
* **n** (integer > 0, required): The number of output streams. Cannot be higher than the number of input channels.
* **OneHot(size)**
Creates a OneHot vector of a given size, given a scalar in the previous layer that determines the OneHot-Index (the index at which the *1* in the vector will be placed).
* **size** (integer > 0, optional): The OneHot-vector's size. Can be omitted to automatically use the output size of the architecture.
......@@ -30,7 +30,7 @@
<groupId>de.monticore.lang.monticar</groupId>
<artifactId>cnn-arch</artifactId>
<version>0.3.1-SNAPSHOT</version>
<version>0.3.2-SNAPSHOT</version>
......
......@@ -69,6 +69,8 @@ grammar CNNArch extends de.monticore.CommonExpressions, de.monticore.lang.Math,
IOElement implements ArchitectureElement = Name ("[" index:ArchSimpleExpression "]")?;
Constant implements ArchitectureElement = ArchSimpleExpression;
Layer implements ArchitectureElement = Name "(" arguments:(ArchArgument || ",")* ")";
Unroll implements ArchitectureElement = "unroll" "<" timeParameter:LayerParameter ">"
......
/**
*
* ******************************************************************************
* MontiCAR Modeling Family, www.se-rwth.de
* Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
* All rights reserved.
*
* This project is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 3.0 of the License, or (at your option) any later version.
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package de.monticore.lang.monticar.cnnarch;
import de.monticore.io.paths.ModelPath;
import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.CNNArchLanguage;
import de.monticore.symboltable.GlobalScope;
import de.monticore.symboltable.Scope;
import de.se_rwth.commons.logging.Log;
import java.io.IOException;
import java.nio.file.Path;
import java.util.Map;
public abstract class CNNArchGenerator {
private String generationTargetPath;
private String modelsDirPath;
public static void quitGeneration(){
Log.error("Code generation is aborted");
System.exit(1);
}
public boolean isCMakeRequired() {
return true;
}
public String getGenerationTargetPath(){
if (generationTargetPath.charAt(generationTargetPath.length() - 1) != '/') {
this.generationTargetPath = generationTargetPath + "/";
}
return generationTargetPath;
}
public void setGenerationTargetPath(String generationTargetPath){
this.generationTargetPath = generationTargetPath;
}
protected String getModelsDirPath() {
return this.modelsDirPath;
}
public void generate(Path modelsDirPath, String rootModelName){
this.modelsDirPath = modelsDirPath.toString();
final ModelPath mp = new ModelPath(modelsDirPath);
GlobalScope scope = new GlobalScope(mp, new CNNArchLanguage());
generate(scope, rootModelName);
}
public abstract boolean check(ArchitectureSymbol architecture);
public abstract void generate(Scope scope, String rootModelName);
//check cocos with CNNArchCocos.checkAll(architecture) before calling this method.
public abstract Map<String, String> generateStrings(ArchitectureSymbol architecture);
//check cocos with CNNArchCocos.checkAll(architecture) before calling this method.
public void generateFiles(ArchitectureSymbol architecture) throws IOException{
Map<String, String> fileContentMap = generateStrings(architecture);
generateFromFilecontentsMap(fileContentMap);
}
public abstract void generateFromFilecontentsMap(Map<String, String> fileContentMap) throws IOException;
public void generateCMake(String rootModelName){
Map<String, String> fileContentMap = generateCMakeContent(rootModelName);
try {
generateFromFilecontentsMap(fileContentMap);
} catch (IOException e) {
Log.error("CMake file could not be generated" + e.getMessage());
}
}
public abstract Map<String, String> generateCMakeContent(String rootModelName);
}
/**
*
* ******************************************************************************
* MontiCAR Modeling Family, www.se-rwth.de
* Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
* All rights reserved.
*
* This project is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 3.0 of the License, or (at your option) any later version.
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package de.monticore.lang.monticar.cnnarch;
import de.se_rwth.commons.logging.Log;
import java.io.*;
import java.net.URL;
import java.util.Objects;
import java.util.Properties;
public class DataPathConfigParser{
private String configTargetPath;
private String configFileName;
private Properties properties;
public DataPathConfigParser(String configPath) {
setConfigPath(configPath);
properties = new Properties();
try
{
properties.load(new FileInputStream(configTargetPath));
} catch(IOException e)
{
Log.error("Config file " + configPath + " could not be found");
}
}
public String getConfigPath() {
if (configTargetPath.charAt(configTargetPath.length() - 1) != '/') {
this.configTargetPath = configTargetPath + "/";
}
return configTargetPath;
}
public void setConfigPath(String configTargetPath){
this.configTargetPath = configTargetPath;
}
public String getDataPath(String modelName) {
String path = properties.getProperty(modelName);
if(path == null) {
Log.error("Data path config file did not specify a path for component '" + modelName + "'");
}
return path;
}
}
......@@ -20,7 +20,6 @@
*/
package de.monticore.lang.monticar.cnnarch._symboltable;
import de.monticore.expressionsbasis._ast.ASTExpression;
import de.monticore.lang.math._symboltable.MathSymbolTableCreator;
import de.monticore.lang.math._symboltable.expression.*;
......@@ -455,6 +454,17 @@ public class CNNArchSymbolTableCreator extends de.monticore.symboltable.CommonSy
addToScopeAndLinkWithNode(argument, node);
}
public void visit(ASTConstant node) {
ConstantSymbol constant = new ConstantSymbol();
addToScopeAndLinkWithNode(constant, node);
}
public void endVisit(ASTConstant node) {
ConstantSymbol constant = (ConstantSymbol) node.getSymbolOpt().get();
constant.setExpression((ArchSimpleExpressionSymbol) node.getArchSimpleExpression().getSymbolOpt().get());
removeCurrentScope();
}
public void visit(ASTIOElement node) {
IOSymbol ioElement = new IOSymbol(node.getName());
addToScopeAndLinkWithNode(ioElement, node);
......
/**
*
* ******************************************************************************
* MontiCAR Modeling Family, www.se-rwth.de
* Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
* All rights reserved.
*
* This project is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 3.0 of the License, or (at your option) any later version.
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package de.monticore.lang.monticar.cnnarch._symboltable;
import de.monticore.lang.math._symboltable.expression.MathExpressionSymbol;
import de.monticore.lang.monticar.cnnarch.helper.ErrorCodes;
import de.monticore.lang.monticar.cnnarch.helper.Utils;
import de.monticore.lang.monticar.cnnarch.predefined.AllPredefinedLayers;
import de.monticore.lang.monticar.ranges._ast.ASTRange;
import de.monticore.lang.monticar.types2._ast.ASTElementType;
import de.monticore.symboltable.Scope;
import de.monticore.symboltable.Symbol;
import de.se_rwth.commons.logging.Log;
import java.util.*;
public class ConstantSymbol extends ArchitectureElementSymbol {
private ArchSimpleExpressionSymbol expression = null;
protected ConstantSymbol() {
super("const");
}
public ArchSimpleExpressionSymbol getExpression() {
return expression;
}
protected void setExpression(ArchSimpleExpressionSymbol expression) {
this.expression = expression;
}
@Override
public boolean isResolvable() {
return super.isResolvable();
}
@Override
public boolean isAtomic() {
return getResolvedThis().isPresent() && getResolvedThis().get() == this;
}
@Override
public List<ArchitectureElementSymbol> getFirstAtomicElements() {
if (getResolvedThis().isPresent() && getResolvedThis().get() != this) {
return getResolvedThis().get().getFirstAtomicElements();
}
else {
return Collections.singletonList(this);
}
}
@Override
public List<ArchitectureElementSymbol> getLastAtomicElements() {
if (getResolvedThis().isPresent() && getResolvedThis().get() != this) {
return getResolvedThis().get().getLastAtomicElements();
}
else {
return Collections.singletonList(this);
}
}
@Override
public Set<VariableSymbol> resolve() throws ArchResolveException {
if (!isResolved()) {
if (isResolvable()) {
resolveExpressions();
setResolvedThis(this);
}
}
return getUnresolvableVariables();
}
@Override
protected void computeUnresolvableVariables(Set<VariableSymbol> unresolvableVariables, Set<VariableSymbol> allVariables) {
getExpression().checkIfResolvable(allVariables);
unresolvableVariables.addAll(getExpression().getUnresolvableVariables());
}
@Override
public List<ArchTypeSymbol> computeOutputTypes() {
List<ArchTypeSymbol> outputShapes;
if (isAtomic()) {
ArchTypeSymbol outputShape = new ArchTypeSymbol();
// Since symbol is resolved at this point, it is safe to assume that the expression is an int
int value = getExpression().getIntValue().get();
ASTRange range = new ASTRange();
range.setStartValue(String.valueOf(value));
range.setEndValue(String.valueOf(value));
ASTElementType domain = new ASTElementType("Z", Optional.of(range));
outputShape.setDomain(domain);
outputShapes = Collections.singletonList(outputShape);
}
else {
if (!getResolvedThis().isPresent()){
throw new IllegalStateException("The architecture resolve() method was never called");
}
outputShapes = getResolvedThis().get().computeOutputTypes();
}
return outputShapes;
}
@Override
public void checkInput() {
if (isAtomic()) {
if (!getInputTypes().isEmpty()) {
Log.error("0" + ErrorCodes.INVALID_ELEMENT_INPUT_SHAPE + " Invalid number of input streams. "
, getSourcePosition());
}
}
else {
if (!getResolvedThis().isPresent()) {
throw new IllegalStateException("The architecture resolve() method was never called");
}
getResolvedThis().get().checkInput();
}
}
@Override
public Optional<Integer> getParallelLength() {
return Optional.of(1);
}
@Override
public Optional<List<Integer>> getSerialLengths() {
return Optional.of(Collections.nCopies(getParallelLength().get(), 1));
}
@Override
protected void putInScope(Scope scope) {
Collection<Symbol> symbolsInScope = scope.getLocalSymbols().get(getName());
if (symbolsInScope == null || !symbolsInScope.contains(this)) {
scope.getAsMutableScope().add(this);
getExpression().putInScope(getSpannedScope());
}
}
@Override
protected void resolveExpressions() throws ArchResolveException {
getExpression().resolveOrError();
if (!Constraints.INTEGER.check(getExpression(), getSourcePosition(), getName())) {
throw new ArchResolveException();
}
}
@Override
protected ArchitectureElementSymbol preResolveDeepCopy() {
ConstantSymbol copy = new ConstantSymbol();
if (getAstNode().isPresent()) {
copy.setAstNode(getAstNode().get());
}
copy.setExpression(getExpression());
return copy;
}
}
......@@ -329,6 +329,50 @@ public class LayerSymbol extends ArchitectureElementSymbol {
}
}
public void setIntValue(String parameterName, int value) {
setTValue(parameterName, value, ArchSimpleExpressionSymbol::of);
}
public void setIntTupleValue(String parameterName, List<Object> tupleValues) {
setTValue(parameterName, tupleValues, ArchSimpleExpressionSymbol::of);
}
public void setBooleanValue(String parameterName, boolean value) {
setTValue(parameterName, value, ArchSimpleExpressionSymbol::of);
}
public void setStringValue(String parameterName, String value) {
setTValue(parameterName, value, ArchSimpleExpressionSymbol::of);
}
public void setDoubleValue(String parameterName, double value) {
setTValue(parameterName, value, ArchSimpleExpressionSymbol::of);
}
public void setValue(String parameterName, Object value) {
ArchSimpleExpressionSymbol res = new ArchSimpleExpressionSymbol();
res.setValue(value);
setTValue(parameterName, res, Function.identity());
}
public <T> void setTValue(String parameterName, T value, Function<T, ArchSimpleExpressionSymbol> of) {
Optional<VariableSymbol> param = getDeclaration().getParameter(parameterName);
if (param.isPresent()) {
Optional<ArgumentSymbol> arg = getArgument(parameterName);
ArchSimpleExpressionSymbol expression = of.apply(value);
if (arg.isPresent()) {
arg.get().setRhs(expression);
}
else {
arg = Optional.of(new ArgumentSymbol(parameterName));
arg.get().setRhs(expression);
arguments.add(arg.get());
}
}
}
@Override
public Optional<Integer> getParallelLength(){
int length = -1;
......
......@@ -90,6 +90,36 @@ abstract public class PredefinedLayerDeclaration extends LayerDeclarationSymbol
}
}
protected void errorIfInputChannelSizeIsInvalid(List<ArchTypeSymbol> inputTypes, LayerSymbol layer, int channels) {
for (ArchTypeSymbol inputType : inputTypes) {
if (inputType.getChannels() != channels) {
Log.error("0" + ErrorCodes.INVALID_ELEMENT_INPUT_SHAPE + " Invalid layer input. Input channel size is "
+ inputType.getChannels() + " but needs to be " + channels + "."
, layer.getSourcePosition());
}
}
}
protected void errorIfInputHeightIsInvalid(List<ArchTypeSymbol> inputTypes, LayerSymbol layer, int height) {
for (ArchTypeSymbol inputType : inputTypes) {
if (inputType.getHeight() != height) {
Log.error("0" + ErrorCodes.INVALID_ELEMENT_INPUT_SHAPE + " Invalid layer input. Input height is "
+ inputType.getHeight() + " but needs to be " + height + "."
, layer.getSourcePosition());
}
}
}
protected void errorIfInputWidthIsInvalid(List<ArchTypeSymbol> inputTypes, LayerSymbol layer, int width) {
for (ArchTypeSymbol inputType : inputTypes) {
if (inputType.getWidth() != width) {
Log.error("0" + ErrorCodes.INVALID_ELEMENT_INPUT_SHAPE + " Invalid layer input. Input width is "
+ inputType.getWidth() + " but needs to be " + width + "."
, layer.getSourcePosition());
}
}
}
//check input for convolution and pooling
protected static void errorIfInputSmallerThanKernel(List<ArchTypeSymbol> inputTypes, LayerSymbol layer){
if (!inputTypes.isEmpty()) {
......@@ -148,6 +178,16 @@ abstract public class PredefinedLayerDeclaration extends LayerDeclarationSymbol
}
}
protected static void computeOneHotOutputSize(LayerSymbol layer){
int outputChannels = 0;
if(layer.getOutputElement().get() instanceof IOSymbol && layer.getOutputElement().get().isOutput()) {
outputChannels = ((IOSymbol) layer.getOutputElement().get()).getDefinition().getType().getChannels();
}
layer.setIntValue(AllPredefinedLayers.SIZE_NAME, outputChannels);
}
//padding with border_mode=valid, no padding
private static List<ArchTypeSymbol> computeOutputShapeWithValidPadding(ArchTypeSymbol inputType, LayerSymbol method, int channels){
int strideHeight = method.getIntTupleValue(AllPredefinedLayers.STRIDE_NAME).get().get(0);
......
......@@ -21,26 +21,29 @@
package de.monticore.lang.monticar.cnnarch.predefined;
import de.monticore.lang.monticar.cnnarch._symboltable.*;
import de.monticore.lang.monticar.cnnarch.helper.ErrorCodes;
import de.monticore.lang.monticar.ranges._ast.ASTRange;
import de.monticore.lang.monticar.types2._ast.ASTElementType;
import de.se_rwth.commons.logging.Log;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.*;
public class OneHot extends PredefinedLayerDeclaration {
private static int channels;
private OneHot() {
super(AllPredefinedLayers.ONE_HOT_NAME);
}
int size;
@Override
public List<ArchTypeSymbol> computeOutputTypes(List<ArchTypeSymbol> inputTypes, LayerSymbol layer) {
<<<<<<< HEAD
channels = inputTypes.get(0).getChannels();
=======
>>>>>>> develop
return Collections.singletonList(new ArchTypeSymbol.Builder()
.channels(layer.getIntValue(AllPredefinedLayers.SIZE_NAME).get())
.height(1)
......@@ -51,8 +54,66 @@ public class OneHot extends PredefinedLayerDeclaration {
@Override
public void checkInput(List<ArchTypeSymbol> inputTypes, LayerSymbol layer) {
computeOneHotOutputSize(layer);
size = layer.getIntValue(AllPredefinedLayers.SIZE_NAME).get();
errorIfInputSizeIsNotOne(inputTypes, layer);
<<<<<<< HEAD
//errorIfInputSizeUnequalToOnehotSize(inputTypes, layer);
=======
errorIfInputChannelSizeIsInvalid(inputTypes, layer, 1);
errorIfInputHeightIsInvalid(inputTypes, layer, 1);
errorIfInputWidthIsInvalid(inputTypes, layer, 1);
// Check range of input
ASTElementType domain = inputTypes.get(0).getDomain();
if (