Commit 963eae5f authored by Sebastian Nickels's avatar Sebastian Nickels
Browse files

Cleaned up UnrollSymbol, added ArgMax layer

parent 13ca88a6
......@@ -42,7 +42,7 @@ public class CheckArgument implements CNNArchASTArchArgumentCoCo {
}
}else if(argument.getEnclosingScope().getSpanningSymbol().get() instanceof UnrollSymbol){
UnrollDeclarationSymbol layerDeclaration = argument.getUnroll().getDeclaration();
if (layerDeclaration != null && argument.getUnrollParameter() == null){
if (layerDeclaration != null && argument.getParameter() == null){
Log.error("0"+ ErrorCodes.UNKNOWN_ARGUMENT + " Unknown Argument. " +
"Parameter with name '" + node.getName() + "' does not exist. " +
"Possible arguments are: " + Joiners.COMMA.join(layerDeclaration.getParameters())
......
......@@ -252,6 +252,15 @@ public class ArchTypeSymbol extends CommonSymbol {
domain.setRange(range);
return this;
}
public Builder elementType(String name, String start, String end){
domain = new ASTElementType();
domain.setName(name); //("Q(" + start + ":" + end +")");
ASTRange range = new ASTRange();
range.setStartValue(start);
range.setEndValue(end);
domain.setRange(range);
return this;
}
public ArchTypeSymbol build(){
ArchTypeSymbol sym = new ArchTypeSymbol();
......
......@@ -123,15 +123,13 @@ public class ArchitectureSymbol extends CommonScopeSpanningSymbol {
}
for (UnrollSymbol unroll : unrolls) {
if(unroll.isResolvable());
{
try {
unroll.resolve();
unroll = unroll.createUnrollForBackend();
}
catch (ArchResolveException e) {
// Do nothing; error is already logged
}
unroll.checkIfResolvable();
try {
unroll.resolveOrError();
}
catch (ArchResolveException e) {
// Do nothing; error is already logged
}
}
}
......@@ -199,19 +197,14 @@ public class ArchitectureSymbol extends CommonScopeSpanningSymbol {
}
}
for (UnrollDeclarationSymbol unrollDeclaration : getSpannedScope().<UnrollDeclarationSymbol>resolveLocally(UnrollDeclarationSymbol.KIND)){
if (!unrollDeclaration.isPredefined()) {
copy.getSpannedScope().getAsMutableScope().add(unrollDeclaration.deepCopy());
}
}
List<LayerVariableDeclarationSymbol> copyLayerParameterDeclarations = new ArrayList<>();
for (LayerVariableDeclarationSymbol layerParameterDeclaration : getLayerVariableDeclarations()) {
LayerVariableDeclarationSymbol copyLayerParameterDeclaration =
(LayerVariableDeclarationSymbol) layerParameterDeclaration.preResolveDeepCopy();
copyLayerParameterDeclaration.putInScope(copy.getSpannedScope());
copyLayerParameterDeclarations.add(copyLayerParameterDeclaration);
List<LayerVariableDeclarationSymbol> copyLayerVariableDeclarations = new ArrayList<>();
for (LayerVariableDeclarationSymbol layerVariableDeclaration : getLayerVariableDeclarations()) {
LayerVariableDeclarationSymbol copyLayerVariableDeclaration =
(LayerVariableDeclarationSymbol) layerVariableDeclaration.preResolveDeepCopy();
copyLayerVariableDeclaration.putInScope(copy.getSpannedScope());
copyLayerVariableDeclarations.add(copyLayerVariableDeclaration);
}
copy.setLayerVariableDeclarations(copyLayerParameterDeclarations);
copy.setLayerVariableDeclarations(copyLayerVariableDeclarations);
List<SerialCompositeElementSymbol> copyStreams = new ArrayList<>();
for (SerialCompositeElementSymbol stream : getStreams()) {
......@@ -223,7 +216,7 @@ public class ArchitectureSymbol extends CommonScopeSpanningSymbol {
List<UnrollSymbol> copyUnrolls = new ArrayList<>();
for (UnrollSymbol unroll : getUnrolls()) {
UnrollSymbol copyUnroll = unroll.preResolveDeepCopy();
UnrollSymbol copyUnroll = (UnrollSymbol) unroll.preResolveDeepCopy();
copyUnroll.putInScope(copy.getSpannedScope());
copyUnrolls.add(copyUnroll);
}
......
......@@ -43,25 +43,28 @@ public class ArgumentSymbol extends CommonSymbol {
public ParameterSymbol getParameter() {
if (parameter == null){
if (getLayer().getDeclaration() != null){
Optional<ParameterSymbol> optParam = getLayer().getDeclaration().getParameter(getName());
optParam.ifPresent(this::setParameter);
Symbol spanningSymbol = getEnclosingScope().getSpanningSymbol().get();
if (spanningSymbol instanceof UnrollSymbol) {
UnrollSymbol unroll = (UnrollSymbol) getEnclosingScope().getSpanningSymbol().get();
if (unroll.getDeclaration() != null){
Optional<ParameterSymbol> optParam = unroll.getDeclaration().getParameter(getName());
optParam.ifPresent(this::setParameter);
}
}
}
return parameter;
}
else {
LayerSymbol layer = (LayerSymbol) getEnclosingScope().getSpanningSymbol().get();
public ParameterSymbol getUnrollParameter() {
if (parameter == null){
if (getUnroll().getDeclaration() != null){
Optional<ParameterSymbol> optParam = getUnroll().getDeclaration().getParameter(getName());
optParam.ifPresent(this::setParameter);
if (layer.getDeclaration() != null){
Optional<ParameterSymbol> optParam = layer.getDeclaration().getParameter(getName());
optParam.ifPresent(this::setParameter);
}
}
}
return parameter;
}
protected void setParameter(ParameterSymbol parameter) {
this.parameter = parameter;
}
......@@ -104,7 +107,6 @@ public class ArgumentSymbol extends CommonSymbol {
public void set(){
if (getRhs().isResolved() && getRhs().isSimpleValue()){
getParameter().setExpression((ArchSimpleExpressionSymbol) getRhs());
getUnrollParameter().setExpression((ArchSimpleExpressionSymbol) getRhs());
}
else {
throw new IllegalStateException("The value of the parameter is set to a sequence or the expression is not resolved. This should never happen.");
......@@ -119,14 +121,6 @@ public class ArgumentSymbol extends CommonSymbol {
}
}
public void resolveUnrollExpression() throws ArchResolveException {
getRhs().resolveOrError();
boolean valid = Constraints.checkUnroll(this);
if (!valid){
throw new ArchResolveException();
}
}
public void checkConstraints(){
Constraints.check(this);
}
......
......@@ -209,16 +209,6 @@ public enum Constraints {
return valid;
}
public static boolean checkUnroll(ArgumentSymbol argument){
boolean valid = true;
ParameterSymbol parameter = argument.getUnrollParameter();
for (Constraints constraint : parameter.getConstraints()) {
valid = valid &&
constraint.check(argument.getRhs(), argument.getSourcePosition(), parameter.getName());
}
return valid;
}
public boolean check(ArchExpressionSymbol exp, SourcePosition sourcePosition, String name){
if (exp instanceof ArchRangeExpressionSymbol){
ArchRangeExpressionSymbol range = (ArchRangeExpressionSymbol)exp;
......
......@@ -252,9 +252,6 @@ public class LayerSymbol extends ArchitectureElementSymbol {
@Override
public List<ArchTypeSymbol> computeOutputTypes() {
if (getResolvedThis().isPresent()) {
if (getResolvedThis().get() == this) {
return ((PredefinedLayerDeclaration) getDeclaration()).computeOutputTypes(getInputTypes(), this, VariableSymbol.Member.NONE);
......
......@@ -139,7 +139,7 @@ public class ParallelCompositeElementSymbol extends CompositeElementSymbol {
List<ArchitectureElementSymbol> elements = new ArrayList<>(getElements().size());
for (ArchitectureElementSymbol element : getElements()){
ArchitectureElementSymbol elementCopy = element.preResolveDeepCopy();
ArchitectureElementSymbol elementCopy = (ArchitectureElementSymbol) element.preResolveDeepCopy();
elements.add(elementCopy);
}
copy.setElements(elements);
......
......@@ -123,5 +123,5 @@ public abstract class ResolvableSymbol extends CommonScopeSpanningSymbol {
* Creates a deep copy in the state before the architecture resolution.
* @return returns a deep copy of this object in the pre-resolve version.
*/
protected abstract ArchitectureElementSymbol preResolveDeepCopy();
protected abstract ResolvableSymbol preResolveDeepCopy();
}
......@@ -124,7 +124,7 @@ public class SerialCompositeElementSymbol extends CompositeElementSymbol {
List<ArchitectureElementSymbol> elements = new ArrayList<>(getElements().size());
for (ArchitectureElementSymbol element : getElements()){
ArchitectureElementSymbol elementCopy = element.preResolveDeepCopy();
ArchitectureElementSymbol elementCopy = (ArchitectureElementSymbol) element.preResolveDeepCopy();
elements.add(elementCopy);
}
copy.setElements(elements);
......
......@@ -30,69 +30,55 @@ import de.monticore.symboltable.Symbol;
import java.util.*;
import java.util.function.Function;
public class UnrollSymbol extends CommonScopeSpanningSymbol {
public class UnrollSymbol extends ResolvableSymbol {
public static final UnrollKind KIND = new UnrollKind();
private UnrollDeclarationSymbol declaration = null;
private List<ArgumentSymbol> arguments;
private ParameterSymbol timeParameter;
private UnrollSymbol resolvedThis = null;
private SerialCompositeElementSymbol body;
private ArrayList<SerialCompositeElementSymbol> bodies = new ArrayList<>();
private boolean isExtendedForBackend = false;
public SerialCompositeElementSymbol getBody() {
return body;
}
protected void setBody(SerialCompositeElementSymbol body) {
this.body = body;
}
private List<SerialCompositeElementSymbol> bodies = new ArrayList<>();
public ArrayList<SerialCompositeElementSymbol> getBodiesForAllTimesteps() {
return bodies;
protected UnrollSymbol(String name) {
super(name, KIND);
}
protected void setBodiesForAllTimesteps(ArrayList<SerialCompositeElementSymbol> bodies) {
this.bodies = bodies;
}
public UnrollDeclarationSymbol getDeclaration() {
if (declaration == null) {
Collection<UnrollDeclarationSymbol> collection = getEnclosingScope().resolveMany(getName(), UnrollDeclarationSymbol.KIND);
public boolean isExtended(){
return this.isExtendedForBackend;
if (!collection.isEmpty()) {
declaration = collection.iterator().next();
}
else {
throw new IllegalStateException("No unroll declaration found");
}
}
return declaration;
}
private void setExtended(boolean extended){
this.isExtendedForBackend = extended;
public List<SerialCompositeElementSymbol> getBodiesForAllTimesteps() {
return bodies;
}
public boolean isTrainable() {
return body.isTrainable();
public SerialCompositeElementSymbol getBody() {
return body;
}
protected UnrollSymbol(String name) {
super(name, KIND);
protected void setBody(SerialCompositeElementSymbol body) {
this.body = body;
}
public UnrollDeclarationSymbol getDeclaration() {
if (declaration == null){
Collection<UnrollDeclarationSymbol> declarationCollection = getEnclosingScope().resolveMany(getName(), UnrollDeclarationSymbol.KIND);
if (!declarationCollection.isEmpty()){
setDeclaration(declarationCollection.iterator().next());
}
}
return declaration;
public boolean isTrainable() {
return getBody().isTrainable();
}
@Override
public boolean isResolvable() {
return getBody().isResolvable() && getDeclaration() != null;
}
private void setDeclaration(UnrollDeclarationSymbol declaration) {
this.declaration = declaration;
}
public List<ArgumentSymbol> getArguments() {
return arguments;
}
......@@ -109,16 +95,6 @@ public class UnrollSymbol extends CommonScopeSpanningSymbol {
this.timeParameter = timeParameter;
}
public ArchExpressionSymbol getIfExpression(){
Optional<ArgumentSymbol> argument = getArgument(AllPredefinedVariables.CONDITIONAL_ARG_NAME);
if (argument.isPresent()){
return argument.get().getRhs();
}
else {
return ArchSimpleExpressionSymbol.of(true);
}
}
protected void putInScope(Scope scope){
Collection<Symbol> symbolsInScope = scope.getLocalSymbols().get(getName());
if (symbolsInScope == null || !symbolsInScope.contains(this)){
......@@ -132,64 +108,65 @@ public class UnrollSymbol extends CommonScopeSpanningSymbol {
}
}
protected void setResolvedThis(UnrollSymbol resolvedThis) {
if (resolvedThis != null){
//resolvedThis.putInScope(getSpannedScope());
}
this.resolvedThis = resolvedThis;
}
public Set<ParameterSymbol> resolve() throws ArchResolveException {
if (true) {
if (!isResolved()) {
if (isResolvable()) {
getDeclaration();
resolveExpressions();
//resolve the unroll call
getBody().resolveOrError();
for (int timestep = this.getIntValue(AllPredefinedLayers.T_NAME).get(); timestep < this.getIntValue(AllPredefinedLayers.MAX_LENGTH_NAME).get(); timestep++) {
SerialCompositeElementSymbol newBody = new SerialCompositeElementSymbol();
List<ArchitectureElementSymbol> newBodyList = new ArrayList<>();
SerialCompositeElementSymbol body = getBody().preResolveDeepCopy();
body.putInScope(getBody().getSpannedScope());
for (ArchitectureElementSymbol element : body.getElements()) {
if (element.getEnclosingScope() == null) {
element.setEnclosingScope(getEnclosingScope().getAsMutableScope());
}
((ParameterSymbol)((ArrayList<Symbol>)getSpannedScope().getLocalSymbols().get(AllPredefinedLayers.T_NAME)).get(0)).getExpression().setValue(timestep);
try {
this.resolveExpressions();
for (ParameterSymbol p:declaration.getParameters()) {
if (p.getEnclosingScope() == null) {
p.putInScope(getSpannedScope());
}
}
element.resolve();
}
catch (ArchResolveException e) {
e.printStackTrace();
}
newBodyList.add(element);
}
newBody.putInScope(this.getBody().getSpannedScope());
newBody.setElements(newBodyList);
bodies.add(newBody);
}
UnrollSymbol resolvedUnroll = getDeclaration().call(this);
setResolvedThis(resolvedUnroll);
}
}
return new HashSet<ParameterSymbol>() ;
}
private boolean isActive(){
if (getIfExpression().isSimpleValue() && !getIfExpression().getBooleanValue().get()){
return false;
}
else {
return true;
}
return getUnresolvableParameters();
}
protected void resolveExpressions() throws ArchResolveException{
for (ArgumentSymbol argument : getArguments()){
argument.resolveUnrollExpression();
argument.resolveExpression();
}
}
private ArchitectureElementSymbol createSerialSequencePart(List<ArchitectureElementSymbol> elements){
if (elements.size() == 1){
return elements.get(0);
}
else {
SerialCompositeElementSymbol serialComposite = new SerialCompositeElementSymbol();
serialComposite.setElements(elements);
if (getAstNode().isPresent()){
serialComposite.setAstNode(getAstNode().get());
}
return serialComposite;
}
}
protected void computeUnresolvableParameters(Set<ParameterSymbol> unresolvableVariables, Set<ParameterSymbol> allVariables) {
for (ArgumentSymbol argument : getArguments()){
argument.getRhs().checkIfResolvable(allVariables);
......@@ -288,59 +265,8 @@ public class UnrollSymbol extends CommonScopeSpanningSymbol {
}
}
// creates a body for each timestep of the unroll
public UnrollSymbol createUnrollForBackend(){
if(this.isExtendedForBackend){
return this;
}else {
int timestep;
SerialCompositeElementSymbol newBody;
List<ArchitectureElementSymbol> newBodyList;
for (timestep = this.getIntValue(AllPredefinedLayers.T_NAME).get(); timestep < this.getIntValue(AllPredefinedLayers.MAX_LENGTH_NAME).get(); timestep++) {
newBody = new SerialCompositeElementSymbol();
newBodyList = new ArrayList<>();
SerialCompositeElementSymbol body = this.getBody().preResolveDeepCopy();
body.putInScope(this.getBody().getSpannedScope());
for (ArchitectureElementSymbol element : body.getElements()) {
if(element.getEnclosingScope() == null) {
element.setEnclosingScope(getEnclosingScope().getAsMutableScope());
}
((ParameterSymbol)((ArrayList<Symbol>)getSpannedScope().getLocalSymbols().get(AllPredefinedLayers.T_NAME)).get(0)).getExpression().setValue(timestep);
try {
this.resolveExpressions();
for(ParameterSymbol p:declaration.getParameters()){
if(p.getEnclosingScope() == null) {
p.putInScope(getSpannedScope());
}
}
element.resolve();
} catch (ArchResolveException e) {e.printStackTrace();}
newBodyList.add(element);
}
newBody.putInScope(this.getBody().getSpannedScope());
newBody.setElements(newBodyList);
bodies.add(newBody);
}
newBody = new SerialCompositeElementSymbol();
ArrayList elementsList = new ArrayList();
elementsList.addAll(bodies);
this.setBody(newBody);
this.setExtended(true);
return this;
}
}
protected UnrollSymbol preResolveDeepCopy() {
protected ResolvableSymbol preResolveDeepCopy() {
UnrollSymbol copy = new UnrollSymbol(getName());
if (getAstNode().isPresent()){
copy.setAstNode(getAstNode().get());
......@@ -350,16 +276,6 @@ public class UnrollSymbol extends CommonScopeSpanningSymbol {
for (ArgumentSymbol argument : getArguments()){
args.add(argument.preResolveDeepCopy());
}
copy.setDeclaration(getDeclaration());
List<ParameterSymbol> parameterCopies = new ArrayList<>(getDeclaration().getParameters().size());
for (ParameterSymbol parameter : getDeclaration().getParameters()){
ParameterSymbol parameterCopy = parameter.deepCopy();
parameterCopies.add(parameterCopy);
parameterCopy.putInScope(copy.getSpannedScope());
}
copy.getDeclaration().setParameters(parameterCopies);
copy.setArguments(args);
copy.setBody(getBody().preResolveDeepCopy());
......@@ -367,34 +283,4 @@ public class UnrollSymbol extends CommonScopeSpanningSymbol {
return copy;
}
public static class Builder{
private UnrollDeclarationSymbol declaration;
private List<ArgumentSymbol> arguments = new ArrayList<>();
private boolean isResolved = false;
public Builder declaration(UnrollDeclarationSymbol declaration){
this.declaration = declaration;
return this;
}
public Builder arguments(List<ArgumentSymbol> arguments){
this.arguments = arguments;
return this;
}
public Builder arguments(ArgumentSymbol... arguments){
this.arguments = Arrays.asList(arguments);
return this;
}
public Builder isResolved(boolean isResolved){
this.isResolved = isResolved;
return this;
}
}
}
......@@ -52,6 +52,7 @@ public class AllPredefinedLayers {
public static final String LSTM_NAME = "LSTM";
public static final String GRU_NAME = "GRU";
public static final String EMBEDDING_NAME = "Embedding";
public static final String ARG_MAX_NAME = "ArgMax";
//predefined argument names
public static final String KERNEL_NAME = "kernel";
......@@ -110,7 +111,8 @@ public class AllPredefinedLayers {
RNN.create(),
LSTM.create(),
GRU.create(),
Embedding.create());
Embedding.create(),
ArgMax.create());
}
public static List<UnrollDeclarationSymbol> createUnrollList(){
......
/**
*
* ******************************************************************************
* MontiCAR Modeling Family, www.se-rwth.de
* Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
* All rights reserved.
*