Commit 1c385279 authored by Evgeny Kusmenko's avatar Evgeny Kusmenko
Browse files

Merge branch 'develop' into 'master'

Implemented unrolling of RNNs and various new layers

See merge request !28
parents 235254ec b284eb3c
Pipeline #213689 passed with stages
in 15 minutes and 34 seconds
......@@ -410,6 +410,70 @@ All predefined methods start with a capital letter and all constructed methods h
* **size** (integer > 0, optional): The OneHot-vector's size. Can be omitted to automatically use the output size of the architecture.
* **ArgMax()**
Computes the index of the maximal value of its input vector. Useful for recurrent networks, when the output of a timestep should be used as integer input for the next timestep.
* **BeamSearch(max_length, width)**
Must be used together with a recurrent network. Uses Beamsearch as search algorithm over the timesteps of the RNN.
* **max_length** (integer > 0, required): The maximum number of timesteps to run the RNN, and thus the maximum length of the generated sequence.
* **width** (integer > 0, required): The number of candidates to consider each in timestep. Sometimes called k.
* **BroadcastAdd()**
Takes multiple tensors as input, and broadcasts them to the same shape (Copies values along one axis until it has the size of the largest axis along all inputs). Then performs elementswise addition.
* **BroadcastMultiply()**
Takes multiple tensors as input, and broadcasts them to the same shape (Copies values along one axis until it has the size of the largest axis along all inputs). Then performs elementswise multiplication.
* **Dot()**
Performs the dot product (matrix multiplication) for two input matrices.
* **ExpandDims(axis)**
Creates a new, empty axis for a given input tensor.
* **axis** (0 <= integer <= 1, required): The axis to expand.
* **GreedySearch(max_length)**
Must be used together with a recurrent network. Uses Greedysearch as search algorithm over the timesteps of the RNN, so that only the best output for each timestep is considered.
* **max_length** (integer > 0, required): The maximum number of timesteps to run the RNN, and thus the maximum length of the generated sequence.
* **ReduceSum(axis)**
Sums all values along a given axis, and reduces the dimension of the axis afterwards, making a scalar out of a one-entry vector etc.
* **axis** (0 <= integer <= 1, optional, default=-1): The axis to sum over. Uses the last axis (-1) by default.
* **Repeat(n, axis)**
Copies the entries of an axis n times in the same axis.
* **n** (integer > 0, required): How often to copy the entries of the given axis
* **axis** (-1 <= integer <= 2, optional, default=-1): The axis to use for copying. Uses the last axis (-1) by default.
* **Reshape(shape)**
Transforms the input tensor into a different shape, while keeping the number of total entries in the tensor.
* **shape** (integer tuple, required): New shape of the tensor.
* **UpConvolution(kernel, channels, stride=(1,1), no_bias=false, padding="same")**
Creates a up convolutional layer (also known as transposed convolution ). Is currently only supported in the tesnsorflow backend.
......@@ -420,3 +484,4 @@ All predefined methods start with a capital letter and all constructed methods h
* **padding** ({"valid", "same", "no_loss"}, optional, default="same"): One of "valid", "same" or "no_loss". "valid" means no padding. "same" results in padding the input such that the output has the same length as the original input divided by the stride (rounded up). "no_loss" results in minimal padding such that each input is used by at least one filter (identical to "valid" if *stride* equals 1).
* **no_bias** (boolean, optional, default=false): Whether to disable the bias parameter.
......@@ -19,7 +19,7 @@
<groupId>de.monticore.lang.monticar</groupId>
<artifactId>cnn-arch</artifactId>
<version>0.3.3-SNAPSHOT</version>
<version>0.3.4-SNAPSHOT</version>
......@@ -348,6 +348,7 @@
<version>2.19.1</version>
<configuration>
<argLine>-Xmx1024m -Xms1024m -XX:MaxPermSize=512m -Djdk.net.URLClassPath.disableClassPathURLCheck=true</argLine>
<trimStackTrace>false</trimStackTrace>
</configuration>
</plugin>
<plugin>
......
......@@ -57,10 +57,18 @@ grammar CNNArch extends de.monticore.CommonExpressions, de.monticore.lang.Math,
Architecture = methodDeclaration:LayerDeclaration*
instructions:(Instruction || ";")+ ";";
Instruction = (LayerVariableDeclaration | Stream);
Instruction = (LayerVariableDeclaration | NetworkInstruction);
LayerVariableDeclaration = "layer" Layer Name;
interface NetworkInstruction;
StreamInstruction implements NetworkInstruction = body:Stream;
UnrollInstruction implements NetworkInstruction = "timed" "<" timeParameter:TimeParameter ">"
Name "(" arguments:(ArchArgument || ",")* ")"
"{" body:Stream "}";
Stream = elements:(ArchitectureElement || "->")+;
interface ArchitectureElement;
......@@ -71,6 +79,7 @@ grammar CNNArch extends de.monticore.CommonExpressions, de.monticore.lang.Math,
Layer implements ArchitectureElement = Name "(" arguments:(ArchArgument || ",")* ")";
ParallelBlock implements ArchitectureElement = "("
groups:Stream "|"
groups:(Stream || "|")+ ")";
......@@ -86,6 +95,8 @@ grammar CNNArch extends de.monticore.CommonExpressions, de.monticore.lang.Math,
LayerParameter implements ArchParameter = Name ("=" default:ArchSimpleExpression)? ;
TimeParameter implements ArchParameter = Name ("=" default:ArchSimpleExpression)? ;
interface ArchArgument;
ArchParameterArgument implements ArchArgument = Name "=" rhs:ArchExpression ;
......
......@@ -52,13 +52,15 @@ public class CNNArchCocos {
public static CNNArchSymbolCoCoChecker createCNNArchPostResolveSymbolChecker() {
return new CNNArchSymbolCoCoChecker()
.addCoCo(new CheckIOType())
.addCoCo(new CheckIOArrayLength())
.addCoCo(new CheckElementInputs())
.addCoCo(new CheckIOAccessAndIOMissing())
.addCoCo(new CheckArchitectureFinished())
.addCoCo(new CheckNetworkStreamMissing())
.addCoCo(new CheckVariableMember())
.addCoCo(new CheckLayerVariableDeclarationLayerType())
.addCoCo(new CheckLayerVariableDeclarationIsUsed());
.addCoCo(new CheckLayerVariableDeclarationIsUsed())
.addCoCo(new CheckConstants())
.addCoCo(new CheckUnrollInputsOutputsTooMany());
}
//checks cocos based on symbols before the resolve method of the ArchitectureSymbol is called
......@@ -66,6 +68,7 @@ public class CNNArchCocos {
return new CNNArchSymbolCoCoChecker()
.addCoCo(new CheckVariableDeclarationName())
.addCoCo(new CheckVariableName())
.addCoCo(new CheckArgmaxLayer())
.addCoCo(new CheckExpressions());
}
......
......@@ -21,6 +21,15 @@ public class CNNArchSymbolCoCo {
else if (sym instanceof LayerDeclarationSymbol){
check((LayerDeclarationSymbol) sym);
}
else if (sym instanceof UnrollDeclarationSymbol){
check((UnrollDeclarationSymbol) sym);
}
else if (sym instanceof UnrollInstructionSymbol){
check((UnrollInstructionSymbol) sym);
}
else if (sym instanceof StreamInstructionSymbol){
check((StreamInstructionSymbol) sym);
}
else if (sym instanceof ArchitectureElementSymbol){
check((ArchitectureElementSymbol) sym);
}
......@@ -60,6 +69,10 @@ public class CNNArchSymbolCoCo {
//Override if needed
}
public void check(UnrollDeclarationSymbol sym){
//Override if needed
}
public void check(ArchitectureElementSymbol sym){
//Override if needed
}
......@@ -91,4 +104,12 @@ public class CNNArchSymbolCoCo {
public void check(MathExpressionSymbol sym){
//Override if needed
}
public void check(UnrollInstructionSymbol sym){
//Override if needed
}
public void check(StreamInstructionSymbol sym){
//Override if needed
}
}
......@@ -9,21 +9,23 @@
package de.monticore.lang.monticar.cnnarch._cocos;
import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.CompositeElementSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.NetworkInstructionSymbol;
import de.monticore.lang.monticar.cnnarch.helper.ErrorCodes;
import de.se_rwth.commons.SourcePosition;
import de.se_rwth.commons.logging.Log;
public class CheckArchitectureFinished extends CNNArchSymbolCoCo {
@Override
public void check(ArchitectureSymbol architecture) {
for (CompositeElementSymbol stream : architecture.getStreams()) {
if (!stream.getOutputTypes().isEmpty()){
for (NetworkInstructionSymbol networkInstruction : architecture.getNetworkInstructions()) {
if (!networkInstruction.getBody().getOutputTypes().isEmpty()) {
Log.error("0" + ErrorCodes.UNFINISHED_ARCHITECTURE + " The architecture is not finished. " +
"There are still open streams at the end of the architecture. "
, stream.getSourcePosition());
, networkInstruction.getSourcePosition());
}
}
if (architecture.getInputs().isEmpty()){
Log.error("0" + ErrorCodes.UNFINISHED_ARCHITECTURE + " The architecture has no inputs. "
, architecture.getSourcePosition());
......
/**
*
* (c) https://github.com/MontiCore/monticore
*
* The license generally applicable for this project
* can be found under https://github.com/MontiCore/monticore.
*/
package de.monticore.lang.monticar.cnnarch._cocos;
import de.monticore.lang.monticar.cnnarch._symboltable.*;
import de.monticore.lang.monticar.cnnarch.helper.ErrorCodes;
import de.monticore.lang.monticar.cnnarch.predefined.AllPredefinedLayers;
import de.se_rwth.commons.logging.Log;
public class CheckArgmaxLayer extends CNNArchSymbolCoCo {
@Override
public void check(ArchitectureElementSymbol symbol) {
if (symbol instanceof LayerSymbol && symbol.getName().equals(AllPredefinedLayers.ARG_MAX_NAME)) {
checkArgmaxBeforeOutput((LayerSymbol) symbol);
}
}
public void checkArgmaxBeforeOutput(LayerSymbol layer) {
if(!(layer.getOutputElement().get() instanceof VariableSymbol)){
Log.error("0" + ErrorCodes.ILLEGAL_LAYER_USE + " ArgMax Layer must be applied directly before an output symbol.");
}
}
}
......@@ -9,9 +9,7 @@
package de.monticore.lang.monticar.cnnarch._cocos;
import de.monticore.lang.monticar.cnnarch._ast.ASTArchArgument;
import de.monticore.lang.monticar.cnnarch._symboltable.ArgumentSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.LayerDeclarationSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.LayerSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.*;
import de.monticore.lang.monticar.cnnarch.helper.ErrorCodes;
import de.se_rwth.commons.Joiners;
import de.se_rwth.commons.logging.Log;
......@@ -30,7 +28,14 @@ public class CheckArgument implements CNNArchASTArchArgumentCoCo {
"Possible arguments are: " + Joiners.COMMA.join(layerDeclaration.getParameters())
, node.get_SourcePositionStart());
}
}else if(argument.getEnclosingScope().getSpanningSymbol().get() instanceof UnrollInstructionSymbol){
UnrollDeclarationSymbol layerDeclaration = argument.getUnroll().getDeclaration();
if (layerDeclaration != null && argument.getParameter() == null){
Log.error("0"+ ErrorCodes.UNKNOWN_ARGUMENT + " Unknown Argument. " +
"Parameter with name '" + node.getName() + "' does not exist. " +
"Possible arguments are: " + Joiners.COMMA.join(layerDeclaration.getParameters())
, node.get_SourcePositionStart());
}
}
}
......
/**
*
* (c) https://github.com/MontiCore/monticore
*
* The license generally applicable for this project
* can be found under https://github.com/MontiCore/monticore.
*/
package de.monticore.lang.monticar.cnnarch._cocos;
import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureElementSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.ConstantSymbol;
import de.monticore.lang.monticar.cnnarch.helper.ErrorCodes;
import de.se_rwth.commons.logging.Log;
import java.util.Optional;
public class CheckConstants extends CNNArchSymbolCoCo {
@Override
public void check(ArchitectureElementSymbol sym) {
if (sym instanceof ConstantSymbol) {
checkConstant((ConstantSymbol) sym);
}
}
public void checkConstant(ConstantSymbol constant) {
Optional<Boolean> isInt = constant.getExpression().isInt();
if (!isInt.isPresent() || !isInt.get()) {
Log.error("0" + ErrorCodes.INVALID_CONSTANT + " Invalid constant, only integers allowed. ",
constant.getSourcePosition());
}
}
}
......@@ -9,14 +9,14 @@
package de.monticore.lang.monticar.cnnarch._cocos;
import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.CompositeElementSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.NetworkInstructionSymbol;
public class CheckElementInputs extends CNNArchSymbolCoCo {
@Override
public void check(ArchitectureSymbol architecture) {
for (CompositeElementSymbol stream : architecture.getStreams()) {
stream.checkInput();
for (NetworkInstructionSymbol networkInstruction : architecture.getNetworkInstructions()) {
networkInstruction.getBody().checkInput();
}
}
}
......@@ -8,9 +8,7 @@
/* (c) https://github.com/MontiCore/monticore */
package de.monticore.lang.monticar.cnnarch._cocos;
import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.IODeclarationSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.VariableSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.*;
import de.monticore.lang.monticar.cnnarch.helper.ErrorCodes;
import de.se_rwth.commons.Joiners;
import de.se_rwth.commons.logging.Log;
......@@ -32,6 +30,10 @@ public class CheckIOAccessAndIOMissing extends CNNArchSymbolCoCo {
else {
checkIOArray(ioDeclaration);
}
if (ioDeclaration.isOutput()) {
checkOutputWrittenToOnce(ioDeclaration);
}
}
}
......@@ -51,7 +53,6 @@ public class CheckIOAccessAndIOMissing extends CNNArchSymbolCoCo {
}
}
private void checkIOArray(IODeclarationSymbol ioDeclaration){
List<Integer> unusedIndices = IntStream.range(0, ioDeclaration.getArrayLength()).boxed().collect(Collectors.toList());
......@@ -65,7 +66,7 @@ public class CheckIOAccessAndIOMissing extends CNNArchSymbolCoCo {
Log.error("0" + ErrorCodes.INVALID_ARRAY_ACCESS + " The IO array access value of '" + ioElement.getName() +
"' must be an integer between 0 and " + (ioDeclaration.getArrayLength()-1) + ". " +
"The current value is: " + ioElement.getArrayAccess().get().getValue().get().toString()
, ioElement.getSourcePosition());
, ioElement.getSourcePosition());
}
}
else{
......@@ -80,4 +81,63 @@ public class CheckIOAccessAndIOMissing extends CNNArchSymbolCoCo {
}
}
private void checkOutputWrittenToOnce(IODeclarationSymbol ioDeclaration) {
List<Integer> written = new ArrayList<>();
for (NetworkInstructionSymbol networkInstruction : ioDeclaration.getArchitecture().getNetworkInstructions()) {
if (networkInstruction.isStream()) {
SerialCompositeElementSymbol body = networkInstruction.getBody();
List<ArchitectureElementSymbol> outputs = body.getLastAtomicElements();
for (ArchitectureElementSymbol output : outputs) {
if (output instanceof VariableSymbol) {
VariableSymbol variable = (VariableSymbol) output;
if (variable.getName().equals(ioDeclaration.getName())) {
int arrayAccess = 0;
if (variable.getArrayAccess().isPresent()) {
arrayAccess = variable.getArrayAccess().get().getIntValue().orElse(0);
}
if (!written.contains(arrayAccess)) {
written.add(arrayAccess);
} else {
Log.error("0" + ErrorCodes.OUTPUT_WRITTEN_TO_MULTIPLE_TIMES + " " + variable.getName() + "["
+ arrayAccess + "] is written to multiple times, this is currently not allowed."
, networkInstruction.getSourcePosition());
}
}
}
}
} else if (networkInstruction.isUnroll()) {
for (SerialCompositeElementSymbol body : networkInstruction.toUnrollInstruction().getResolvedBodies()) {
List<ArchitectureElementSymbol> outputs = body.getLastAtomicElements();
for (ArchitectureElementSymbol output : outputs) {
if (output instanceof VariableSymbol) {
VariableSymbol variable = (VariableSymbol) output;
if (variable.getName().equals(ioDeclaration.getName())) {
int arrayAccess = 0;
if (variable.getArrayAccess().isPresent()) {
arrayAccess = variable.getArrayAccess().get().getIntValue().orElse(0);
}
if (!written.contains(arrayAccess)) {
written.add(arrayAccess);
} else {
Log.error("0" + ErrorCodes.OUTPUT_WRITTEN_TO_MULTIPLE_TIMES + " " + variable.getName() + "["
+ arrayAccess + "] is written to multiple times, this is currently not allowed."
, networkInstruction.getSourcePosition());
}
}
}
}
}
}
}
}
}
......@@ -8,24 +8,23 @@
/* (c) https://github.com/MontiCore/monticore */
package de.monticore.lang.monticar.cnnarch._cocos;
import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.CompositeElementSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.*;
import de.monticore.lang.monticar.cnnarch.helper.ErrorCodes;
import de.se_rwth.commons.logging.Log;
public class CheckNetworkStreamMissing extends CNNArchSymbolCoCo {
public class CheckIOArrayLength extends CNNArchSymbolCoCo {
@Override
public void check(ArchitectureSymbol architecture) {
boolean hasTrainableStream = false;
for (CompositeElementSymbol stream : architecture.getStreams()) {
hasTrainableStream |= stream.isTrainable();
for (IODeclarationSymbol ioDeclaration : architecture.getIODeclarations()){
checkIO(ioDeclaration);
}
}
if (!hasTrainableStream) {
Log.error("0" + ErrorCodes.MISSING_TRAINABLE_STREAM + " The architecture has no trainable stream. "
, architecture.getSourcePosition());
public void checkIO(IODeclarationSymbol ioDeclaration) {
if (ioDeclaration.getArrayLength() > IODeclarationSymbol.MAX_ARRAY_LENGTH) {
Log.error("0" + ErrorCodes.INVALID_IO_ARRAY_LENGTH + " Invalid IO array length. Length can not be bigger than " + IODeclarationSymbol.MAX_ARRAY_LENGTH
, ioDeclaration.getSourcePosition());
}
}
......
......@@ -37,7 +37,6 @@ public class CheckLayer implements CNNArchASTLayerCoCo{
nameSet.add(name);
}
}
LayerDeclarationSymbol layerDeclaration = ((LayerSymbol) node.getSymbolOpt().get()).getDeclaration();
if (layerDeclaration == null){
ArchitectureSymbol architecture = node.getSymbolOpt().get().getEnclosingScope().<ArchitectureSymbol>resolve("", ArchitectureSymbol.KIND).get();
......
......@@ -12,7 +12,10 @@ import de.monticore.lang.monticar.cnnarch._symboltable.*;
import de.monticore.lang.monticar.cnnarch.helper.ErrorCodes;
import de.se_rwth.commons.logging.Log;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.Set;
public class CheckLayerVariableDeclarationIsUsed extends CNNArchSymbolCoCo {
......@@ -23,9 +26,12 @@ public class CheckLayerVariableDeclarationIsUsed extends CNNArchSymbolCoCo {
boolean isUsed = false;
for (SerialCompositeElementSymbol stream : layerVariableDeclaration.getLayer().getArchitecture().getStreams()) {
Collection<ArchitectureElementSymbol> elements =
stream.getSpannedScope().resolveMany(layerVariableDeclaration.getName(), ArchitectureElementSymbol.KIND);
Set<String> allowedUnusedLayers = new HashSet();
allowedUnusedLayers.add("attention");
for (NetworkInstructionSymbol networkInstruction : layerVariableDeclaration.getLayer().getArchitecture().getNetworkInstructions()) {
Collection<ArchitectureElementSymbol> elements
= networkInstruction.getBody().getSpannedScope().resolveMany(layerVariableDeclaration.getName(), ArchitectureElementSymbol.KIND);
for (ArchitectureElementSymbol element : elements) {
if (element instanceof VariableSymbol && ((VariableSymbol) element).getMember() == VariableSymbol.Member.NONE) {
......@@ -39,6 +45,10 @@ public class CheckLayerVariableDeclarationIsUsed extends CNNArchSymbolCoCo {
}
}
if(allowedUnusedLayers.contains(sym.getName())){
isUsed = true;
}
if (!isUsed) {
Log.error("0" + ErrorCodes.UNUSED_LAYER + " Unused layer. " +
"Declared layer variables need to be used as layer at least once.",
......
/**
*
* (c) https://github.com/MontiCore/monticore
*
* The license generally applicable for this project
* can be found under https://github.com/MontiCore/monticore.
*/
package de.monticore.lang.monticar.cnnarch._cocos;
import de.monticore.lang.monticar.cnnarch._ast.ASTArchArgument;
import de.monticore.lang.monticar.cnnarch._ast.ASTUnrollInstruction;
import de.monticore.lang.monticar.cnnarch._symboltable.*;
import de.monticore.lang.monticar.cnnarch.helper.ErrorCodes;
import de.se_rwth.commons.Joiners;
import de.se_rwth.commons.logging.Log;
import java.util.HashSet;
import java.util.Set;
public class CheckUnroll implements CNNArchASTUnrollInstructionCoCo{
@Override
public void check(ASTUnrollInstruction node) {
Set<String> nameSet = new HashSet<>();
for (ASTArchArgument argument : node.getArgumentsList()){
String name = argument.getName();
if (nameSet.contains(name)){
Log.error("0" + ErrorCodes.DUPLICATED_ARG + " Duplicated name: " + name +
". Multiple values assigned to the same argument."
, argument.get_SourcePositionStart());
}
else {
nameSet.add(name);
}
}
UnrollDeclarationSymbol layerDeclaration = ((UnrollInstructionSymbol) node.getSymbolOpt().get()).getDeclaration();
if (layerDeclaration == null){
ArchitectureSymbol architecture = node.getSymbolOpt().get().getEnclosingScope().<ArchitectureSymbol>resolve("", ArchitectureSymbol.KIND).get();
Log.error("0" + ErrorCodes.UNKNOWN_LAYER + " Unknown layer. " +
"Layer with name '" + node.getName() + "' does not exist. " +
"Existing layers: " + Joiners.COMMA.join(architecture.getUnrollDeclarations()) + "."
, node.get_SourcePositionStart());
}
else {
Set<String> requiredArguments = new HashSet<>();
for (ParameterSymbol param : layerDeclaration.getParameters()){
if (!param.getDefaultExpression().isPresent()){
requiredArguments.add(param.getName());
}
}
for (ASTArchArgument argument : node.getArgumentsList()){
requiredArguments.remove(argument.getName());