Commit d9c07007 authored by Sebastian Nickels's avatar Sebastian Nickels

Cleaned up code, renamed BeamSearchStart to BeamSearch

parent 6793aa59
Pipeline #180213 failed with stages
in 18 minutes and 31 seconds
......@@ -359,6 +359,7 @@
<version>2.19.1</version>
<configuration>
<argLine>-Xmx1024m -Xms1024m -XX:MaxPermSize=512m -Djdk.net.URLClassPath.disableClassPathURLCheck=true</argLine>
<trimStackTrace>false</trimStackTrace>
</configuration>
</plugin>
<plugin>
......
......@@ -62,7 +62,7 @@ grammar CNNArch extends de.monticore.CommonExpressions, de.monticore.lang.Math,
Stream = elements:(ArchitectureElement || "->")+;
Unroll = "timed" "<" timeParameter:LayerParameter ">"
Unroll = "timed" "<" timeParameter:TimeParameter ">"
Name "(" arguments:(ArchArgument || ",")* ")"
"{" body:Stream "}";
......@@ -90,6 +90,8 @@ grammar CNNArch extends de.monticore.CommonExpressions, de.monticore.lang.Math,
LayerParameter implements ArchParameter = Name ("=" default:ArchSimpleExpression)? ;
TimeParameter implements ArchParameter = Name ("=" default:ArchSimpleExpression)? ;
interface ArchArgument;
ArchParameterArgument implements ArchArgument = Name "=" rhs:ArchExpression ;
......
......@@ -269,6 +269,26 @@ public class CNNArchSymbolTableCreator extends de.monticore.symboltable.CommonSy
}
}
@Override
public void visit(ASTTimeParameter ast) {
ParameterSymbol variable = new ParameterSymbol(ast.getName());
variable.setType(ParameterType.TIME_PARAMETER);
addToScopeAndLinkWithNode(variable, ast);
}
@Override
public void endVisit(ASTTimeParameter ast) {
ParameterSymbol variable = (ParameterSymbol) ast.getSymbolOpt().get();
if (ast.isPresentDefault()){
variable.setDefaultExpression((ArchSimpleExpressionSymbol) ast.getDefault().getSymbolOpt().get());
}
else {
ArchSimpleExpressionSymbol expression = new ArchSimpleExpressionSymbol();
expression.setValue(1);
variable.setDefaultExpression(expression);
}
}
@Override
public void endVisit(ASTArchSimpleExpression ast) {
ArchSimpleExpressionSymbol sym = new ArchSimpleExpressionSymbol();
......@@ -348,10 +368,7 @@ public class CNNArchSymbolTableCreator extends de.monticore.symboltable.CommonSy
}
layer.setArguments(arguments);
ArchSimpleExpressionSymbol t_value = new ArchSimpleExpressionSymbol();
t_value.setValue(1);
((ParameterSymbol)ast.getTimeParameter().getSymbolOpt().get()).setDefaultExpression(t_value);
layer.setTimeParameter((ParameterSymbol)ast.getTimeParameter().getSymbolOpt().get());
layer.setTimeParameter((ParameterSymbol) ast.getTimeParameter().getSymbolOpt().get());
removeCurrentScope();
}
......
......@@ -38,7 +38,6 @@ public class ParameterSymbol extends CommonSymbol {
private ArchSimpleExpressionSymbol currentExpression = null; //Optional
private Set<Constraints> constraints = new HashSet<>();
protected ParameterSymbol(String name) {
super(name, KIND);
}
......@@ -94,6 +93,9 @@ public class ParameterSymbol extends CommonSymbol {
return type == ParameterType.LAYER_PARAMETER;
}
public boolean isTimeParameter(){
return type == ParameterType.TIME_PARAMETER;
}
public boolean hasExpression(){
return getCurrentExpression().isPresent() || getDefaultExpression().isPresent();
......
......@@ -39,6 +39,12 @@ public enum ParameterType {
return "constant";
}
},
TIME_PARAMETER {
@Override
public String toString(){
return "time parameter";
}
},
UNKNOWN {
//describes a parameter which does not exist. Only used to avoid exceptions while checking Cocos.
@Override
......
......@@ -35,9 +35,6 @@ public class UnrollDeclarationSymbol extends CommonScopeSpanningSymbol {
public static final UnrollDeclarationKind KIND = new UnrollDeclarationKind();
private List<ParameterSymbol> parameters;
private List<LayerSymbol> layers;
private SerialCompositeElementSymbol body;
protected UnrollDeclarationSymbol(String name) {
super(name, KIND);
......@@ -53,14 +50,6 @@ public class UnrollDeclarationSymbol extends CommonScopeSpanningSymbol {
return (UnrollDeclarationScope) super.getSpannedScope();
}
protected void setLayers(List<LayerSymbol> layers) {
this.layers = layers;
}
public List<LayerSymbol> getLayers() {
return layers;
}
public List<ParameterSymbol> getParameters() {
return parameters;
}
......@@ -82,23 +71,9 @@ public class UnrollDeclarationSymbol extends CommonScopeSpanningSymbol {
this.parameters.add(forParam);
forParam.putInScope(getSpannedScope());
}
}
public SerialCompositeElementSymbol getBody() {
return body;
}
protected void setBody(SerialCompositeElementSymbol body) {
this.body = body;
}
public boolean isPredefined() {
//Override by PredefinedUnrollDeclaration
return false;
}
public boolean isTrainable() {
return body.isTrainable();
for (ParameterSymbol param : parameters){
param.putInScope(getSpannedScope());
}
}
public Optional<ParameterSymbol> getParameter(String name) {
......@@ -141,66 +116,4 @@ public class UnrollDeclarationSymbol extends CommonScopeSpanningSymbol {
throw new IllegalArgumentException("Arguments with sequence expressions have to be resolved first before calling the layer method.");
}
}
public UnrollDeclarationSymbol deepCopy() {
UnrollDeclarationSymbol copy = new UnrollDeclarationSymbol(getName());
if (getAstNode().isPresent()){
copy.setAstNode(getAstNode().get());
}
System.err.println("parameters to copy: " + getParameters());
List<ParameterSymbol> parameterCopies = new ArrayList<>(getParameters().size());
for (ParameterSymbol parameter : getParameters()){
ParameterSymbol parameterCopy = parameter.deepCopy();
parameterCopies.add(parameterCopy);
parameterCopy.putInScope(copy.getSpannedScope());
}
copy.setParameters(parameterCopies);
copy.setBody(getBody().preResolveDeepCopy());
copy.getBody().putInScope(copy.getSpannedScope());
return copy;
}
/*public static class Builder{
private List<VariableSymbol> parameters = new ArrayList<>();
private CompositeElementSymbol body;
private String name = "";
public Builder parameters(List<VariableSymbol> parameters) {
this.parameters = parameters;
return this;
}
public Builder parameters(VariableSymbol... parameters) {
this.parameters = new ArrayList<>(Arrays.asList(parameters));
return this;
}
public Builder body(CompositeElementSymbol body) {
this.body = body;
return this;
}
public Builder name(String name) {
this.name = name;
return this;
}
public UnrollDeclarationSymbol build(){
if (name == null || name.equals("")){
throw new IllegalStateException("Missing or empty name for UnrollDeclarationSymbol");
}
UnrollDeclarationSymbol sym = new UnrollDeclarationSymbol(name);
sym.setBody(body);
if (body != null){
body.putInScope(sym.getSpannedScope());
}
for (VariableSymbol param : parameters){
param.putInScope(sym.getSpannedScope());
}
sym.setParameters(parameters);
return sym;
}
}*/
}
......@@ -93,6 +93,7 @@ public class UnrollSymbol extends ResolvableSymbol {
protected void setTimeParameter(ParameterSymbol timeParameter){
this.timeParameter = timeParameter;
this.timeParameter.putInScope(getSpannedScope());
}
protected void putInScope(Scope scope){
......@@ -114,45 +115,25 @@ public class UnrollSymbol extends ResolvableSymbol {
getDeclaration();
resolveExpressions();
getBody().resolveOrError();
for (int timestep = this.getIntValue(AllPredefinedLayers.T_NAME).get(); timestep < this.getIntValue(AllPredefinedLayers.MAX_LENGTH_NAME).get(); timestep++) {
SerialCompositeElementSymbol newBody = new SerialCompositeElementSymbol();
List<ArchitectureElementSymbol> newBodyList = new ArrayList<>();
SerialCompositeElementSymbol body = getBody().preResolveDeepCopy();
body.putInScope(getBody().getSpannedScope());
for (ArchitectureElementSymbol element : body.getElements()) {
if (element.getEnclosingScope() == null) {
element.setEnclosingScope(getEnclosingScope().getAsMutableScope());
}
((ParameterSymbol)((ArrayList<Symbol>)getSpannedScope().getLocalSymbols().get(AllPredefinedLayers.T_NAME)).get(0)).getExpression().setValue(timestep);
int startValue = getTimeParameter().getDefaultExpression().get().getIntValue().get();
int endValue = getIntValue(AllPredefinedLayers.MAX_LENGTH_NAME).get();
try {
this.resolveExpressions();
for (ParameterSymbol p:declaration.getParameters()) {
if (p.getEnclosingScope() == null) {
p.putInScope(getSpannedScope());
}
}
getTimeParameter().getExpression().setValue(1); // TODO: Change constant
getBody().resolveOrError();
element.resolve();
}
catch (ArchResolveException e) {
e.printStackTrace();
}
for (int timestep = startValue; timestep < endValue; timestep++) {
SerialCompositeElementSymbol currentBody = getBody().preResolveDeepCopy();
currentBody.putInScope(getBody().getSpannedScope());
newBodyList.add(element);
}
getTimeParameter().getExpression().setValue(timestep);
getTimeParameter().putInScope(currentBody.getEnclosingScope());
newBody.putInScope(this.getBody().getSpannedScope());
newBody.setElements(newBodyList);
currentBody.resolveOrError();
bodies.add(newBody);
bodies.add(currentBody);
}
UnrollSymbol resolvedUnroll = getDeclaration().call(this);
setResolvedThis(resolvedUnroll);
}
......
......@@ -46,7 +46,7 @@ public class AllPredefinedLayers {
public static final String CONCATENATE_NAME = "Concatenate";
public static final String FLATTEN_NAME = "Flatten";
public static final String ONE_HOT_NAME = "OneHot";
public static final String BEAMSEARCH_NAME = "BeamSearchStart";
public static final String BEAMSEARCH_NAME = "BeamSearch";
public static final String GREEDYSEARCH_NAME = "GreedySearch";
public static final String RNN_NAME = "RNN";
public static final String LSTM_NAME = "LSTM";
......@@ -78,7 +78,6 @@ public class AllPredefinedLayers {
public static final String FLATTEN_PARAMETER_NAME = "flatten";
public static final String MAX_LENGTH_NAME = "max_length";
public static final String WIDTH_NAME = "width";
public static final String T_NAME = "t";
//possible String values
public static final String PADDING_VALID = "valid";
......@@ -118,7 +117,7 @@ public class AllPredefinedLayers {
public static List<UnrollDeclarationSymbol> createUnrollList(){
return Arrays.asList(
GreedySearch.create(),
BeamSearchStart.create());
BeamSearch.create());
}
}
......@@ -26,34 +26,19 @@ import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
public class BeamSearchStart extends PredefinedUnrollDeclaration {
public class BeamSearch extends UnrollDeclarationSymbol {
private BeamSearchStart() {
private BeamSearch() {
super(AllPredefinedLayers.BEAMSEARCH_NAME);
}
@Override
public List<ArchTypeSymbol> computeOutputTypes(List<ArchTypeSymbol> inputTypes, UnrollSymbol layer, VariableSymbol.Member member) {
return new ArrayList<ArchTypeSymbol>();
}
@Override
public void checkInput(List<ArchTypeSymbol> inputTypes, UnrollSymbol layer, VariableSymbol.Member member) {
errorIfInputSizeIsNotOne(inputTypes, layer);
}
public static BeamSearchStart create(){
BeamSearchStart declaration = new BeamSearchStart();
public static BeamSearch create(){
BeamSearch declaration = new BeamSearch();
List<ParameterSymbol> parameters = new ArrayList<>(Arrays.asList(
new ParameterSymbol.Builder()
.name(AllPredefinedLayers.MAX_LENGTH_NAME)
.constraints(Constraints.INTEGER, Constraints.POSITIVE)
.build(),
new ParameterSymbol.Builder()
.name(AllPredefinedLayers.T_NAME)
.constraints(Constraints.INTEGER, Constraints.NON_NEGATIVE)
.defaultValue(1)
.build(),
new ParameterSymbol.Builder()
.name(AllPredefinedLayers.WIDTH_NAME)
.constraints(Constraints.INTEGER, Constraints.POSITIVE)
......
......@@ -26,33 +26,18 @@ import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
public class GreedySearch extends PredefinedUnrollDeclaration {
public class GreedySearch extends UnrollDeclarationSymbol {
private GreedySearch() {
super(AllPredefinedLayers.GREEDYSEARCH_NAME);
}
@Override
public List<ArchTypeSymbol> computeOutputTypes(List<ArchTypeSymbol> inputTypes, UnrollSymbol layer, VariableSymbol.Member member) {
return new ArrayList<ArchTypeSymbol>();
}
@Override
public void checkInput(List<ArchTypeSymbol> inputTypes, UnrollSymbol layer, VariableSymbol.Member member) {
errorIfInputSizeIsNotOne(inputTypes, layer);
}
public static GreedySearch create(){
GreedySearch declaration = new GreedySearch();
List<ParameterSymbol> parameters = new ArrayList<>(Arrays.asList(
new ParameterSymbol.Builder()
.name(AllPredefinedLayers.MAX_LENGTH_NAME)
.constraints(Constraints.INTEGER, Constraints.POSITIVE)
.build(),
new ParameterSymbol.Builder()
.name(AllPredefinedLayers.T_NAME)
.constraints(Constraints.INTEGER, Constraints.NON_NEGATIVE)
.defaultValue(1)
.build()));
declaration.setParameters(parameters);
return declaration;
......
......@@ -4,7 +4,7 @@ architecture RNNtest(max_length=50, vocabulary_size=30000, hidden_size=1000){
source -> Softmax() -> target[0];
timed <t> BeamSearchStart(max_length=5){
timed <t> BeamSearch(max_length=5){
(target[0] | target[t-1]) ->
Concatenate() ->
FullyConnected(units=30000) ->
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment