Commit 78009eaf authored by Christian Fuß's avatar Christian Fuß

added OneHotLayer and some parts of the Unrolling

parent b3f4bc77
Pipeline #150120 failed with stages
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
......@@ -25,6 +25,10 @@ grammar CNNArch extends de.monticore.CommonExpressions, de.monticore.lang.Math,
parameters:(LayerParameter || ",")* ")" "{"
body:Stream "}";
UnrollDeclaration = "unroll" "<" timeParameter:LayerParameter ">"
Name "(" parameters:(LayerParameter || ",")* ")"
"{" body:Stream "}";
IODeclaration = "def"
(in:"input" | out:"output")
type:ArchType
......@@ -53,7 +57,8 @@ grammar CNNArch extends de.monticore.CommonExpressions, de.monticore.lang.Math,
@attribute body
The architecture of the neural network.
*/
Architecture = methodDeclaration:LayerDeclaration*
Architecture = unrollDeclarations:UnrollDeclaration*
methodDeclaration:LayerDeclaration*
instructions:(Instruction || ";")+ ";";
interface Instruction;
......@@ -66,6 +71,8 @@ grammar CNNArch extends de.monticore.CommonExpressions, de.monticore.lang.Math,
Layer implements ArchitectureElement = Name "(" arguments:(ArchArgument || ",")* ")";
Unroll implements ArchitectureElement = Name "(" arguments:(ArchArgument || ",")* ")";
ParallelBlock implements ArchitectureElement = "("
groups:Stream "|"
groups:(Stream || "|")+ ")";
......
......@@ -33,6 +33,9 @@ public class CNNArchSymbolCoCo {
else if (sym instanceof LayerDeclarationSymbol){
check((LayerDeclarationSymbol) sym);
}
else if (sym instanceof UnrollDeclarationSymbol){
check((UnrollDeclarationSymbol) sym);
}
else if (sym instanceof ArchitectureElementSymbol){
check((ArchitectureElementSymbol) sym);
}
......@@ -72,6 +75,10 @@ public class CNNArchSymbolCoCo {
//Override if needed
}
public void check(UnrollDeclarationSymbol sym){
//Override if needed
}
public void check(ArchitectureElementSymbol sym){
//Override if needed
}
......
/**
*
* ******************************************************************************
* MontiCAR Modeling Family, www.se-rwth.de
* Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
* All rights reserved.
*
* This project is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 3.0 of the License, or (at your option) any later version.
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package de.monticore.lang.monticar.cnnarch._cocos;
import de.monticore.lang.monticar.cnnarch._ast.ASTArchArgument;
import de.monticore.lang.monticar.cnnarch._ast.ASTLayer;
import de.monticore.lang.monticar.cnnarch._ast.ASTUnroll;
import de.monticore.lang.monticar.cnnarch._symboltable.*;
import de.monticore.lang.monticar.cnnarch.helper.ErrorCodes;
import de.se_rwth.commons.Joiners;
import de.se_rwth.commons.logging.Log;
import java.util.HashSet;
import java.util.Set;
public class CheckUnroll implements CNNArchASTUnrollCoCo{
@Override
public void check(ASTUnroll node) {
Set<String> nameSet = new HashSet<>();
for (ASTArchArgument argument : node.getArgumentsList()){
String name = argument.getName();
if (nameSet.contains(name)){
Log.error("0" + ErrorCodes.DUPLICATED_ARG + " Duplicated name: " + name +
". Multiple values assigned to the same argument."
, argument.get_SourcePositionStart());
}
else {
nameSet.add(name);
}
}
LayerDeclarationSymbol layerDeclaration = ((LayerSymbol) node.getSymbolOpt().get()).getDeclaration();
if (layerDeclaration == null){
ArchitectureSymbol architecture = node.getSymbolOpt().get().getEnclosingScope().<ArchitectureSymbol>resolve("", ArchitectureSymbol.KIND).get();
Log.error("0" + ErrorCodes.UNKNOWN_LAYER + " Unknown layer. " +
"Layer with name '" + node.getName() + "' does not exist. " +
"Existing layers: " + Joiners.COMMA.join(architecture.getLayerDeclarations()) + "."
, node.get_SourcePositionStart());
}
else {
Set<String> requiredArguments = new HashSet<>();
for (VariableSymbol param : layerDeclaration.getParameters()){
if (!param.getDefaultExpression().isPresent()){
requiredArguments.add(param.getName());
}
}
for (ASTArchArgument argument : node.getArgumentsList()){
requiredArguments.remove(argument.getName());
}
for (String missingArgumentName : requiredArguments){
Log.error("0"+ErrorCodes.MISSING_ARGUMENT + " Missing argument. " +
"The argument '" + missingArgumentName + "' is required."
, node.get_SourcePositionStart());
}
}
}
}
......@@ -242,6 +242,29 @@ public class CNNArchSymbolTableCreator extends de.monticore.symboltable.CommonSy
removeCurrentScope();
}
@Override
public void visit(ASTUnrollDeclaration ast) {
UnrollDeclarationSymbol unrollDeclaration = new UnrollDeclarationSymbol(ast.getName());
addToScopeAndLinkWithNode(unrollDeclaration, ast);
}
@Override
public void endVisit(ASTUnrollDeclaration ast) {
UnrollDeclarationSymbol unrollDeclaration = (UnrollDeclarationSymbol) ast.getSymbolOpt().get();
unrollDeclaration.setBody((CompositeElementSymbol) ast.getBody().getSymbolOpt().get());
List<VariableSymbol> parameters = new ArrayList<>(4);
for (ASTLayerParameter astParam : ast.getParametersList()){
VariableSymbol parameter = (VariableSymbol) astParam.getSymbolOpt().get();
parameters.add(parameter);
}
unrollDeclaration.setParameters(parameters);
removeCurrentScope();
}
@Override
public void visit(ASTLayerParameter ast) {
VariableSymbol variable = new VariableSymbol(ast.getName());
......
/**
*
* ******************************************************************************
* MontiCAR Modeling Family, www.se-rwth.de
* Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
* All rights reserved.
*
* This project is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 3.0 of the License, or (at your option) any later version.
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package de.monticore.lang.monticar.cnnarch._symboltable;
import de.monticore.symboltable.SymbolKind;
public class UnrollDeclarationKind implements SymbolKind {
private static final String NAME = "de.monticore.lang.monticar.cnnarch._symboltable.UnrollDeclarationKind";
@Override
public String getName() {
return NAME;
}
@Override
public boolean isKindOf(SymbolKind kind) {
return NAME.equals(kind.getName()) || SymbolKind.super.isKindOf(kind);
}
}
/**
*
* ******************************************************************************
* MontiCAR Modeling Family, www.se-rwth.de
* Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
* All rights reserved.
*
* This project is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 3.0 of the License, or (at your option) any later version.
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package de.monticore.lang.monticar.cnnarch._symboltable;
import de.monticore.symboltable.MutableScope;
import de.monticore.symboltable.Symbol;
import java.util.Optional;
public class UnrollDeclarationScope extends de.monticore.symboltable.CommonScope {
public UnrollDeclarationScope() {
super(true);
}
public UnrollDeclarationScope(Optional<MutableScope> enclosingScope) {
super(enclosingScope, true);
}
@Override
public void add(Symbol symbol) {
super.add(symbol);
if (symbol instanceof ArchitectureElementSymbol){
ArchitectureElementScope subScope = ((ArchitectureElementSymbol) symbol).getSpannedScope();
addSubScope(subScope);
subScope.setResolvingFilters(getResolvingFilters());
}
}
}
\ No newline at end of file
/**
*
* ******************************************************************************
* MontiCAR Modeling Family, www.se-rwth.de
* Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
* All rights reserved.
*
* This project is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 3.0 of the License, or (at your option) any later version.
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
/* generated by template symboltable.ScopeSpanningSymbol*/
package de.monticore.lang.monticar.cnnarch._symboltable;
import de.monticore.lang.monticar.cnnarch.predefined.AllPredefinedVariables;
import de.monticore.symboltable.CommonScopeSpanningSymbol;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
public class UnrollDeclarationSymbol extends CommonScopeSpanningSymbol {
public static final UnrollDeclarationKind KIND = new UnrollDeclarationKind();
private List<VariableSymbol> parameters;
private CompositeElementSymbol body;
protected UnrollDeclarationSymbol(String name) {
super(name, KIND);
}
@Override
protected UnrollDeclarationScope createSpannedScope() {
return new UnrollDeclarationScope();
}
@Override
public UnrollDeclarationScope getSpannedScope() {
return (UnrollDeclarationScope) super.getSpannedScope();
}
public List<VariableSymbol> getParameters() {
return parameters;
}
protected void setParameters(List<VariableSymbol> parameters) {
this.parameters = parameters;
if (!getParameter(AllPredefinedVariables.CONDITIONAL_ARG_NAME).isPresent()){
VariableSymbol ifParam = AllPredefinedVariables.createConditionalParameter();
this.parameters.add(ifParam);
ifParam.putInScope(getSpannedScope());
}
if (!getParameter(AllPredefinedVariables.SERIAL_ARG_NAME).isPresent()){
VariableSymbol forParam = AllPredefinedVariables.createSerialParameter();
this.parameters.add(forParam);
forParam.putInScope(getSpannedScope());
}
if (!getParameter(AllPredefinedVariables.PARALLEL_ARG_NAME).isPresent()){
VariableSymbol forParam = AllPredefinedVariables.createParallelParameter();
this.parameters.add(forParam);
forParam.putInScope(getSpannedScope());
}
}
public CompositeElementSymbol getBody() {
return body;
}
protected void setBody(CompositeElementSymbol body) {
this.body = body;
}
public boolean isPredefined() {
//Override by PredefinedUnrollDeclaration
return false;
}
public Optional<VariableSymbol> getParameter(String name) {
Optional<VariableSymbol> res = Optional.empty();
for (VariableSymbol parameter : getParameters()){
if (parameter.getName().equals(name)){
res = Optional.of(parameter);
}
}
return res;
}
public ArchitectureElementSymbol call(UnrollSymbol layer) throws ArchResolveException{
checkForSequence(layer.getArguments());
if (isPredefined()){
return layer;
}
else {
reset();
set(layer.getArguments());
CompositeElementSymbol copy = getBody().preResolveDeepCopy();
copy.putInScope(getSpannedScope());
copy.resolveOrError();
getSpannedScope().remove(copy);
getSpannedScope().removeSubScope(copy.getSpannedScope());
reset();
return copy;
}
}
private void reset(){
for (VariableSymbol param : getParameters()){
param.reset();
}
}
private void set(List<ArgumentSymbol> arguments){
for (ArgumentSymbol arg : arguments){
arg.set();
}
}
private void checkForSequence(List<ArgumentSymbol> arguments){
boolean valid = true;
for (ArgumentSymbol arg : arguments){
if (arg.getRhs() instanceof ArchAbstractSequenceExpression){
valid = false;
}
}
if (!valid){
throw new IllegalArgumentException("Arguments with sequence expressions have to be resolved first before calling the layer method.");
}
}
public UnrollDeclarationSymbol deepCopy() {
UnrollDeclarationSymbol copy = new UnrollDeclarationSymbol(getName());
if (getAstNode().isPresent()){
copy.setAstNode(getAstNode().get());
}
List<VariableSymbol> parameterCopies = new ArrayList<>(getParameters().size());
for (VariableSymbol parameter : getParameters()){
VariableSymbol parameterCopy = parameter.deepCopy();
parameterCopies.add(parameterCopy);
parameterCopy.putInScope(copy.getSpannedScope());
}
copy.setParameters(parameterCopies);
copy.setBody(getBody().preResolveDeepCopy());
copy.getBody().putInScope(copy.getSpannedScope());
return copy;
}
/*public static class Builder{
private List<VariableSymbol> parameters = new ArrayList<>();
private CompositeElementSymbol body;
private String name = "";
public Builder parameters(List<VariableSymbol> parameters) {
this.parameters = parameters;
return this;
}
public Builder parameters(VariableSymbol... parameters) {
this.parameters = new ArrayList<>(Arrays.asList(parameters));
return this;
}
public Builder body(CompositeElementSymbol body) {
this.body = body;
return this;
}
public Builder name(String name) {
this.name = name;
return this;
}
public UnrollDeclarationSymbol build(){
if (name == null || name.equals("")){
throw new IllegalStateException("Missing or empty name for UnrollDeclarationSymbol");
}
UnrollDeclarationSymbol sym = new UnrollDeclarationSymbol(name);
sym.setBody(body);
if (body != null){
body.putInScope(sym.getSpannedScope());
}
for (VariableSymbol param : parameters){
param.putInScope(sym.getSpannedScope());
}
sym.setParameters(parameters);
return sym;
}
}*/
}
......@@ -44,6 +44,7 @@ public class AllPredefinedLayers {
public static final String ADD_NAME = "Add";
public static final String CONCATENATE_NAME = "Concatenate";
public static final String FLATTEN_NAME = "Flatten";
public static final String ONE_HOT_NAME = "OneHot";
//predefined argument names
public static final String KERNEL_NAME = "kernel";
......@@ -61,6 +62,7 @@ public class AllPredefinedLayers {
public static final String BETA_NAME = "beta";
public static final String PADDING_NAME = "padding";
public static final String POOL_TYPE_NAME = "pool_type";
public static final String ONE_HOT_SIZE_NAME = "size";
//possible String values
......@@ -89,7 +91,8 @@ public class AllPredefinedLayers {
Split.create(),
Get.create(),
Add.create(),
Concatenate.create());
Concatenate.create(),
OneHot.create());
}
}
/**
*
* ******************************************************************************
* MontiCAR Modeling Family, www.se-rwth.de
* Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
* All rights reserved.
*
* This project is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 3.0 of the License, or (at your option) any later version.
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package de.monticore.lang.monticar.cnnarch.predefined;
import de.monticore.lang.monticar.cnnarch._symboltable.*;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
public class OneHot extends PredefinedLayerDeclaration {
private OneHot() {
super(AllPredefinedLayers.ONE_HOT_NAME);
}
@Override
public List<ArchTypeSymbol> computeOutputTypes(List<ArchTypeSymbol> inputTypes, LayerSymbol layer) {
return Collections.singletonList(new ArchTypeSymbol.Builder()
.channels(layer.getIntValue(AllPredefinedLayers.ONE_HOT_SIZE_NAME).get())
.height(1)
.width(1)
.elementType("0", "1")
.build());
}
@Override
public void checkInput(List<ArchTypeSymbol> inputTypes, LayerSymbol layer) {
errorIfInputSizeIsNotOne(inputTypes, layer);
}
public static OneHot create(){
OneHot declaration = new OneHot();
List<VariableSymbol> parameters = new ArrayList<>(Arrays.asList(
new VariableSymbol.Builder()
.name(AllPredefinedLayers.ONE_HOT_SIZE_NAME)
.constraints(Constraints.POSITIVE, Constraints.INTEGER)
.build()));
declaration.setParameters(parameters);
return declaration;
}
}
......@@ -66,6 +66,7 @@ public class AllCoCoTest extends AbstractCoCoTest {
checkValid("valid_tests", "Alexnet_alt2");
checkValid("valid_tests", "MultipleOutputs");
checkValid("valid_tests", "MultipleStreams");
checkValid("valid_tests", "Alexnet_alt_OneHotOutput");
}
@Test
......
......@@ -2,6 +2,15 @@ architecture Alexnet(img_height=224, img_width=224, img_channels=3, classes=10){
def input Z(0:255)^{img_channels, img_height, img_width} data
def output Q(0:1)^{classes} predictions
unroll<t=5> beamSearchStart (width=5, max_length=50){
FullyConnected(units=4096) ->
Relu() ->
Dropout()
}
def split1(i){
[i] ->
Convolution(kernel=(5,5), channels=128) ->
......@@ -23,6 +32,9 @@ architecture Alexnet(img_height=224, img_width=224, img_channels=3, classes=10){
Dropout()
}
data ->
Convolution(kernel=(11,11), channels=96, stride=(4,4), padding="no_loss") ->
Lrn(nsize=5, alpha=0.0001, beta=0.75) ->
......@@ -36,7 +48,6 @@ architecture Alexnet(img_height=224, img_width=224, img_channels=3, classes=10){
Split(n=2) ->
split2(i=[0|1]) ->
Concatenate() ->
fc(->=2) ->
FullyConnected(units=10) ->
Softmax() ->
predictions;
......
architecture Alexnet_alt_OneHotOutput(img_height=224, img_width=224, img_channels=3, classes=10){
def input Z(0:255)^{img_channels, img_height, img_width} image
def output Q(0:1)^{classes} predictions
image ->
Convolution(kernel=(11,11), channels=96, stride=(4,4), padding="no_loss") ->
Lrn(nsize=5, alpha=0.0001, beta=0.75) ->
Pooling(pool_type="max", kernel=(3,3), stride=(2,2), padding="no_loss") ->
Relu() ->
Split(n=2) ->
(
[0] ->
Convolution(kernel=(5,5), channels=128) ->
Lrn(nsize=5, alpha=0.0001, beta=0.75) ->
Pooling(pool_type="max", kernel=(3,3), stride=(2,2), padding="no_loss") ->
Relu()
|
[1] ->
Convolution(kernel=(5,5), channels=128) ->
Lrn(nsize=5, alpha=0.0001, beta=0.75) ->
Pooling(pool_type="max", kernel=(3,3), stride=(2,2), padding="no_loss") ->
Relu()
) ->
Concatenate() ->
Convolution(kernel=(3,3), channels=384) ->
Relu() ->
Split(n=2) ->
(
[0] ->
Convolution(kernel=(3,3), channels=192) ->
Relu() ->
Convolution(kernel=(3,3), channels=128) ->
Pooling(pool_type="max", kernel=(3,3), stride=(2,2), padding="no_loss") ->
Relu()
|
[1] ->
Convolution(kernel=(3,3), channels=192) ->
Relu() ->
Convolution(kernel=(3,3), channels=128) ->
Pooling(pool_type="max", kernel=(3,3), stride=(2,2), padding="no_loss") ->
Relu()
) ->
Concatenate() ->
FullyConnected(units=4096) ->
Relu() ->
Dropout() ->
FullyConnected(units=4096) ->
Relu() ->
Dropout() ->
FullyConnected(units=classes) ->
Softmax() ->
OneHot(size=classes) ->
predictions;
}
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment