Commit e25a213e authored by Sebastian Nickels's avatar Sebastian Nickels

Merge branch 'develop' of...

Merge branch 'develop' of git.rwth-aachen.de:monticore/EmbeddedMontiArc/languages/CNNArchLang into develop
parents 45906f7b 57ffd104
Pipeline #151204 failed with stages
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
......@@ -25,6 +25,10 @@ grammar CNNArch extends de.monticore.CommonExpressions, de.monticore.lang.Math,
parameters:(LayerParameter || ",")* ")" "{"
body:Stream "}";
UnrollDeclaration = "unroll" "<" timeParameter:LayerParameter ">"
Name "(" parameters:(LayerParameter || ",")* ")"
"{" body:Stream "}";
IODeclaration = "def"
(in:"input" | out:"output")
type:ArchType
......@@ -53,7 +57,8 @@ grammar CNNArch extends de.monticore.CommonExpressions, de.monticore.lang.Math,
@attribute body
The architecture of the neural network.
*/
Architecture = methodDeclaration:LayerDeclaration*
Architecture = unrollDeclarations:UnrollDeclaration*
methodDeclaration:LayerDeclaration*
instructions:(Instruction || ";")+ ";";
interface Instruction;
......@@ -66,6 +71,8 @@ grammar CNNArch extends de.monticore.CommonExpressions, de.monticore.lang.Math,
Layer implements ArchitectureElement = Name "(" arguments:(ArchArgument || ",")* ")";
Unroll implements ArchitectureElement = Name "(" arguments:(ArchArgument || ",")* ")";
ParallelBlock implements ArchitectureElement = "("
groups:Stream "|"
groups:(Stream || "|")+ ")";
......
......@@ -33,6 +33,9 @@ public class CNNArchSymbolCoCo {
else if (sym instanceof LayerDeclarationSymbol){
check((LayerDeclarationSymbol) sym);
}
else if (sym instanceof UnrollDeclarationSymbol){
check((UnrollDeclarationSymbol) sym);
}
else if (sym instanceof ArchitectureElementSymbol){
check((ArchitectureElementSymbol) sym);
}
......@@ -72,6 +75,10 @@ public class CNNArchSymbolCoCo {
//Override if needed
}
public void check(UnrollDeclarationSymbol sym){
//Override if needed
}
public void check(ArchitectureElementSymbol sym){
//Override if needed
}
......
/**
*
* ******************************************************************************
* MontiCAR Modeling Family, www.se-rwth.de
* Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
* All rights reserved.
*
* This project is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 3.0 of the License, or (at your option) any later version.
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package de.monticore.lang.monticar.cnnarch._cocos;
import de.monticore.lang.monticar.cnnarch._ast.ASTArchArgument;
import de.monticore.lang.monticar.cnnarch._ast.ASTLayer;
import de.monticore.lang.monticar.cnnarch._ast.ASTUnroll;
import de.monticore.lang.monticar.cnnarch._symboltable.*;
import de.monticore.lang.monticar.cnnarch.helper.ErrorCodes;
import de.se_rwth.commons.Joiners;
import de.se_rwth.commons.logging.Log;
import java.util.HashSet;
import java.util.Set;
public class CheckUnroll implements CNNArchASTUnrollCoCo{
@Override
public void check(ASTUnroll node) {
Set<String> nameSet = new HashSet<>();
for (ASTArchArgument argument : node.getArgumentsList()){
String name = argument.getName();
if (nameSet.contains(name)){
Log.error("0" + ErrorCodes.DUPLICATED_ARG + " Duplicated name: " + name +
". Multiple values assigned to the same argument."
, argument.get_SourcePositionStart());
}
else {
nameSet.add(name);
}
}
LayerDeclarationSymbol layerDeclaration = ((LayerSymbol) node.getSymbolOpt().get()).getDeclaration();
if (layerDeclaration == null){
ArchitectureSymbol architecture = node.getSymbolOpt().get().getEnclosingScope().<ArchitectureSymbol>resolve("", ArchitectureSymbol.KIND).get();
Log.error("0" + ErrorCodes.UNKNOWN_LAYER + " Unknown layer. " +
"Layer with name '" + node.getName() + "' does not exist. " +
"Existing layers: " + Joiners.COMMA.join(architecture.getLayerDeclarations()) + "."
, node.get_SourcePositionStart());
}
else {
Set<String> requiredArguments = new HashSet<>();
for (VariableSymbol param : layerDeclaration.getParameters()){
if (!param.getDefaultExpression().isPresent()){
requiredArguments.add(param.getName());
}
}
for (ASTArchArgument argument : node.getArgumentsList()){
requiredArguments.remove(argument.getName());
}
for (String missingArgumentName : requiredArguments){
Log.error("0"+ErrorCodes.MISSING_ARGUMENT + " Missing argument. " +
"The argument '" + missingArgumentName + "' is required."
, node.get_SourcePositionStart());
}
}
}
}
......@@ -242,6 +242,29 @@ public class CNNArchSymbolTableCreator extends de.monticore.symboltable.CommonSy
removeCurrentScope();
}
@Override
public void visit(ASTUnrollDeclaration ast) {
UnrollDeclarationSymbol unrollDeclaration = new UnrollDeclarationSymbol(ast.getName());
addToScopeAndLinkWithNode(unrollDeclaration, ast);
}
@Override
public void endVisit(ASTUnrollDeclaration ast) {
UnrollDeclarationSymbol unrollDeclaration = (UnrollDeclarationSymbol) ast.getSymbolOpt().get();
unrollDeclaration.setBody((CompositeElementSymbol) ast.getBody().getSymbolOpt().get());
List<VariableSymbol> parameters = new ArrayList<>(4);
for (ASTLayerParameter astParam : ast.getParametersList()){
VariableSymbol parameter = (VariableSymbol) astParam.getSymbolOpt().get();
parameters.add(parameter);
}
unrollDeclaration.setParameters(parameters);
removeCurrentScope();
}
@Override
public void visit(ASTLayerParameter ast) {
VariableSymbol variable = new VariableSymbol(ast.getName());
......
/**
*
* ******************************************************************************
* MontiCAR Modeling Family, www.se-rwth.de
* Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
* All rights reserved.
*
* This project is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 3.0 of the License, or (at your option) any later version.
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package de.monticore.lang.monticar.cnnarch._symboltable;
import de.monticore.symboltable.SymbolKind;
public class UnrollDeclarationKind implements SymbolKind {
private static final String NAME = "de.monticore.lang.monticar.cnnarch._symboltable.UnrollDeclarationKind";
@Override
public String getName() {
return NAME;
}
@Override
public boolean isKindOf(SymbolKind kind) {
return NAME.equals(kind.getName()) || SymbolKind.super.isKindOf(kind);
}
}
/**
*
* ******************************************************************************
* MontiCAR Modeling Family, www.se-rwth.de
* Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
* All rights reserved.
*
* This project is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 3.0 of the License, or (at your option) any later version.
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package de.monticore.lang.monticar.cnnarch._symboltable;
import de.monticore.symboltable.MutableScope;
import de.monticore.symboltable.Symbol;
import java.util.Optional;
public class UnrollDeclarationScope extends de.monticore.symboltable.CommonScope {
public UnrollDeclarationScope() {
super(true);
}
public UnrollDeclarationScope(Optional<MutableScope> enclosingScope) {
super(enclosingScope, true);
}
@Override
public void add(Symbol symbol) {
super.add(symbol);
if (symbol instanceof ArchitectureElementSymbol){
ArchitectureElementScope subScope = ((ArchitectureElementSymbol) symbol).getSpannedScope();
addSubScope(subScope);
subScope.setResolvingFilters(getResolvingFilters());
}
}
}
\ No newline at end of file
/**
*
* ******************************************************************************
* MontiCAR Modeling Family, www.se-rwth.de
* Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
* All rights reserved.
*
* This project is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 3.0 of the License, or (at your option) any later version.
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
/* generated by template symboltable.ScopeSpanningSymbol*/
package de.monticore.lang.monticar.cnnarch._symboltable;
import de.monticore.lang.monticar.cnnarch.predefined.AllPredefinedVariables;
import de.monticore.symboltable.CommonScopeSpanningSymbol;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
public class UnrollDeclarationSymbol extends CommonScopeSpanningSymbol {
public static final UnrollDeclarationKind KIND = new UnrollDeclarationKind();
private List<VariableSymbol> parameters;
private CompositeElementSymbol body;
protected UnrollDeclarationSymbol(String name) {
super(name, KIND);
}
@Override
protected UnrollDeclarationScope createSpannedScope() {
return new UnrollDeclarationScope();
}
@Override
public UnrollDeclarationScope getSpannedScope() {
return (UnrollDeclarationScope) super.getSpannedScope();
}
public List<VariableSymbol> getParameters() {
return parameters;
}
protected void setParameters(List<VariableSymbol> parameters) {
this.parameters = parameters;
if (!getParameter(AllPredefinedVariables.CONDITIONAL_ARG_NAME).isPresent()){
VariableSymbol ifParam = AllPredefinedVariables.createConditionalParameter();
this.parameters.add(ifParam);
ifParam.putInScope(getSpannedScope());
}
if (!getParameter(AllPredefinedVariables.SERIAL_ARG_NAME).isPresent()){
VariableSymbol forParam = AllPredefinedVariables.createSerialParameter();
this.parameters.add(forParam);
forParam.putInScope(getSpannedScope());
}
if (!getParameter(AllPredefinedVariables.PARALLEL_ARG_NAME).isPresent()){
VariableSymbol forParam = AllPredefinedVariables.createParallelParameter();
this.parameters.add(forParam);
forParam.putInScope(getSpannedScope());
}
}
public CompositeElementSymbol getBody() {
return body;
}
protected void setBody(CompositeElementSymbol body) {
this.body = body;
}
public boolean isPredefined() {
//Override by PredefinedUnrollDeclaration
return false;
}
public Optional<VariableSymbol> getParameter(String name) {
Optional<VariableSymbol> res = Optional.empty();
for (VariableSymbol parameter : getParameters()){
if (parameter.getName().equals(name)){
res = Optional.of(parameter);
}
}
return res;
}
public ArchitectureElementSymbol call(UnrollSymbol layer) throws ArchResolveException{
checkForSequence(layer.getArguments());
if (isPredefined()){
return layer;
}
else {
reset();
set(layer.getArguments());
CompositeElementSymbol copy = getBody().preResolveDeepCopy();
copy.putInScope(getSpannedScope());
copy.resolveOrError();
getSpannedScope().remove(copy);
getSpannedScope().removeSubScope(copy.getSpannedScope());
reset();
return copy;
}
}
private void reset(){
for (VariableSymbol param : getParameters()){
param.reset();
}
}
private void set(List<ArgumentSymbol> arguments){
for (ArgumentSymbol arg : arguments){
arg.set();
}
}
private void checkForSequence(List<ArgumentSymbol> arguments){
boolean valid = true;
for (ArgumentSymbol arg : arguments){
if (arg.getRhs() instanceof ArchAbstractSequenceExpression){
valid = false;
}
}
if (!valid){
throw new IllegalArgumentException("Arguments with sequence expressions have to be resolved first before calling the layer method.");
}
}
public UnrollDeclarationSymbol deepCopy() {
UnrollDeclarationSymbol copy = new UnrollDeclarationSymbol(getName());
if (getAstNode().isPresent()){
copy.setAstNode(getAstNode().get());
}
List<VariableSymbol> parameterCopies = new ArrayList<>(getParameters().size());
for (VariableSymbol parameter : getParameters()){
VariableSymbol parameterCopy = parameter.deepCopy();
parameterCopies.add(parameterCopy);
parameterCopy.putInScope(copy.getSpannedScope());
}
copy.setParameters(parameterCopies);
copy.setBody(getBody().preResolveDeepCopy());
copy.getBody().putInScope(copy.getSpannedScope());
return copy;
}
/*public static class Builder{
private List<VariableSymbol> parameters = new ArrayList<>();
private CompositeElementSymbol body;
private String name = "";
public Builder parameters(List<VariableSymbol> parameters) {
this.parameters = parameters;
return this;
}
public Builder parameters(VariableSymbol... parameters) {
this.parameters = new ArrayList<>(Arrays.asList(parameters));
return this;
}
public Builder body(CompositeElementSymbol body) {
this.body = body;
return this;
}
public Builder name(String name) {
this.name = name;
return this;
}
public UnrollDeclarationSymbol build(){
if (name == null || name.equals("")){
throw new IllegalStateException("Missing or empty name for UnrollDeclarationSymbol");
}
UnrollDeclarationSymbol sym = new UnrollDeclarationSymbol(name);
sym.setBody(body);
if (body != null){
body.putInScope(sym.getSpannedScope());
}
for (VariableSymbol param : parameters){
param.putInScope(sym.getSpannedScope());
}
sym.setParameters(parameters);
return sym;
}
}*/
}
......@@ -62,7 +62,8 @@ public class AllPredefinedLayers {
public static final String BETA_NAME = "beta";
public static final String PADDING_NAME = "padding";
public static final String POOL_TYPE_NAME = "pool_type";
public static final String SIZE_NAME = "size";
public static final String ONE_HOT_SIZE_NAME = "size";
//possible String values
public static final String PADDING_VALID = "valid";
......@@ -90,7 +91,8 @@ public class AllPredefinedLayers {
Split.create(),
Get.create(),
Add.create(),
Concatenate.create());
Concatenate.create(),
OneHot.create());
}
}
......@@ -38,7 +38,7 @@ public class OneHot extends PredefinedLayerDeclaration {
public List<ArchTypeSymbol> computeOutputTypes(List<ArchTypeSymbol> inputTypes, LayerSymbol layer) {
return Collections.singletonList(new ArchTypeSymbol.Builder()
.channels(layer.getIntValue(AllPredefinedLayers.SIZE_NAME).get())
.channels(layer.getIntValue(AllPredefinedLayers.ONE_HOT_SIZE_NAME).get())
.height(1)
.width(1)
.elementType("0", "1")
......@@ -54,7 +54,7 @@ public class OneHot extends PredefinedLayerDeclaration {
OneHot declaration = new OneHot();
List<VariableSymbol> parameters = new ArrayList<>(Arrays.asList(
new VariableSymbol.Builder()
.name(AllPredefinedLayers.SIZE_NAME)
.name(AllPredefinedLayers.ONE_HOT_SIZE_NAME)
.constraints(Constraints.POSITIVE, Constraints.INTEGER)
.build()));
declaration.setParameters(parameters);
......
......@@ -66,6 +66,7 @@ public class AllCoCoTest extends AbstractCoCoTest {
checkValid("valid_tests", "Alexnet_alt2");
checkValid("valid_tests", "MultipleOutputs");
checkValid("valid_tests", "MultipleStreams");
checkValid("valid_tests", "Alexnet_alt_OneHotOutput");
}
@Test
......
......@@ -2,6 +2,15 @@ architecture Alexnet(img_height=224, img_width=224, img_channels=3, classes=10){
def input Z(0:255)^{img_channels, img_height, img_width} data
def output Q(0:1)^{classes} predictions
unroll<t=5> beamSearchStart (width=5, max_length=50){
FullyConnected(units=4096) ->
Relu() ->
Dropout()
}
def split1(i){
[i] ->
Convolution(kernel=(5,5), channels=128) ->
......@@ -23,6 +32,9 @@ architecture Alexnet(img_height=224, img_width=224, img_channels=3, classes=10){
Dropout()
}
data ->
Convolution(kernel=(11,11), channels=96, stride=(4,4), padding="no_loss") ->
Lrn(nsize=5, alpha=0.0001, beta=0.75) ->
......@@ -36,7 +48,6 @@ architecture Alexnet(img_height=224, img_width=224, img_channels=3, classes=10){
Split(n=2) ->
split2(i=[0|1]) ->
Concatenate() ->
fc(->=2) ->
FullyConnected(units=10) ->
Softmax() ->
predictions;
......
architecture Alexnet_alt_OneHotOutput(img_height=224, img_width=224, img_channels=3, classes=10){
def input Z(0:255)^{img_channels, img_height, img_width} image
def output Q(0:1)^{classes} predictions
image ->
Convolution(kernel=(11,11), channels=96, stride=(4,4), padding="no_loss") ->
Lrn(nsize=5, alpha=0.0001, beta=0.75) ->
Pooling(pool_type="max", kernel=(3,3), stride=(2,2), padding="no_loss") ->
Relu() ->
Split(n=2) ->
(
[0] ->
Convolution(kernel=(5,5), channels=128) ->
Lrn(nsize=5, alpha=0.0001, beta=0.75) ->
Pooling(pool_type="max", kernel=(3,3), stride=(2,2), padding="no_loss") ->
Relu()
|
[1] ->
Convolution(kernel=(5,5), channels=128) ->
Lrn(nsize=5, alpha=0.0001, beta=0.75) ->
Pooling(pool_type="max", kernel=(3,3), stride=(2,2), padding="no_loss") ->
Relu()
) ->
Concatenate() ->
Convolution(kernel=(3,3), channels=384) ->
Relu() ->
Split(n=2) ->
(
[0] ->
Convolution(kernel=(3,3), channels=192) ->
Relu() ->
Convolution(kernel=(3,3), channels=128) ->
Pooling(pool_type="max", kernel=(3,3), stride=(2,2), padding="no_loss") ->
Relu()
|
[1] ->
Convolution(kernel=(3,3), channels=192) ->
Relu() ->
Convolution(kernel=(3,3), channels=128) ->
Pooling(pool_type="max", kernel=(3,3), stride=(2,2), padding="no_loss") ->
Relu()
) ->
Concatenate() ->
FullyConnected(units=4096) ->
Relu() ->
Dropout() ->
FullyConnected(units=4096) ->
Relu() ->
Dropout() ->
FullyConnected(units=classes) ->
Softmax() ->
OneHot(size=classes) ->
predictions;
}
\ No newline at end of file
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment