Commit b96760c0 authored by Christian Fuß's avatar Christian Fuß
Browse files

allowed Unrolls to contain input & output elements

parent 95e4b087
Pipeline #158368 passed with stages
in 17 minutes and 10 seconds
......@@ -73,7 +73,7 @@ grammar CNNArch extends de.monticore.CommonExpressions, de.monticore.lang.Math,
Unroll implements ArchitectureElement = "unroll" "<" timeParameter:LayerParameter ">"
Name "(" arguments:(ArchArgument || ",")* ")"
"{" input:IOElement? body:Stream output:IOElement? "}";
"{" body:Stream "}";
ParallelBlock implements ArchitectureElement = "("
groups:Stream "|"
......
......@@ -361,6 +361,20 @@ public class CNNArchSymbolTableCreator extends de.monticore.symboltable.CommonSy
}
layer.setArguments(arguments);
List<ArchitectureElementSymbol> elements = new ArrayList<>();
int elementNumber = 0;
for (ASTArchitectureElement astElement : ast.getBody().getElementsList()){
elements.add((ArchitectureElementSymbol) astElement.getSymbolOpt().get());
if(elementNumber == 0){
layer.getDeclaration().getBody().setInputElement((ArchitectureElementSymbol) astElement.getSymbolOpt().get());
} else if(elementNumber == (ast.getBody().getElementsList().size() - 1)){
layer.getDeclaration().getBody().setOutputElement((ArchitectureElementSymbol) astElement.getSymbolOpt().get());
}
elementNumber++;
}
layer.getDeclaration().getBody().setElements(elements);
/*List<ArchitectureElementSymbol> elements = new ArrayList<>();
for (ASTStream astStream : ast.getGroupsList()){
elements.add((SerialCompositeElementSymbol) astStream.getSymbolOpt().get());
......
......@@ -50,6 +50,9 @@ public abstract class CompositeElementSymbol extends ArchitectureElementSymbol {
else if (element instanceof LayerSymbol) {
isNetwork |= ((LayerSymbol) element).getDeclaration().isNetworkLayer();
}
else if (element instanceof UnrollSymbol) {
isNetwork |= ((UnrollSymbol) element).getDeclaration().isNetworkLayer();
}
}
return isNetwork;
......
......@@ -152,7 +152,6 @@ public class LayerSymbol extends ArchitectureElementSymbol {
}
else {
//split the layer if it contains an argument sequence
System.err.println("Resolve() called in " + this.getName());
ArchitectureElementSymbol splitComposite = resolveSequences(parallelLength, getSerialLengths().get());
setResolvedThis(splitComposite);
splitComposite.resolveOrError();
......
......@@ -76,7 +76,6 @@ public class UnrollDeclarationSymbol extends CommonScopeSpanningSymbol {
}
public SerialCompositeElementSymbol getBody() {
System.err.println("Body_elements in UNROLLDECLARATIONSYMBOL: " + body.getElements().toString());
return body;
}
......@@ -89,6 +88,10 @@ public class UnrollDeclarationSymbol extends CommonScopeSpanningSymbol {
return false;
}
public boolean isNetworkLayer() {
return body.isNetwork();
}
public Optional<VariableSymbol> getParameter(String name) {
Optional<VariableSymbol> res = Optional.empty();
for (VariableSymbol parameter : getParameters()){
......
......@@ -21,6 +21,7 @@
package de.monticore.lang.monticar.cnnarch.predefined;
import de.monticore.lang.monticar.cnnarch._symboltable.*;
import de.se_rwth.commons.Joiners;
import java.util.ArrayList;
import java.util.Arrays;
......@@ -37,30 +38,15 @@ public class BeamSearchStart extends PredefinedUnrollDeclaration {
@Override
public List<ArchTypeSymbol> computeOutputTypes(List<ArchTypeSymbol> inputTypes, UnrollSymbol layer) {
try {
System.err.println("allElements: " + layer.getDeclaration().getBody().getElements().toString());
List<ArchitectureElementSymbol> elements = new ArrayList<ArchitectureElementSymbol>();
elements = layer.getDeclaration().getBody().getElements();
System.err.println("LastElement: " + elements.get(elements.size()-1));
//System.err.println("LastElement_Channels: " + elements.get(elements.size()-1).getOutputTypes().get(0).getChannels());
for(ArchitectureElementSymbol item:elements){
System.err.println("Resolved?1: " + item.isResolved());
for(ArchitectureElementSymbol item:layer.getDeclaration().getBody().getElements()){
try {
item.resolve();
System.err.println("Resolved?2: " + item.isResolved());
//System.err.println("name2" + item.getOutputElement().get().toString());
//System.err.println("channels: " + item.getOutputTypes().get(0).getChannels().toString());
System.err.println("name3" + item.getName());
} catch (ArchResolveException e) {
System.err.println("The following names could not be resolved: " + Joiners.COMMA.join(item.getUnresolvableVariables()));
}
}catch(Exception e){
e.printStackTrace();
}
return Collections.singletonList(new ArchTypeSymbol.Builder()
.channels(100) // TODO
.height(1)
.width(1)
.elementType("0", "1")
.build());
return layer.getDeclaration().getBody().computeOutputTypes();
}
@Override
......
......@@ -2,15 +2,10 @@ architecture RNNencdec(max_length=50, vocabulary_size=30000, hidden_size=1000){
def input Q(0:1)^{vocabulary_size} source
def output Q(0:1)^{vocabulary_size} target
source ->
FullyConnected(units=vocabulary_size) ->
Softmax() ->
target;
unroll<t> BeamSearchStart(max_length=max_length) {
source ->
FullyConnected(units=vocabulary_size) ->
Softmax()
} ->
target;
Softmax() ->
target
};
}
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment