Commit 6770c3e0 authored by Christian Fuß's avatar Christian Fuß
Browse files

changed Unroll to be an instruction, like Streams. Still get errors when...

changed Unroll to be an instruction, like Streams. Still get errors when trying to access Unroll body from EMADL2CPP
parent 2c6ec817
Pipeline #172962 failed with stages
......@@ -64,6 +64,10 @@ grammar CNNArch extends de.monticore.CommonExpressions, de.monticore.lang.Math,
Stream implements Instruction = elements:(ArchitectureElement || "->")+;
Unroll implements Instruction = "timed" "<" timeParameter:ArchitectureParameter ">"
Name "(" arguments:(ArchArgument || ",")* ")"
"{" body:Stream "}";
interface ArchitectureElement;
IOElement implements ArchitectureElement = Name ("[" index:ArchSimpleExpression "]")?;
......@@ -72,9 +76,6 @@ grammar CNNArch extends de.monticore.CommonExpressions, de.monticore.lang.Math,
Layer implements ArchitectureElement = Name "(" arguments:(ArchArgument || ",")* ")";
Unroll implements ArchitectureElement = "unroll" Name "(" arguments:(ArchArgument || ",")* ")" "{"
body:Stream
"}";
ParallelBlock implements ArchitectureElement = "("
groups:Stream "|"
......
......@@ -22,6 +22,7 @@ package de.monticore.lang.monticar.cnnarch._cocos;
import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.CompositeElementSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.UnrollSymbol;
import de.monticore.lang.monticar.cnnarch.helper.ErrorCodes;
import de.se_rwth.commons.logging.Log;
......@@ -29,13 +30,23 @@ public class CheckArchitectureFinished extends CNNArchSymbolCoCo {
@Override
public void check(ArchitectureSymbol architecture) {
System.err.println("Architecture in checkArchFinished: " + architecture.toString());
for (CompositeElementSymbol stream : architecture.getStreams()) {
System.err.println("Stream in checkArchFinished");
if (!stream.getOutputTypes().isEmpty()){
Log.error("0" + ErrorCodes.UNFINISHED_ARCHITECTURE + " The architecture is not finished. " +
"There are still open streams at the end of the architecture. "
, architecture.getSourcePosition());
}
}
for (UnrollSymbol unroll : architecture.getUnrolls()) {
System.err.println("UnrollSymbol in checkArchFinished");
if (!unroll.getOutputTypes().isEmpty()){
Log.error("0" + ErrorCodes.UNFINISHED_ARCHITECTURE + " The architecture is not finished. " +
"There are still open streams at the end of the architecture. "
, architecture.getSourcePosition());
}
}
if (architecture.getInputs().isEmpty()){
Log.error("0" + ErrorCodes.UNFINISHED_ARCHITECTURE + " The architecture has no inputs. "
, architecture.getSourcePosition());
......
......@@ -22,6 +22,7 @@ package de.monticore.lang.monticar.cnnarch._cocos;
import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.CompositeElementSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.UnrollSymbol;
public class CheckElementInputs extends CNNArchSymbolCoCo {
......@@ -30,5 +31,11 @@ public class CheckElementInputs extends CNNArchSymbolCoCo {
for (CompositeElementSymbol stream : architecture.getStreams()) {
stream.checkInput();
}
for (UnrollSymbol unroll : architecture.getUnrolls()) {
System.err.println("BEFORE check");
unroll.checkInput();
System.err.println("AFTER check");
}
}
}
......@@ -22,6 +22,7 @@ package de.monticore.lang.monticar.cnnarch._cocos;
import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.CompositeElementSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.UnrollSymbol;
import de.monticore.lang.monticar.cnnarch.helper.ErrorCodes;
import de.se_rwth.commons.logging.Log;
......@@ -35,6 +36,10 @@ public class CheckNetworkStreamMissing extends CNNArchSymbolCoCo {
hasNetworkStream |= stream.isNetwork();
}
for (UnrollSymbol unroll : architecture.getUnrolls()) {
hasNetworkStream |= unroll.isNetwork();
}
if (!hasNetworkStream) {
Log.error("0" + ErrorCodes.MISSING_NETWORK_STREAM + " The architecture has no network stream. "
, architecture.getSourcePosition());
......
......@@ -38,6 +38,7 @@ public class ArchitectureSymbol extends CommonScopeSpanningSymbol {
public static final ArchitectureKind KIND = new ArchitectureKind();
private List<SerialCompositeElementSymbol> streams = new ArrayList<>();
private List<UnrollSymbol> unrolls = new ArrayList<>();
private List<IOSymbol> inputs = new ArrayList<>();
private List<IOSymbol> outputs = new ArrayList<>();
private Map<String, IODeclarationSymbol> ioDeclarationMap = new HashMap<>();
......@@ -56,6 +57,14 @@ public class ArchitectureSymbol extends CommonScopeSpanningSymbol {
this.streams = streams;
}
public List<UnrollSymbol> getUnrolls() {
return unrolls;
}
public void setUnrolls(List<UnrollSymbol> unrolls) {
this.unrolls = unrolls;
}
public String getDataPath() {
return this.dataPath;
}
......@@ -73,6 +82,7 @@ public class ArchitectureSymbol extends CommonScopeSpanningSymbol {
}
public List<IOSymbol> getInputs() {
System.err.println("THE inputs: " + inputs);
return inputs;
}
......@@ -118,6 +128,18 @@ public class ArchitectureSymbol extends CommonScopeSpanningSymbol {
// Do nothing; error is already logged
}
}
for (UnrollSymbol unroll : unrolls) {
unroll.checkIfResolvable();
try {
unroll.resolveOrError();
//unroll.getBody().resolveOrError();
}
catch (ArchResolveException e) {
// Do nothing; error is already logged
}
}
}
/*public List<ArchitectureElementSymbol> getFirstElements() {
......@@ -134,6 +156,10 @@ public class ArchitectureSymbol extends CommonScopeSpanningSymbol {
resolved &= stream.isResolved();
}
for (UnrollSymbol unroll: unrolls) {
resolved &= unroll.isResolved();
}
return resolved;
}
......@@ -144,6 +170,10 @@ public class ArchitectureSymbol extends CommonScopeSpanningSymbol {
resolvable &= stream.isResolvable();
}
for (UnrollSymbol unroll: unrolls) {
resolvable &= unroll.isResolvable();
}
return resolvable;
}
......@@ -199,6 +229,14 @@ public class ArchitectureSymbol extends CommonScopeSpanningSymbol {
}
copy.setStreams(copyStreams);
List<UnrollSymbol> copyUnrolls = new ArrayList<>();
for (UnrollSymbol unroll : unrolls) {
UnrollSymbol copyUnroll = unroll.preResolveDeepCopy();
copyUnroll.putInScope(copy.getSpannedScope());
copyUnrolls.add(copyUnroll);
}
copy.setUnrolls(copyUnrolls);
copy.putInScope(enclosingScopeOfCopy);
return copy;
}
......
......@@ -145,13 +145,21 @@ public class CNNArchSymbolTableCreator extends de.monticore.symboltable.CommonSy
public void endVisit(final ASTArchitecture node) {
List<SerialCompositeElementSymbol> streams = new ArrayList<>();
List<UnrollSymbol> unrolls = new ArrayList<>();
for (ASTInstruction astInstruction : node.getInstructionsList()){
ASTStream astStream = (ASTStream)astInstruction; // TODO: For now all instructions are streams
streams.add((SerialCompositeElementSymbol) astStream.getSymbolOpt().get());
if(astInstruction instanceof ASTStream) {
ASTStream astStream = (ASTStream) astInstruction;
streams.add((SerialCompositeElementSymbol) astStream.getSymbolOpt().get());
}else if(astInstruction instanceof ASTUnroll) {
ASTUnroll astUnroll = (ASTUnroll) astInstruction;
unrolls.add((UnrollSymbol) astUnroll.getSymbolOpt().get());
System.err.println("Table 1: " + ((UnrollSymbol) astUnroll.getSymbolOpt().get()).getName());
System.err.println("Table 1_1: " + ((UnrollSymbol) astUnroll.getSymbolOpt().get()).getBody().getElements().toString());
}
}
System.err.println("777333: streams set for architecture");
System.err.println(streams.get(0).getElements().toString());
architecture.setStreams(streams);
architecture.setUnrolls(unrolls);
removeCurrentScope();
}
......@@ -360,9 +368,9 @@ public class CNNArchSymbolTableCreator extends de.monticore.symboltable.CommonSy
}
sces.setElements(elements);
layer.setBody(sces);
layer.getDeclaration().setBody(sces);
//layer.getDeclaration().setBody(sces);
layer.getDeclaration().getBody().setElements(elements);
//layer.setElements(elements);
List<ArgumentSymbol> arguments = new ArrayList<>(6);
for (ASTArchArgument astArgument : ast.getArgumentsList()){
......@@ -371,70 +379,9 @@ public class CNNArchSymbolTableCreator extends de.monticore.symboltable.CommonSy
}
layer.setArguments(arguments);
int elementNumber = 0;
for (ASTArchitectureElement astElement : ast.getBody().getElementsList()){
//sublayer.getDeclaration().setBody((SerialCompositeElementSymbol) ast.getBody().getSymbolOpt().get());
List<ArgumentSymbol> subarguments = new ArrayList<>(6);
if(astElement.getSymbolOpt().get() instanceof LayerSymbol) {
for (ArgumentSymbol argument : ((LayerSymbol) astElement.getSymbolOpt().get()).getArguments()) {
System.err.println("arg: " + argument);
subarguments.add(argument);
}
((LayerSymbol) astElement.getSymbolOpt().get()).setArguments(subarguments);
}
if(elementNumber == 0 && astElement.getSymbolOpt().get() instanceof IOSymbol){
((IOSymbol)astElement.getSymbol()).getArchitecture().resolveIODeclaration(((IOSymbol)astElement.getSymbol()).getName());
System.err.println("Here 55566");
Iterator iterator = ((IOSymbol)astElement.getSymbol()).getArchitecture().getIODeclarations().iterator();
while (iterator.hasNext()){
System.err.println("0001111: " + iterator.next().toString());
}
ASTIODeclaration ioAST = (ASTIODeclaration) ((IOSymbol)astElement.getSymbol()).getDefinition().getAstNode().get();
IODeclarationSymbol iODeclaration = ((IOSymbol)astElement.getSymbol()).getDefinition();
if (ioAST.isPresentArrayDeclaration()){
iODeclaration.setArrayLength(ioAST.getArrayDeclaration().getIntLiteral().getNumber().get().intValue());
}
iODeclaration.setInput(ioAST.isPresentIn());
iODeclaration.setType((ArchTypeSymbol) ioAST.getType().getSymbolOpt().get());
try {
((IOSymbol) astElement.getSymbolOpt().get()).resolve();
((IOSymbol) ast.getBody().getElements(0).getSymbolOpt().get()).resolve();
} catch (ArchResolveException e) {
e.printStackTrace();
}
layer.getBody().setInputElement((ArchitectureElementSymbol) astElement.getSymbolOpt().get());
layer.setInputElement((ArchitectureElementSymbol) astElement.getSymbolOpt().get());
}else if(elementNumber == (ast.getBody().getElementsList().size() - 1) && astElement.getSymbolOpt().get() instanceof IOSymbol){
Iterator iterator2 = ((IOSymbol)astElement.getSymbol()).getArchitecture().getInputs().iterator();
while (iterator2.hasNext()){
System.err.println("0002222: " + iterator2.next().toString());
}
try {
((IOSymbol) astElement.getSymbolOpt().get()).resolve();
((IOSymbol) ast.getBody().getElements(ast.getBody().getElementsList().size() - 1).getSymbolOpt().get()).resolve();
} catch (ArchResolveException e) {
e.printStackTrace();
}
layer.getBody().setOutputElement((ArchitectureElementSymbol) astElement.getSymbolOpt().get());
layer.setOutputElement((ArchitectureElementSymbol) astElement.getSymbolOpt().get());
}
elementNumber++;
}
/*List<ArchitectureElementSymbol> elements = new ArrayList<>();
for (ASTStream astStream : ast.getGroupsList()){
......
......@@ -51,7 +51,7 @@ public abstract class CompositeElementSymbol extends ArchitectureElementSymbol {
isNetwork |= ((LayerSymbol) element).getDeclaration().isNetworkLayer();
}
else if (element instanceof UnrollSymbol) {
isNetwork |= ((UnrollSymbol) element).isNetworkLayer();
isNetwork |= ((UnrollSymbol) element).isNetwork();
}
}
......
......@@ -72,7 +72,7 @@ public class UnrollSymbol extends ArchitectureElementSymbol {
this.body = body;
}
public boolean isNetworkLayer() {
public boolean isNetwork() {
return body.isNetwork();
}
......@@ -154,12 +154,12 @@ public class UnrollSymbol extends ArchitectureElementSymbol {
@Override
public List<ArchitectureElementSymbol> getFirstAtomicElements() {
return getDeclaration().getBody().getElements().get(0).getFirstAtomicElements();
return this.getBody().getElements().get(0).getFirstAtomicElements();
}
@Override
public List<ArchitectureElementSymbol> getLastAtomicElements() {
return getDeclaration().getBody().getElements().get(getDeclaration().getBody().getElements().size()-1).getLastAtomicElements();
return this.getBody().getElements().get(this.getBody().getElements().size()-1).getLastAtomicElements();
}
@Override
......@@ -285,6 +285,7 @@ public class UnrollSymbol extends ArchitectureElementSymbol {
@Override
public List<ArchTypeSymbol> computeOutputTypes() {
System.err.println("##33333333");
if (getResolvedThis().isPresent()) {
if (getResolvedThis().get() == this) {
List<ArchTypeSymbol> inputTypes = getInputTypes();
......@@ -470,7 +471,7 @@ public class UnrollSymbol extends ArchitectureElementSymbol {
}
@Override
protected ArchitectureElementSymbol preResolveDeepCopy() {
protected UnrollSymbol preResolveDeepCopy() {
UnrollSymbol copy = new UnrollSymbol(getName());
if (getAstNode().isPresent()){
copy.setAstNode(getAstNode().get());
......
......@@ -42,54 +42,15 @@ public class BeamSearchStart extends PredefinedUnrollDeclaration {
List<ArchTypeSymbol> output = new ArrayList<ArchTypeSymbol>();
for(ASTArchitectureElement item: ((ASTUnroll) layer.getAstNode().get()).getBody().getElementsList()){
if(item instanceof ASTLayer) {
try {
//ArchitectureElementSymbol item = (ArchitectureElementSymbol) ASTitem;
LayerSymbol sublayer = (LayerSymbol) item.getSymbol();
sublayer.resolve();
System.err.println("inputElement: " + sublayer.getInputElement().get().toString());
System.err.println("resolved: " + sublayer.getResolvedThis().get().getInputElement().get().getOutputTypes().get(0).getChannels().toString());
//output = sublayer.getResolvedThis().get().getInputElement().get().getOutputTypes();
System.err.println("inputTypes: " + sublayer.getInputTypes().toString());
layers.add((LayerSymbol) sublayer);
System.err.println("outputTypes before: " + ((ArchitectureElementSymbol) sublayer).getOutputTypes().get(0).getChannels());
//System.err.println("arg0_NAME: " + ((LayerSymbol) sublayer).getArguments().get(0).getName());
//System.err.println("arg0_VALUE: " + ((LayerSymbol) sublayer).getIntValue((((LayerSymbol) sublayer).getArguments().get(0).getName())));
//item.setOutputTypes(item.getOutputTypes());
//System.err.println("outputTypes after: " + ((ArchitectureElementSymbol) astElement).getOutputTypes());
} catch (Exception e) {
LayerSymbol sublayer = (LayerSymbol) item.getSymbol();
//System.err.println("The following names could not be resolved: " + Joiners.COMMA.join(sublayer.getUnresolvableVariables()));
e.printStackTrace();
}
}else if(item instanceof ASTIOElement){
try {
//ArchitectureElementSymbol item = (ArchitectureElementSymbol) ASTitem;
IOSymbol sublayer = (IOSymbol) item.getSymbol();
sublayer.resolve();
//TODO setinputElement !!!!!
System.err.println("resolved2: " + sublayer.getResolvedThis().get().getInputElement().get().toString());
System.err.println("isOutput?: " + sublayer.isOutput());
System.err.println("Definition: " + sublayer.getDefinition().toString());
System.err.println("Domain: " + sublayer.getDefinition().getType().getDomain().toString());
System.err.println("Type: " + sublayer.getDefinition().getType().getChannels().toString());
//output = sublayer.getResolvedThis().get().getInputElement().get().getOutputTypes();
if(sublayer.isOutput()){
output = new ArrayList<ArchTypeSymbol>();
}
//item.setOutputTypes(item.getOutputTypes());
//System.err.println("outputTypes after: " + ((ArchitectureElementSymbol) astElement).getOutputTypes());
} catch (Exception e) {
IOSymbol sublayer = (IOSymbol) item.getSymbol();
//System.err.println("The following names could not be resolved2: " + Joiners.COMMA.join(sublayer.getUnresolvableVariables()));
e.printStackTrace();
}
}
try {
// TODO: why is body null in EMADL2CPP?
System.err.println("BODY: " + layer.getBody().getElements().toString());
layer.getBody().resolveOrError();
//layer.getDeclaration().getBody().resolveOrError();
} catch (ArchResolveException e) {
e.printStackTrace();
}
System.err.println("output: " + output);
return output;
}
......
......@@ -2,11 +2,10 @@ architecture RNNencdec(max_length=50, vocabulary_size=30000, hidden_size=1000){
def input Q(0:1)^{vocabulary_size} source
def output Q(0:1)^{vocabulary_size} target
unroll BeamSearchStart(max_length=max_length) {
timed<t> BeamSearchStart(max_length=max_length) {
source ->
FullyConnected(units=17) ->
Softmax() ->
FullyConnected(units=vocabulary_size) ->
target
};
}
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment