Commit 2333d2a3 authored by Christian Fuß's avatar Christian Fuß

progress on Unroll feature. Made 'size' parameter for OneHot Layer optional.

parent c23000c3
Pipeline #153406 passed with stages
in 19 minutes and 2 seconds
......@@ -71,7 +71,9 @@ grammar CNNArch extends de.monticore.CommonExpressions, de.monticore.lang.Math,
Layer implements ArchitectureElement = Name "(" arguments:(ArchArgument || ",")* ")";
Unroll implements ArchitectureElement = Name "(" arguments:(ArchArgument || ",")* ")";
Unroll implements ArchitectureElement = "unroll" "<" timeParameter:LayerParameter ">"
Name "(" arguments:(ArchArgument || ",")* ")"
"{" body:Stream "}";
ParallelBlock implements ArchitectureElement = "("
groups:Stream "|"
......
......@@ -23,6 +23,9 @@ package de.monticore.lang.monticar.cnnarch._cocos;
import de.monticore.lang.monticar.cnnarch._ast.ASTArchArgument;
import de.monticore.lang.monticar.cnnarch._symboltable.ArgumentSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.LayerDeclarationSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.LayerSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.UnrollDeclarationSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.UnrollSymbol;
import de.monticore.lang.monticar.cnnarch.helper.ErrorCodes;
import de.se_rwth.commons.Joiners;
import de.se_rwth.commons.logging.Log;
......@@ -32,12 +35,24 @@ public class CheckArgument implements CNNArchASTArchArgumentCoCo {
@Override
public void check(ASTArchArgument node) {
ArgumentSymbol argument = (ArgumentSymbol) node.getSymbolOpt().get();
LayerDeclarationSymbol layerDeclaration = argument.getLayer().getDeclaration();
if (layerDeclaration != null && argument.getParameter() == null){
Log.error("0"+ ErrorCodes.UNKNOWN_ARGUMENT + " Unknown Argument. " +
"Parameter with name '" + node.getName() + "' does not exist. " +
"Possible arguments are: " + Joiners.COMMA.join(layerDeclaration.getParameters())
, node.get_SourcePositionStart());
if(argument.getEnclosingScope().getSpanningSymbol().get() instanceof LayerSymbol) {
LayerDeclarationSymbol layerDeclaration = argument.getLayer().getDeclaration();
if (layerDeclaration != null && argument.getParameter() == null){
Log.error("0"+ ErrorCodes.UNKNOWN_ARGUMENT + " Unknown Argument. " +
"Parameter with name '" + node.getName() + "' does not exist. " +
"Possible arguments are: " + Joiners.COMMA.join(layerDeclaration.getParameters())
, node.get_SourcePositionStart());
}
}else if(argument.getEnclosingScope().getSpanningSymbol().get() instanceof UnrollSymbol){
UnrollDeclarationSymbol layerDeclaration = argument.getUnroll().getDeclaration();
if (layerDeclaration != null && argument.getParameter() == null){
Log.error("0"+ ErrorCodes.UNKNOWN_ARGUMENT + " Unknown Argument. " +
"Parameter with name '" + node.getName() + "' does not exist. " +
"Possible arguments are: " + Joiners.COMMA.join(layerDeclaration.getParameters())
, node.get_SourcePositionStart());
}
}
}
......
......@@ -20,8 +20,10 @@
*/
package de.monticore.lang.monticar.cnnarch._cocos;
import de.monticore.commonexpressions._ast.ASTArguments;
import de.monticore.lang.monticar.cnnarch._ast.ASTArchArgument;
import de.monticore.lang.monticar.cnnarch._ast.ASTLayer;
import de.monticore.lang.monticar.cnnarch._ast.ASTLayerParameter;
import de.monticore.lang.monticar.cnnarch._ast.ASTUnroll;
import de.monticore.lang.monticar.cnnarch._symboltable.*;
import de.monticore.lang.monticar.cnnarch.helper.ErrorCodes;
......
......@@ -59,6 +59,10 @@ public class ArgumentSymbol extends CommonSymbol {
return (LayerSymbol) getEnclosingScope().getSpanningSymbol().get();
}
public UnrollSymbol getUnroll() {
return (UnrollSymbol) getEnclosingScope().getSpanningSymbol().get();
}
public ArchExpressionSymbol getRhs() {
return rhs;
}
......
......@@ -340,6 +340,39 @@ public class CNNArchSymbolTableCreator extends de.monticore.symboltable.CommonSy
sym.setElements(elements);
}
@Override
public void visit(ASTUnroll ast) {
ast.setName("BeamSearchStart");
UnrollSymbol layer = new UnrollSymbol("BeamSearchStart");
addToScopeAndLinkWithNode(layer, ast);
}
@Override
public void endVisit(ASTUnroll ast) {
UnrollSymbol layer = (UnrollSymbol) ast.getSymbolOpt().get();
List<ArgumentSymbol> arguments = new ArrayList<>(6);
for (ASTArchArgument astArgument : ast.getArgumentsList()){
Optional<ArgumentSymbol> optArgument = astArgument.getSymbolOpt().map(e -> (ArgumentSymbol)e);
optArgument.ifPresent(arguments::add);
}
layer.setArguments(arguments);
/*List<ArchitectureElementSymbol> elements = new ArrayList<>();
for (ASTStream astStream : ast.getGroupsList()){
elements.add((SerialCompositeElementSymbol) astStream.getSymbolOpt().get());
}
compositeElement.setElements(elements);
*/
try{
System.err.println("############################" + layer.getIntValue(AllPredefinedLayers.BEAMSEARCH_MAX_LENGTH).get());
}catch(Exception e){};
removeCurrentScope();
}
@Override
public void visit(ASTParallelBlock node) {
ParallelCompositeElementSymbol compositeElement = new ParallelCompositeElementSymbol();
......
......@@ -116,6 +116,22 @@ abstract public class PredefinedLayerDeclaration extends LayerDeclarationSymbol
}
}
//check input for onehot layer
protected static void errorIfInputSizeUnequalToOnehotSize(List<ArchTypeSymbol> inputTypes, LayerSymbol layer){
if (!inputTypes.isEmpty() && layer.getIntValue(AllPredefinedLayers.ONE_HOT_SIZE_NAME).get() != 0) {
int inputChannels = inputTypes.get(0).getChannels();
int onehotSize = layer.getIntValue(AllPredefinedLayers.ONE_HOT_SIZE_NAME).get();
if (onehotSize != inputChannels){
Log.error("0" + ErrorCodes.INVALID_ELEMENT_INPUT_SHAPE +
"The size of the onehot vector is not equal to the output size of the previous layer." +
"This is usually not intended."
, layer.getSourcePosition());
}
}
}
//output type function for convolution and pooling
protected static List<ArchTypeSymbol> computeConvAndPoolOutputShape(ArchTypeSymbol inputType, LayerSymbol method, int channels) {
String borderModeSetting = method.getStringValue(AllPredefinedLayers.PADDING_NAME).get();
......
......@@ -65,7 +65,7 @@ abstract public class PredefinedUnrollDeclaration extends UnrollDeclarationSymbo
//the following methods are only here to avoid duplication. They are used by multiple subclasses.
//check if inputTypes is of size 1
protected void errorIfInputSizeIsNotOne(List<ArchTypeSymbol> inputTypes, LayerSymbol layer){
protected void errorIfInputSizeIsNotOne(List<ArchTypeSymbol> inputTypes, UnrollSymbol layer){
if (inputTypes.size() != 1){
Log.error("0" + ErrorCodes.INVALID_ELEMENT_INPUT_SHAPE + " Invalid layer input. " +
getName() + " layer can only handle one input stream. " +
......
......@@ -45,6 +45,7 @@ public class AllPredefinedLayers {
public static final String CONCATENATE_NAME = "Concatenate";
public static final String FLATTEN_NAME = "Flatten";
public static final String ONE_HOT_NAME = "OneHot";
public static final String BEAMSEARCH_NAME = "BeamSearchStart";
//predefined argument names
public static final String KERNEL_NAME = "kernel";
......@@ -63,6 +64,8 @@ public class AllPredefinedLayers {
public static final String PADDING_NAME = "padding";
public static final String POOL_TYPE_NAME = "pool_type";
public static final String ONE_HOT_SIZE_NAME = "size";
public static final String BEAMSEARCH_MAX_LENGTH = "max_length";
public static final String BEAMSEARCH_WIDTH_NAME = "width";
//possible String values
......
/**
*
* ******************************************************************************
* MontiCAR Modeling Family, www.se-rwth.de
* Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
* All rights reserved.
*
* This project is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 3.0 of the License, or (at your option) any later version.
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package de.monticore.lang.monticar.cnnarch.predefined;
import de.monticore.lang.monticar.cnnarch._symboltable.*;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
public class BeamSearchStart extends PredefinedLayerDeclaration {
private BeamSearchStart() {
super(AllPredefinedLayers.BEAMSEARCH_NAME);
}
@Override
public List<ArchTypeSymbol> computeOutputTypes(List<ArchTypeSymbol> inputTypes, LayerSymbol layer) {
return Collections.singletonList(new ArchTypeSymbol.Builder()
.channels(100) // TODO
.height(1)
.width(1)
.elementType("0", "1")
.build());
}
@Override
public void checkInput(List<ArchTypeSymbol> inputTypes, LayerSymbol layer) {
errorIfInputSizeIsNotOne(inputTypes, layer);
}
public static BeamSearchStart create(){
BeamSearchStart declaration = new BeamSearchStart();
List<VariableSymbol> parameters = new ArrayList<>(Arrays.asList(
new VariableSymbol.Builder()
.name(AllPredefinedLayers.BEAMSEARCH_MAX_LENGTH)
.constraints(Constraints.INTEGER, Constraints.POSITIVE)
.build(),
new VariableSymbol.Builder()
.name(AllPredefinedLayers.BEAMSEARCH_WIDTH_NAME)
.constraints(Constraints.INTEGER, Constraints.POSITIVE)
.build()));
declaration.setParameters(parameters);
return declaration;
}
}
......@@ -29,6 +29,8 @@ import java.util.List;
public class OneHot extends PredefinedLayerDeclaration {
private static int channels;
private OneHot() {
super(AllPredefinedLayers.ONE_HOT_NAME);
}
......@@ -37,6 +39,8 @@ public class OneHot extends PredefinedLayerDeclaration {
@Override
public List<ArchTypeSymbol> computeOutputTypes(List<ArchTypeSymbol> inputTypes, LayerSymbol layer) {
channels=layer.getIntValue(AllPredefinedLayers.ONE_HOT_SIZE_NAME).get();
return Collections.singletonList(new ArchTypeSymbol.Builder()
.channels(layer.getIntValue(AllPredefinedLayers.ONE_HOT_SIZE_NAME).get())
.height(1)
......@@ -48,6 +52,7 @@ public class OneHot extends PredefinedLayerDeclaration {
@Override
public void checkInput(List<ArchTypeSymbol> inputTypes, LayerSymbol layer) {
errorIfInputSizeIsNotOne(inputTypes, layer);
errorIfInputSizeUnequalToOnehotSize(inputTypes, layer);
}
public static OneHot create(){
......@@ -56,6 +61,7 @@ public class OneHot extends PredefinedLayerDeclaration {
new VariableSymbol.Builder()
.name(AllPredefinedLayers.ONE_HOT_SIZE_NAME)
.constraints(Constraints.POSITIVE, Constraints.INTEGER)
.defaultValue(channels)
.build()));
declaration.setParameters(parameters);
return declaration;
......
......@@ -39,6 +39,11 @@ architecture Alexnet(img_height=224, img_width=224, img_channels=3, classes=10){
Convolution(kernel=(11,11), channels=96, stride=(4,4), padding="no_loss") ->
Lrn(nsize=5, alpha=0.0001, beta=0.75) ->
Pooling(pool_type="max", kernel=(3,3), stride=(2,2), padding="no_loss") ->
unroll<t> BeamSearchStart (width=5, max_length=50){
FullyConnected(units=4096) ->
Relu() ->
Dropout()
} ->
Relu() ->
Split(n=2) ->
split1(i=[0|1]) ->
......@@ -50,5 +55,6 @@ architecture Alexnet(img_height=224, img_width=224, img_channels=3, classes=10){
Concatenate() ->
FullyConnected(units=10) ->
Softmax() ->
OneHot() ->
predictions;
}
\ No newline at end of file
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment