Commit 31aa225b authored by Christian Fuß's avatar Christian Fuß
Browse files

create an UnrollBody for each timestep instead of a single large body.

parent e333f3fd
Pipeline #177574 passed with stages
in 22 minutes and 24 seconds
......@@ -21,15 +21,11 @@
package de.monticore.lang.monticar.cnnarch._symboltable;
import de.monticore.lang.monticar.cnnarch._ast.ASTArchitectureParameter;
import de.monticore.lang.monticar.cnnarch.helper.ErrorCodes;
import de.monticore.lang.monticar.cnnarch.predefined.AllPredefinedLayers;
import de.monticore.lang.monticar.cnnarch.predefined.AllPredefinedVariables;
import de.monticore.symboltable.CommonScopeSpanningSymbol;
import de.monticore.symboltable.Scope;
import de.monticore.symboltable.Symbol;
import de.monticore.symboltable.SymbolKind;
import de.se_rwth.commons.logging.Log;
import java.util.*;
import java.util.function.Function;
......@@ -43,6 +39,7 @@ public class UnrollSymbol extends CommonScopeSpanningSymbol {
private ParameterSymbol timeParameter;
private UnrollSymbol resolvedThis = null;
private SerialCompositeElementSymbol body;
private ArrayList<SerialCompositeElementSymbol> bodies = new ArrayList<>();
private boolean isExtendedForBackend = false;
public SerialCompositeElementSymbol getBody() {
......@@ -53,6 +50,14 @@ public class UnrollSymbol extends CommonScopeSpanningSymbol {
this.body = body;
}
public ArrayList<SerialCompositeElementSymbol> getBodiesForAllTimesteps() {
return bodies;
}
protected void setBodiesForAllTimesteps(ArrayList<SerialCompositeElementSymbol> bodies) {
this.bodies = bodies;
}
public boolean isExtended(){
return this.isExtendedForBackend;
}
......@@ -61,55 +66,6 @@ public class UnrollSymbol extends CommonScopeSpanningSymbol {
this.isExtendedForBackend = extended;
}
public UnrollSymbol createUnrollForBackend(){
if(this.isExtendedForBackend){
return this;
}else {
int i;
boolean skipFirst;
List<ArchitectureElementSymbol> newBodyList = new ArrayList<>();
for (i = this.getIntValue(AllPredefinedLayers.BEAMSEARCH_T_NAME).get(); i < this.getIntValue(AllPredefinedLayers.BEAMSEARCH_MAX_LENGTH).get(); i++) {
skipFirst = true;
SerialCompositeElementSymbol body = this.getBody().preResolveDeepCopy();
body.putInScope(this.getBody().getSpannedScope());
for (ArchitectureElementSymbol element : body.getElements()) {
if (!(i > 0 && skipFirst)) {
if(element.getEnclosingScope() == null) {
element.setEnclosingScope(getEnclosingScope().getAsMutableScope());
}
((ParameterSymbol)((ArrayList<Symbol>)getSpannedScope().getLocalSymbols().get(AllPredefinedLayers.BEAMSEARCH_T_NAME)).get(0)).getExpression().setValue(i);
try {
this.resolveExpressions();
} catch (ArchResolveException e) {
e.printStackTrace();
}
for(ParameterSymbol p:declaration.getParameters()){
if(p.getEnclosingScope() == null) {
p.putInScope(getSpannedScope());
}
}
try {
element.resolve();
} catch (ArchResolveException e) {
e.printStackTrace();
}
newBodyList.add(element);
}
skipFirst=false;
}
}
SerialCompositeElementSymbol newBody = new SerialCompositeElementSymbol();
newBody.putInScope(this.getBody().getSpannedScope());
newBody.setElements(newBodyList);
this.setBody(newBody);
this.setExtended(true);
List<UnrollSymbol> newUnrolls = new ArrayList<>();
newUnrolls.add(this);
getBody().getArchitecture().setUnrolls(newUnrolls);
return this;
}
}
public boolean isTrainable() {
return body.isTrainable();
}
......@@ -332,7 +288,61 @@ public class UnrollSymbol extends CommonScopeSpanningSymbol {
}
}
// creates a body for each timestep of the unroll
public UnrollSymbol createUnrollForBackend(){
if(this.isExtendedForBackend){
return this;
}else {
int timestep;
SerialCompositeElementSymbol newBody;
ArchitectureSymbol architecture = getBody().getArchitecture();
List<ArchitectureElementSymbol> newBodyList;
for (timestep = this.getIntValue(AllPredefinedLayers.BEAMSEARCH_T_NAME).get(); timestep < this.getIntValue(AllPredefinedLayers.BEAMSEARCH_MAX_LENGTH).get(); timestep++) {
newBody = new SerialCompositeElementSymbol();
newBodyList = new ArrayList<>();
SerialCompositeElementSymbol body = this.getBody().preResolveDeepCopy();
body.putInScope(this.getBody().getSpannedScope());
for (ArchitectureElementSymbol element : body.getElements()) {
if(element.getEnclosingScope() == null) {
element.setEnclosingScope(getEnclosingScope().getAsMutableScope());
}
((ParameterSymbol)((ArrayList<Symbol>)getSpannedScope().getLocalSymbols().get(AllPredefinedLayers.BEAMSEARCH_T_NAME)).get(0)).getExpression().setValue(timestep);
try {
this.resolveExpressions();
for(ParameterSymbol p:declaration.getParameters()){
if(p.getEnclosingScope() == null) {
p.putInScope(getSpannedScope());
}
}
element.resolve();
} catch (ArchResolveException e) {e.printStackTrace();}
newBodyList.add(element);
}
newBody.putInScope(this.getBody().getSpannedScope());
newBody.setElements(newBodyList);
bodies.add(newBody);
}
newBody = new SerialCompositeElementSymbol();
ArrayList elementsList = new ArrayList();
elementsList.addAll(bodies);
this.setBody(newBody);
this.setExtended(true);
List<UnrollSymbol> newUnrolls = new ArrayList<>();
newUnrolls.add(this);
architecture.setUnrolls(newUnrolls);
return this;
}
}
protected UnrollSymbol preResolveDeepCopy() {
......@@ -358,10 +368,8 @@ public class UnrollSymbol extends CommonScopeSpanningSymbol {
copy.setArguments(args);
copy.setTimeParameter(getTimeParameter().deepCopy());
//System.err.println("t2: " + this.getIntValue(AllPredefinedLayers.BEAMSEARCH_T_NAME));
// TODO: currently only the defaultValue for t is used, make it use the assigned value instead
((ParameterSymbol)((ArrayList<Symbol>)getSpannedScope().getLocalSymbols().get("t")).get(0)).getExpression().setValue(this.getIntValue(AllPredefinedLayers.BEAMSEARCH_T_NAME));
System.err.println("t_value: " + ((ParameterSymbol)((ArrayList<Symbol>)getSpannedScope().getLocalSymbols().get("t")).get(0)).getValue().get().toString());
copy.getTimeParameter().putInScope(getSpannedScope());
......
......@@ -20,13 +20,10 @@
*/
package de.monticore.lang.monticar.cnnarch.predefined;
import de.monticore.lang.monticar.cnnarch._ast.*;
import de.monticore.lang.monticar.cnnarch._symboltable.*;
import de.se_rwth.commons.Joiners;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
public class BeamSearchStart extends PredefinedUnrollDeclaration {
......@@ -35,66 +32,14 @@ public class BeamSearchStart extends PredefinedUnrollDeclaration {
super(AllPredefinedLayers.BEAMSEARCH_NAME);
}
List<LayerSymbol> layers = new ArrayList<>(Arrays.asList());
@Override
public List<ArchTypeSymbol> computeOutputTypes(List<ArchTypeSymbol> inputTypes, UnrollSymbol layer, VariableSymbol.Member member) {
List<ArchTypeSymbol> output = new ArrayList<ArchTypeSymbol>();
for(ASTArchitectureElement item: ((ASTUnroll) layer.getAstNode().get()).getBody().getElementsList()){
if(item instanceof ASTLayer) {
try {
//ArchitectureElementSymbol item = (ArchitectureElementSymbol) ASTitem;
LayerSymbol sublayer = (LayerSymbol) item.getSymbol();
sublayer.resolve();
System.err.println("inputElement: " + sublayer.getInputElement().get().toString());
System.err.println("resolved: " + sublayer.getResolvedThis().get().getInputElement().get().getOutputTypes().get(0).getChannels().toString());
//output = sublayer.getResolvedThis().get().getInputElement().get().getOutputTypes();
System.err.println("inputTypes: " + sublayer.getInputTypes().toString());
layers.add((LayerSymbol) sublayer);
System.err.println("outputTypes before: " + ((ArchitectureElementSymbol) sublayer).getOutputTypes().get(0).getChannels());
//System.err.println("arg0_NAME: " + ((LayerSymbol) sublayer).getArguments().get(0).getName());
//System.err.println("arg0_VALUE: " + ((LayerSymbol) sublayer).getIntValue((((LayerSymbol) sublayer).getArguments().get(0).getName())));
//item.setOutputTypes(item.getOutputTypes());
//System.err.println("outputTypes after: " + ((ArchitectureElementSymbol) astElement).getOutputTypes());
} catch (Exception e) {
LayerSymbol sublayer = (LayerSymbol) item.getSymbol();
//System.err.println("The following names could not be resolved: " + Joiners.COMMA.join(sublayer.getUnresolvableVariables()));
e.printStackTrace();
}
}else if(item instanceof ASTVariable){
try {
//ArchitectureElementSymbol item = (ArchitectureElementSymbol) ASTitem;
VariableSymbol sublayer = (VariableSymbol) item.getSymbol();
sublayer.resolve();
System.err.println("resolved2: " + sublayer.getResolvedThis().get());
//System.err.println("resolved2_1: " + sublayer.getResolvedThis().get().getInputElement().get().toString());
System.err.println("isOutput?: " + sublayer.isOutput());
System.err.println("Definition: " + sublayer.getIoDeclaration().toString());
System.err.println("Domain: " + sublayer.getIoDeclaration().getType().getDomain().toString());
System.err.println("Type: " + sublayer.getIoDeclaration().getType().getChannels().toString());
//output = sublayer.getResolvedThis().get().getInputElement().get().getOutputTypes();
if(sublayer.isOutput()){
output = new ArrayList<ArchTypeSymbol>();
}
//item.setOutputTypes(item.getOutputTypes());
//System.err.println("outputTypes after: " + ((ArchitectureElementSymbol) astElement).getOutputTypes());
} catch (Exception e) {
VariableSymbol sublayer = (VariableSymbol) item.getSymbol();
//System.err.println("The following names could not be resolved2: " + Joiners.COMMA.join(sublayer.getUnresolvableVariables()));
e.printStackTrace();
}
}
}
System.err.println("output: " + output);
return output;
return new ArrayList<ArchTypeSymbol>();
}
@Override
public void checkInput(List<ArchTypeSymbol> inputTypes, UnrollSymbol layer, VariableSymbol.Member member) {
//errorIfInputSizeIsNotOne(inputTypes, layer);
errorIfInputSizeIsNotOne(inputTypes, layer);
}
public static BeamSearchStart create(){
......
......@@ -5,7 +5,7 @@ architecture RNNencdec(max_length=50, vocabulary_size=30000, hidden_size=1000){
source -> Softmax() -> target[0];
timed <t=0> BeamSearchStart(max_length=4){
target[t] ->
source ->
FullyConnected(units=30000) ->
Softmax() ->
target[t+1]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment