Commit 8c07ee90 authored by Christian Fuß's avatar Christian Fuß
Browse files

added GreedySearch

parent f1df76ef
Pipeline #179481 passed with stages
in 21 minutes and 25 seconds
......@@ -298,7 +298,7 @@ public class UnrollSymbol extends CommonScopeSpanningSymbol {
ArchitectureSymbol architecture = getBody().getArchitecture();
List<ArchitectureElementSymbol> newBodyList;
for (timestep = this.getIntValue(AllPredefinedLayers.BEAMSEARCH_T_NAME).get(); timestep < this.getIntValue(AllPredefinedLayers.BEAMSEARCH_MAX_LENGTH).get(); timestep++) {
for (timestep = this.getIntValue(AllPredefinedLayers.T_NAME).get(); timestep < this.getIntValue(AllPredefinedLayers.MAX_LENGTH_NAME).get(); timestep++) {
newBody = new SerialCompositeElementSymbol();
newBodyList = new ArrayList<>();
SerialCompositeElementSymbol body = this.getBody().preResolveDeepCopy();
......@@ -310,7 +310,7 @@ public class UnrollSymbol extends CommonScopeSpanningSymbol {
element.setEnclosingScope(getEnclosingScope().getAsMutableScope());
}
((ParameterSymbol)((ArrayList<Symbol>)getSpannedScope().getLocalSymbols().get(AllPredefinedLayers.BEAMSEARCH_T_NAME)).get(0)).getExpression().setValue(timestep);
((ParameterSymbol)((ArrayList<Symbol>)getSpannedScope().getLocalSymbols().get(AllPredefinedLayers.T_NAME)).get(0)).getExpression().setValue(timestep);
try {
this.resolveExpressions();
......@@ -369,7 +369,7 @@ public class UnrollSymbol extends CommonScopeSpanningSymbol {
copy.setTimeParameter(getTimeParameter().deepCopy());
// TODO: currently only the defaultValue for t is used, make it use the assigned value instead
((ParameterSymbol)((ArrayList<Symbol>)getSpannedScope().getLocalSymbols().get("t")).get(0)).getExpression().setValue(this.getIntValue(AllPredefinedLayers.BEAMSEARCH_T_NAME));
((ParameterSymbol)((ArrayList<Symbol>)getSpannedScope().getLocalSymbols().get("t")).get(0)).getExpression().setValue(this.getIntValue(AllPredefinedLayers.T_NAME));
copy.getTimeParameter().putInScope(getSpannedScope());
......
......@@ -47,6 +47,7 @@ public class AllPredefinedLayers {
public static final String FLATTEN_NAME = "Flatten";
public static final String ONE_HOT_NAME = "OneHot";
public static final String BEAMSEARCH_NAME = "BeamSearchStart";
public static final String GREEDYSEARCH_NAME = "GreedySearch";
public static final String RNN_NAME = "RNN";
public static final String LSTM_NAME = "LSTM";
public static final String GRU_NAME = "GRU";
......@@ -74,9 +75,9 @@ public class AllPredefinedLayers {
public static final String OUTPUT_DIM_NAME = "output_dim";
public static final String BIDIRECTIONAL_NAME = "bidirectional";
public static final String FLATTEN_PARAMETER_NAME = "flatten";
public static final String BEAMSEARCH_MAX_LENGTH = "max_length";
public static final String BEAMSEARCH_WIDTH_NAME = "width";
public static final String BEAMSEARCH_T_NAME = "t";
public static final String MAX_LENGTH_NAME = "max_length";
public static final String WIDTH_NAME = "width";
public static final String T_NAME = "t";
//possible String values
public static final String PADDING_VALID = "valid";
......@@ -114,7 +115,8 @@ public class AllPredefinedLayers {
public static List<UnrollDeclarationSymbol> createUnrollList(){
return Arrays.asList(
BeamSearchStart.create());
GreedySearch.create(),
BeamSearchStart.create());
}
}
......@@ -46,16 +46,16 @@ public class BeamSearchStart extends PredefinedUnrollDeclaration {
BeamSearchStart declaration = new BeamSearchStart();
List<ParameterSymbol> parameters = new ArrayList<>(Arrays.asList(
new ParameterSymbol.Builder()
.name(AllPredefinedLayers.BEAMSEARCH_MAX_LENGTH)
.name(AllPredefinedLayers.MAX_LENGTH_NAME)
.constraints(Constraints.INTEGER, Constraints.POSITIVE)
.build(),
new ParameterSymbol.Builder()
.name(AllPredefinedLayers.BEAMSEARCH_T_NAME)
.name(AllPredefinedLayers.T_NAME)
.constraints(Constraints.INTEGER, Constraints.NON_NEGATIVE)
.defaultValue(1)
.build(),
new ParameterSymbol.Builder()
.name(AllPredefinedLayers.BEAMSEARCH_WIDTH_NAME)
.name(AllPredefinedLayers.WIDTH_NAME)
.constraints(Constraints.INTEGER, Constraints.POSITIVE)
.build()));
declaration.setParameters(parameters);
......
/**
*
* ******************************************************************************
* MontiCAR Modeling Family, www.se-rwth.de
* Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
* All rights reserved.
*
* This project is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 3.0 of the License, or (at your option) any later version.
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package de.monticore.lang.monticar.cnnarch.predefined;
import de.monticore.lang.monticar.cnnarch._symboltable.*;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
public class GreedySearch extends PredefinedUnrollDeclaration {
private GreedySearch() {
super(AllPredefinedLayers.GREEDYSEARCH_NAME);
}
@Override
public List<ArchTypeSymbol> computeOutputTypes(List<ArchTypeSymbol> inputTypes, UnrollSymbol layer, VariableSymbol.Member member) {
return new ArrayList<ArchTypeSymbol>();
}
@Override
public void checkInput(List<ArchTypeSymbol> inputTypes, UnrollSymbol layer, VariableSymbol.Member member) {
errorIfInputSizeIsNotOne(inputTypes, layer);
}
public static GreedySearch create(){
GreedySearch declaration = new GreedySearch();
List<ParameterSymbol> parameters = new ArrayList<>(Arrays.asList(
new ParameterSymbol.Builder()
.name(AllPredefinedLayers.MAX_LENGTH_NAME)
.constraints(Constraints.INTEGER, Constraints.POSITIVE)
.build(),
new ParameterSymbol.Builder()
.name(AllPredefinedLayers.T_NAME)
.constraints(Constraints.INTEGER, Constraints.NON_NEGATIVE)
.defaultValue(1)
.build()));
declaration.setParameters(parameters);
return declaration;
}
}
......@@ -14,7 +14,7 @@ architecture RNNencdec(max_length=50, vocabulary_size=30000, hidden_size=1000){
encoder.state -> decoder.state;
timed<t=1> BeamSearchStart(max_length=50) {
timed<t=1> GreedySearch(max_length=50) {
target[t-1] ->
Embedding(output_dim=hidden_size) ->
decoder ->
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment