Commit b96760c0 authored by Christian Fuß's avatar Christian Fuß
Browse files

allowed Unrolls to contain input & output elements

parent 95e4b087
Pipeline #158368 passed with stages
in 17 minutes and 10 seconds
...@@ -73,7 +73,7 @@ grammar CNNArch extends de.monticore.CommonExpressions, de.monticore.lang.Math, ...@@ -73,7 +73,7 @@ grammar CNNArch extends de.monticore.CommonExpressions, de.monticore.lang.Math,
Unroll implements ArchitectureElement = "unroll" "<" timeParameter:LayerParameter ">" Unroll implements ArchitectureElement = "unroll" "<" timeParameter:LayerParameter ">"
Name "(" arguments:(ArchArgument || ",")* ")" Name "(" arguments:(ArchArgument || ",")* ")"
"{" input:IOElement? body:Stream output:IOElement? "}"; "{" body:Stream "}";
ParallelBlock implements ArchitectureElement = "(" ParallelBlock implements ArchitectureElement = "("
groups:Stream "|" groups:Stream "|"
......
...@@ -361,6 +361,20 @@ public class CNNArchSymbolTableCreator extends de.monticore.symboltable.CommonSy ...@@ -361,6 +361,20 @@ public class CNNArchSymbolTableCreator extends de.monticore.symboltable.CommonSy
} }
layer.setArguments(arguments); layer.setArguments(arguments);
List<ArchitectureElementSymbol> elements = new ArrayList<>();
int elementNumber = 0;
for (ASTArchitectureElement astElement : ast.getBody().getElementsList()){
elements.add((ArchitectureElementSymbol) astElement.getSymbolOpt().get());
if(elementNumber == 0){
layer.getDeclaration().getBody().setInputElement((ArchitectureElementSymbol) astElement.getSymbolOpt().get());
} else if(elementNumber == (ast.getBody().getElementsList().size() - 1)){
layer.getDeclaration().getBody().setOutputElement((ArchitectureElementSymbol) astElement.getSymbolOpt().get());
}
elementNumber++;
}
layer.getDeclaration().getBody().setElements(elements);
/*List<ArchitectureElementSymbol> elements = new ArrayList<>(); /*List<ArchitectureElementSymbol> elements = new ArrayList<>();
for (ASTStream astStream : ast.getGroupsList()){ for (ASTStream astStream : ast.getGroupsList()){
elements.add((SerialCompositeElementSymbol) astStream.getSymbolOpt().get()); elements.add((SerialCompositeElementSymbol) astStream.getSymbolOpt().get());
......
...@@ -50,6 +50,9 @@ public abstract class CompositeElementSymbol extends ArchitectureElementSymbol { ...@@ -50,6 +50,9 @@ public abstract class CompositeElementSymbol extends ArchitectureElementSymbol {
else if (element instanceof LayerSymbol) { else if (element instanceof LayerSymbol) {
isNetwork |= ((LayerSymbol) element).getDeclaration().isNetworkLayer(); isNetwork |= ((LayerSymbol) element).getDeclaration().isNetworkLayer();
} }
else if (element instanceof UnrollSymbol) {
isNetwork |= ((UnrollSymbol) element).getDeclaration().isNetworkLayer();
}
} }
return isNetwork; return isNetwork;
......
...@@ -152,7 +152,6 @@ public class LayerSymbol extends ArchitectureElementSymbol { ...@@ -152,7 +152,6 @@ public class LayerSymbol extends ArchitectureElementSymbol {
} }
else { else {
//split the layer if it contains an argument sequence //split the layer if it contains an argument sequence
System.err.println("Resolve() called in " + this.getName());
ArchitectureElementSymbol splitComposite = resolveSequences(parallelLength, getSerialLengths().get()); ArchitectureElementSymbol splitComposite = resolveSequences(parallelLength, getSerialLengths().get());
setResolvedThis(splitComposite); setResolvedThis(splitComposite);
splitComposite.resolveOrError(); splitComposite.resolveOrError();
......
...@@ -76,7 +76,6 @@ public class UnrollDeclarationSymbol extends CommonScopeSpanningSymbol { ...@@ -76,7 +76,6 @@ public class UnrollDeclarationSymbol extends CommonScopeSpanningSymbol {
} }
public SerialCompositeElementSymbol getBody() { public SerialCompositeElementSymbol getBody() {
System.err.println("Body_elements in UNROLLDECLARATIONSYMBOL: " + body.getElements().toString());
return body; return body;
} }
...@@ -89,6 +88,10 @@ public class UnrollDeclarationSymbol extends CommonScopeSpanningSymbol { ...@@ -89,6 +88,10 @@ public class UnrollDeclarationSymbol extends CommonScopeSpanningSymbol {
return false; return false;
} }
public boolean isNetworkLayer() {
return body.isNetwork();
}
public Optional<VariableSymbol> getParameter(String name) { public Optional<VariableSymbol> getParameter(String name) {
Optional<VariableSymbol> res = Optional.empty(); Optional<VariableSymbol> res = Optional.empty();
for (VariableSymbol parameter : getParameters()){ for (VariableSymbol parameter : getParameters()){
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
package de.monticore.lang.monticar.cnnarch.predefined; package de.monticore.lang.monticar.cnnarch.predefined;
import de.monticore.lang.monticar.cnnarch._symboltable.*; import de.monticore.lang.monticar.cnnarch._symboltable.*;
import de.se_rwth.commons.Joiners;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
...@@ -37,30 +38,15 @@ public class BeamSearchStart extends PredefinedUnrollDeclaration { ...@@ -37,30 +38,15 @@ public class BeamSearchStart extends PredefinedUnrollDeclaration {
@Override @Override
public List<ArchTypeSymbol> computeOutputTypes(List<ArchTypeSymbol> inputTypes, UnrollSymbol layer) { public List<ArchTypeSymbol> computeOutputTypes(List<ArchTypeSymbol> inputTypes, UnrollSymbol layer) {
try { for(ArchitectureElementSymbol item:layer.getDeclaration().getBody().getElements()){
System.err.println("allElements: " + layer.getDeclaration().getBody().getElements().toString()); try {
List<ArchitectureElementSymbol> elements = new ArrayList<ArchitectureElementSymbol>();
elements = layer.getDeclaration().getBody().getElements();
System.err.println("LastElement: " + elements.get(elements.size()-1));
//System.err.println("LastElement_Channels: " + elements.get(elements.size()-1).getOutputTypes().get(0).getChannels());
for(ArchitectureElementSymbol item:elements){
System.err.println("Resolved?1: " + item.isResolved());
item.resolve(); item.resolve();
System.err.println("Resolved?2: " + item.isResolved()); } catch (ArchResolveException e) {
//System.err.println("name2" + item.getOutputElement().get().toString()); System.err.println("The following names could not be resolved: " + Joiners.COMMA.join(item.getUnresolvableVariables()));
//System.err.println("channels: " + item.getOutputTypes().get(0).getChannels().toString());
System.err.println("name3" + item.getName());
} }
}catch(Exception e){
e.printStackTrace();
} }
return Collections.singletonList(new ArchTypeSymbol.Builder() return layer.getDeclaration().getBody().computeOutputTypes();
.channels(100) // TODO
.height(1)
.width(1)
.elementType("0", "1")
.build());
} }
@Override @Override
......
...@@ -2,15 +2,10 @@ architecture RNNencdec(max_length=50, vocabulary_size=30000, hidden_size=1000){ ...@@ -2,15 +2,10 @@ architecture RNNencdec(max_length=50, vocabulary_size=30000, hidden_size=1000){
def input Q(0:1)^{vocabulary_size} source def input Q(0:1)^{vocabulary_size} source
def output Q(0:1)^{vocabulary_size} target def output Q(0:1)^{vocabulary_size} target
source ->
FullyConnected(units=vocabulary_size) ->
Softmax() ->
target;
unroll<t> BeamSearchStart(max_length=max_length) { unroll<t> BeamSearchStart(max_length=max_length) {
source -> source ->
FullyConnected(units=vocabulary_size) -> FullyConnected(units=vocabulary_size) ->
Softmax() Softmax() ->
} -> target
target; };
} }
\ No newline at end of file
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment