Commit 09301be7 authored by Sebastian Nickels's avatar Sebastian Nickels
Browse files

Implemented NetworkInstructionSymbol which is a parent of both...

Implemented NetworkInstructionSymbol which is a parent of both UnrollInstructionSymbol and StreamInstructionSymbol. This is done so that we can keep track of the order of the individual instructions
parent bc8e037c
Pipeline #180398 passed with stages
in 19 minutes and 54 seconds
......@@ -56,16 +56,20 @@ grammar CNNArch extends de.monticore.CommonExpressions, de.monticore.lang.Math,
Architecture = methodDeclaration:LayerDeclaration*
instructions:(Instruction || ";")+ ";";
Instruction = (LayerVariableDeclaration | Stream | Unroll);
Instruction = (LayerVariableDeclaration | NetworkInstruction);
LayerVariableDeclaration = "layer" Layer Name;
Stream = elements:(ArchitectureElement || "->")+;
interface NetworkInstruction;
StreamInstruction implements NetworkInstruction = body:Stream;
Unroll = "timed" "<" timeParameter:TimeParameter ">"
UnrollInstruction implements NetworkInstruction = "timed" "<" timeParameter:TimeParameter ">"
Name "(" arguments:(ArchArgument || ",")* ")"
"{" body:Stream "}";
Stream = elements:(ArchitectureElement || "->")+;
interface ArchitectureElement;
Variable implements ArchitectureElement = Name ("." (member:"output" | member:Name))? ("[" index:ArchSimpleExpression "]")?;
......
......@@ -36,8 +36,11 @@ public class CNNArchSymbolCoCo {
else if (sym instanceof UnrollDeclarationSymbol){
check((UnrollDeclarationSymbol) sym);
}
else if (sym instanceof UnrollSymbol){
check((UnrollSymbol) sym);
else if (sym instanceof UnrollInstructionSymbol){
check((UnrollInstructionSymbol) sym);
}
else if (sym instanceof StreamInstructionSymbol){
check((StreamInstructionSymbol) sym);
}
else if (sym instanceof ArchitectureElementSymbol){
check((ArchitectureElementSymbol) sym);
......@@ -114,7 +117,11 @@ public class CNNArchSymbolCoCo {
//Override if needed
}
public void check(UnrollSymbol sym){
public void check(UnrollInstructionSymbol sym){
//Override if needed
}
public void check(StreamInstructionSymbol sym){
//Override if needed
}
}
......@@ -21,29 +21,23 @@
package de.monticore.lang.monticar.cnnarch._cocos;
import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.CompositeElementSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.UnrollSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.NetworkInstructionSymbol;
import de.monticore.lang.monticar.cnnarch.helper.ErrorCodes;
import de.se_rwth.commons.SourcePosition;
import de.se_rwth.commons.logging.Log;
public class CheckArchitectureFinished extends CNNArchSymbolCoCo {
@Override
public void check(ArchitectureSymbol architecture) {
for (CompositeElementSymbol stream : architecture.getStreams()) {
if (!stream.getOutputTypes().isEmpty()){
for (NetworkInstructionSymbol networkInstruction : architecture.getNetworkInstructions()) {
if (!networkInstruction.getBody().getOutputTypes().isEmpty()) {
Log.error("0" + ErrorCodes.UNFINISHED_ARCHITECTURE + " The architecture is not finished. " +
"There are still open streams at the end of the architecture. "
, stream.getSourcePosition());
}
}
for (UnrollSymbol unroll : architecture.getUnrolls()) {
if (!unroll.getBody().getOutputTypes().isEmpty()){
Log.error("0" + ErrorCodes.UNFINISHED_ARCHITECTURE + " The architecture is not finished. " +
"There are still open streams at the end of the architecture. "
, architecture.getSourcePosition());
, networkInstruction.getSourcePosition());
}
}
if (architecture.getInputs().isEmpty()){
Log.error("0" + ErrorCodes.UNFINISHED_ARCHITECTURE + " The architecture has no inputs. "
, architecture.getSourcePosition());
......
......@@ -40,7 +40,7 @@ public class CheckArgument implements CNNArchASTArchArgumentCoCo {
"Possible arguments are: " + Joiners.COMMA.join(layerDeclaration.getParameters())
, node.get_SourcePositionStart());
}
}else if(argument.getEnclosingScope().getSpanningSymbol().get() instanceof UnrollSymbol){
}else if(argument.getEnclosingScope().getSpanningSymbol().get() instanceof UnrollInstructionSymbol){
UnrollDeclarationSymbol layerDeclaration = argument.getUnroll().getDeclaration();
if (layerDeclaration != null && argument.getParameter() == null){
Log.error("0"+ ErrorCodes.UNKNOWN_ARGUMENT + " Unknown Argument. " +
......
......@@ -21,19 +21,14 @@
package de.monticore.lang.monticar.cnnarch._cocos;
import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.CompositeElementSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.UnrollSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.NetworkInstructionSymbol;
public class CheckElementInputs extends CNNArchSymbolCoCo {
@Override
public void check(ArchitectureSymbol architecture) {
for (CompositeElementSymbol stream : architecture.getStreams()) {
stream.checkInput();
}
for (UnrollSymbol unroll : architecture.getUnrolls()) {
unroll.getBody().checkInput();
for (NetworkInstructionSymbol networkInstruction : architecture.getNetworkInstructions()) {
networkInstruction.getBody().checkInput();
}
}
}
......@@ -75,8 +75,10 @@ public class CheckIOAccessAndIOMissing extends CNNArchSymbolCoCo {
ArchitectureSymbol architecture = ioDeclaration.getArchitecture();
boolean isUnroll = false;
for (UnrollSymbol unroll : architecture.getUnrolls()) {
isUnroll = contains(unroll.getBody(), ioElement);
for (NetworkInstructionSymbol networkInstruction : architecture.getNetworkInstructions()) {
if (networkInstruction.isUnroll()) {
isUnroll = contains(networkInstruction.getBody(), ioElement);
}
}
// Allow invalid indices in UnrollSymbols
......
......@@ -24,12 +24,9 @@ import de.monticore.lang.monticar.cnnarch._ast.ASTArchArgument;
import de.monticore.lang.monticar.cnnarch._ast.ASTLayer;
import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.LayerDeclarationSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.UnrollDeclarationSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.LayerSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.UnrollSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.ParameterSymbol;
import de.monticore.lang.monticar.cnnarch.helper.ErrorCodes;
import de.monticore.symboltable.Symbol;
import de.se_rwth.commons.Joiners;
import de.se_rwth.commons.logging.Log;
......
......@@ -24,6 +24,7 @@ import de.monticore.lang.monticar.cnnarch._symboltable.*;
import de.monticore.lang.monticar.cnnarch.helper.ErrorCodes;
import de.se_rwth.commons.logging.Log;
import java.util.ArrayList;
import java.util.Collection;
public class CheckLayerVariableDeclarationIsUsed extends CNNArchSymbolCoCo {
......@@ -35,25 +36,9 @@ public class CheckLayerVariableDeclarationIsUsed extends CNNArchSymbolCoCo {
boolean isUsed = false;
for (SerialCompositeElementSymbol stream : layerVariableDeclaration.getLayer().getArchitecture().getStreams()) {
Collection<ArchitectureElementSymbol> elements =
stream.getSpannedScope().resolveMany(layerVariableDeclaration.getName(), ArchitectureElementSymbol.KIND);
for (ArchitectureElementSymbol element : elements) {
if (element instanceof VariableSymbol && ((VariableSymbol) element).getMember() == VariableSymbol.Member.NONE) {
isUsed = true;
break;
}
}
if (isUsed) {
break;
}
}
for (UnrollSymbol unroll : layerVariableDeclaration.getLayer().getArchitecture().getUnrolls()) {
Collection<ArchitectureElementSymbol> elements =
unroll.getBody().getSpannedScope().resolveMany(layerVariableDeclaration.getName(), ArchitectureElementSymbol.KIND);
for (NetworkInstructionSymbol networkInstruction : layerVariableDeclaration.getLayer().getArchitecture().getNetworkInstructions()) {
Collection<ArchitectureElementSymbol> elements
= networkInstruction.getBody().getSpannedScope().resolveMany(layerVariableDeclaration.getName(), ArchitectureElementSymbol.KIND);
for (ArchitectureElementSymbol element : elements) {
if (element instanceof VariableSymbol && ((VariableSymbol) element).getMember() == VariableSymbol.Member.NONE) {
......
......@@ -21,8 +21,7 @@
package de.monticore.lang.monticar.cnnarch._cocos;
import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.CompositeElementSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.UnrollSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.NetworkInstructionSymbol;
import de.monticore.lang.monticar.cnnarch.helper.ErrorCodes;
import de.se_rwth.commons.logging.Log;
......@@ -32,12 +31,8 @@ public class CheckNetworkStreamMissing extends CNNArchSymbolCoCo {
public void check(ArchitectureSymbol architecture) {
boolean hasTrainableStream = false;
for (CompositeElementSymbol stream : architecture.getStreams()) {
hasTrainableStream |= stream.isTrainable();
}
for (UnrollSymbol unroll : architecture.getUnrolls()) {
hasTrainableStream |= unroll.isTrainable();
for (NetworkInstructionSymbol networkInstruction : architecture.getNetworkInstructions()) {
hasTrainableStream |= networkInstruction.getBody().isTrainable();
}
if (!hasTrainableStream) {
......
......@@ -20,11 +20,8 @@
*/
package de.monticore.lang.monticar.cnnarch._cocos;
import de.monticore.commonexpressions._ast.ASTArguments;
import de.monticore.lang.monticar.cnnarch._ast.ASTArchArgument;
import de.monticore.lang.monticar.cnnarch._ast.ASTLayer;
import de.monticore.lang.monticar.cnnarch._ast.ASTLayerParameter;
import de.monticore.lang.monticar.cnnarch._ast.ASTUnroll;
import de.monticore.lang.monticar.cnnarch._ast.ASTUnrollInstruction;
import de.monticore.lang.monticar.cnnarch._symboltable.*;
import de.monticore.lang.monticar.cnnarch.helper.ErrorCodes;
import de.se_rwth.commons.Joiners;
......@@ -33,10 +30,10 @@ import de.se_rwth.commons.logging.Log;
import java.util.HashSet;
import java.util.Set;
public class CheckUnroll implements CNNArchASTUnrollCoCo{
public class CheckUnroll implements CNNArchASTUnrollInstructionCoCo{
@Override
public void check(ASTUnroll node) {
public void check(ASTUnrollInstruction node) {
Set<String> nameSet = new HashSet<>();
for (ASTArchArgument argument : node.getArgumentsList()){
......@@ -52,7 +49,7 @@ public class CheckUnroll implements CNNArchASTUnrollCoCo{
}
UnrollDeclarationSymbol layerDeclaration = ((UnrollSymbol) node.getSymbolOpt().get()).getDeclaration();
UnrollDeclarationSymbol layerDeclaration = ((UnrollInstructionSymbol) node.getSymbolOpt().get()).getDeclaration();
if (layerDeclaration == null){
ArchitectureSymbol architecture = node.getSymbolOpt().get().getEnclosingScope().<ArchitectureSymbol>resolve("", ArchitectureSymbol.KIND).get();
Log.error("0" + ErrorCodes.UNKNOWN_LAYER + " Unknown layer. " +
......
......@@ -29,6 +29,7 @@ import de.monticore.lang.monticar.cnnarch.predefined.AllPredefinedVariables;
import de.monticore.symboltable.CommonScopeSpanningSymbol;
import de.monticore.symboltable.Scope;
import de.monticore.symboltable.Symbol;
import org.apache.commons.math3.ml.neuralnet.Network;
import java.util.*;
......@@ -37,9 +38,7 @@ public class ArchitectureSymbol extends CommonScopeSpanningSymbol {
public static final ArchitectureKind KIND = new ArchitectureKind();
private List<LayerVariableDeclarationSymbol> layerVariableDeclarations = new ArrayList<>();
private List<SerialCompositeElementSymbol> streams = new ArrayList<>();
private List<UnrollSymbol> unrolls = new ArrayList<>();
private Map<String, IODeclarationSymbol> ioDeclarationMap = new HashMap<>();
private List<NetworkInstructionSymbol> networkInstructions = new ArrayList<>();
private List<VariableSymbol> inputs = new ArrayList<>();
private List<VariableSymbol> outputs = new ArrayList<>();
private String dataPath;
......@@ -57,20 +56,24 @@ public class ArchitectureSymbol extends CommonScopeSpanningSymbol {
this.layerVariableDeclarations = layerVariableDeclarations;
}
public List<SerialCompositeElementSymbol> getStreams() {
return streams;
public List<NetworkInstructionSymbol> getNetworkInstructions() {
return networkInstructions;
}
public void setStreams(List<SerialCompositeElementSymbol> streams) {
this.streams = streams;
public void setNetworkInstructions(List<NetworkInstructionSymbol> networkInstructions) {
this.networkInstructions = networkInstructions;
}
public List<UnrollSymbol> getUnrolls() {
return unrolls;
}
public List<SerialCompositeElementSymbol> getStreams() {
List<SerialCompositeElementSymbol> streams = new ArrayList<>();
public void setUnrolls(List<UnrollSymbol> unrolls) {
this.unrolls = unrolls;
for (NetworkInstructionSymbol networkInstruction : getNetworkInstructions()) {
if (networkInstruction.isStream()) {
streams.add(networkInstruction.getBody());
}
}
return streams;
}
public String getDataPath() {
......@@ -111,22 +114,11 @@ public class ArchitectureSymbol extends CommonScopeSpanningSymbol {
}
public void resolve() {
for (CompositeElementSymbol stream : streams) {
stream.checkIfResolvable();
for (NetworkInstructionSymbol networkInstruction : getNetworkInstructions()) {
networkInstruction.checkIfResolvable();
try {
stream.resolveOrError();
}
catch (ArchResolveException e) {
// Do nothing; error is already logged
}
}
for (UnrollSymbol unroll : unrolls) {
unroll.checkIfResolvable();
try {
unroll.resolveOrError();
networkInstruction.resolveOrError();
}
catch (ArchResolveException e) {
// Do nothing; error is already logged
......@@ -137,23 +129,18 @@ public class ArchitectureSymbol extends CommonScopeSpanningSymbol {
public boolean isResolved(){
boolean resolved = true;
for (CompositeElementSymbol stream : streams) {
resolved &= stream.isResolved();
for (NetworkInstructionSymbol networkInstruction : getNetworkInstructions()) {
resolved &= networkInstruction.isResolved();
}
return resolved;
}
public boolean isResolvable(){
boolean resolvable = true;
for (CompositeElementSymbol stream : streams) {
resolvable &= stream.isResolvable();
}
for (UnrollSymbol unroll: unrolls) {
resolvable &= unroll.isResolvable();
for (NetworkInstructionSymbol networkInstruction : getNetworkInstructions()) {
resolvable &= networkInstruction.isResolvable();
}
return resolvable;
......@@ -206,21 +193,13 @@ public class ArchitectureSymbol extends CommonScopeSpanningSymbol {
}
copy.setLayerVariableDeclarations(copyLayerVariableDeclarations);
List<SerialCompositeElementSymbol> copyStreams = new ArrayList<>();
for (SerialCompositeElementSymbol stream : getStreams()) {
SerialCompositeElementSymbol copyStream = stream.preResolveDeepCopy();
copyStream.putInScope(copy.getSpannedScope());
copyStreams.add(copyStream);
}
copy.setStreams(copyStreams);
List<UnrollSymbol> copyUnrolls = new ArrayList<>();
for (UnrollSymbol unroll : getUnrolls()) {
UnrollSymbol copyUnroll = (UnrollSymbol) unroll.preResolveDeepCopy();
copyUnroll.putInScope(copy.getSpannedScope());
copyUnrolls.add(copyUnroll);
List<NetworkInstructionSymbol> copyNetworkInstructions = new ArrayList<>();
for (NetworkInstructionSymbol networkInstruction : getNetworkInstructions()) {
NetworkInstructionSymbol copyNetworkInstruction = (NetworkInstructionSymbol) networkInstruction.preResolveDeepCopy();
copyNetworkInstruction.putInScope(copy.getSpannedScope());
copyNetworkInstructions.add(copyNetworkInstruction);
}
copy.setUnrolls(copyUnrolls);
copy.setNetworkInstructions(copyNetworkInstructions);
copy.putInScope(enclosingScopeOfCopy);
return copy;
......
......@@ -45,8 +45,8 @@ public class ArgumentSymbol extends CommonSymbol {
if (parameter == null){
Symbol spanningSymbol = getEnclosingScope().getSpanningSymbol().get();
if (spanningSymbol instanceof UnrollSymbol) {
UnrollSymbol unroll = (UnrollSymbol) getEnclosingScope().getSpanningSymbol().get();
if (spanningSymbol instanceof UnrollInstructionSymbol) {
UnrollInstructionSymbol unroll = (UnrollInstructionSymbol) getEnclosingScope().getSpanningSymbol().get();
if (unroll.getDeclaration() != null){
Optional<ParameterSymbol> optParam = unroll.getDeclaration().getParameter(getName());
......@@ -73,8 +73,8 @@ public class ArgumentSymbol extends CommonSymbol {
return (LayerSymbol) getEnclosingScope().getSpanningSymbol().get();
}
public UnrollSymbol getUnroll() {
return (UnrollSymbol) getEnclosingScope().getSpanningSymbol().get();
public UnrollInstructionSymbol getUnroll() {
return (UnrollInstructionSymbol) getEnclosingScope().getSpanningSymbol().get();
}
public ArchExpressionSymbol getRhs() {
......
......@@ -31,7 +31,6 @@ import de.monticore.lang.monticar.cnnarch.predefined.AllPredefinedVariables;
import de.monticore.symboltable.*;
import de.se_rwth.commons.logging.Log;
import java.lang.reflect.Array;
import java.util.*;
public class CNNArchSymbolTableCreator extends de.monticore.symboltable.CommonSymbolTableCreator
......@@ -144,22 +143,19 @@ public class CNNArchSymbolTableCreator extends de.monticore.symboltable.CommonSy
public void endVisit(final ASTArchitecture node) {
List<LayerVariableDeclarationSymbol> layerVariableDeclarations = new ArrayList<>();
List<SerialCompositeElementSymbol> streams = new ArrayList<>();
List<UnrollSymbol> unrolls = new ArrayList<>();
for (ASTInstruction astInstruction : node.getInstructionsList()){
List<NetworkInstructionSymbol> networkInstructions = new ArrayList<>();
for (ASTInstruction astInstruction : node.getInstructionsList()) {
if (astInstruction.isPresentLayerVariableDeclaration()) {
layerVariableDeclarations.add((LayerVariableDeclarationSymbol) astInstruction.getLayerVariableDeclaration().getSymbolOpt().get());
}
else if (astInstruction.isPresentStream()) {
streams.add((SerialCompositeElementSymbol) astInstruction.getStream().getSymbolOpt().get());
}else if(astInstruction.isPresentUnroll()) {
unrolls.add((UnrollSymbol) astInstruction.getUnroll().getSymbolOpt().get());
else if (astInstruction.isPresentNetworkInstruction()) {
networkInstructions.add((NetworkInstructionSymbol) astInstruction.getNetworkInstruction().getSymbolOpt().get());
}
}
architecture.setLayerVariableDeclarations(layerVariableDeclarations);
architecture.setStreams(streams);
architecture.setUnrolls(unrolls);
architecture.setNetworkInstructions(networkInstructions);
removeCurrentScope();
}
......@@ -351,24 +347,38 @@ public class CNNArchSymbolTableCreator extends de.monticore.symboltable.CommonSy
}
@Override
public void visit(ASTUnroll ast) {
UnrollSymbol layer = new UnrollSymbol(ast.getName());
addToScopeAndLinkWithNode(layer, ast);
public void visit(ASTUnrollInstruction ast) {
UnrollInstructionSymbol unrollInstruction = new UnrollInstructionSymbol(ast.getName());
addToScopeAndLinkWithNode(unrollInstruction, ast);
}
@Override
public void endVisit(ASTUnroll ast) {
UnrollSymbol layer = (UnrollSymbol) ast.getSymbolOpt().get();
layer.setBody((SerialCompositeElementSymbol) ast.getBody().getSymbolOpt().get());
public void endVisit(ASTUnrollInstruction ast) {
UnrollInstructionSymbol unrollInstruction = (UnrollInstructionSymbol) ast.getSymbolOpt().get();
unrollInstruction.setBody((SerialCompositeElementSymbol) ast.getBody().getSymbolOpt().get());
List<ArgumentSymbol> arguments = new ArrayList<>(6);
for (ASTArchArgument astArgument : ast.getArgumentsList()){
Optional<ArgumentSymbol> optArgument = astArgument.getSymbolOpt().map(e -> (ArgumentSymbol)e);
optArgument.ifPresent(arguments::add);
}
layer.setArguments(arguments);
unrollInstruction.setArguments(arguments);
layer.setTimeParameter((ParameterSymbol) ast.getTimeParameter().getSymbolOpt().get());
unrollInstruction.setTimeParameter((ParameterSymbol) ast.getTimeParameter().getSymbolOpt().get());
removeCurrentScope();
}
@Override
public void visit(ASTStreamInstruction ast) {
StreamInstructionSymbol streamInstruction = new StreamInstructionSymbol();
addToScopeAndLinkWithNode(streamInstruction, ast);
}
@Override
public void endVisit(ASTStreamInstruction ast) {
StreamInstructionSymbol streamInstruction = (StreamInstructionSymbol) ast.getSymbolOpt().get();
streamInstruction.setBody((SerialCompositeElementSymbol) ast.getBody().getSymbolOpt().get());
removeCurrentScope();
}
......
......@@ -78,7 +78,8 @@ public abstract class CompositeElementSymbol extends ArchitectureElementSymbol {
public Set<ParameterSymbol> resolve() throws ArchResolveException {
if (!isResolved()) {
if (isResolvable()) {
List<ArchitectureElementSymbol> resolvedElements = new ArrayList<>();
resolveExpressions();
for (ArchitectureElementSymbol element : getElements()) {
element.resolve();
}
......
/**
*
* ******************************************************************************
* MontiCAR Modeling Family, www.se-rwth.de
* Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
* All rights reserved.
*
* This project is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 3.0 of the License, or (at your option) any later version.
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package de.monticore.lang.monticar.cnnarch._symboltable;
import de.monticore.symboltable.SymbolKind;
public abstract class NetworkInstructionSymbol extends ResolvableSymbol {
private SerialCompositeElementSymbol body;
protected NetworkInstructionSymbol(String name, SymbolKind kind) {
super(name, kind);
}
public SerialCompositeElementSymbol getBody() {
return body;
}
protected void setBody(SerialCompositeElementSymbol body) {
this.body = body;
}
public boolean isStream() {
return false;
}
public boolean isUnroll() {
return false;
}
}
......@@ -42,7 +42,7 @@ public abstract class ResolvableSymbol extends CommonScopeSpanningSymbol {
if (sym instanceof ArchitectureSymbol){
return (ArchitectureSymbol) sym;
}