Commit 0eb98f6f authored by Sebastian Nickels's avatar Sebastian Nickels

Implemented bidirectional RNNs

parent afb3db82
......@@ -71,6 +71,7 @@ public class AllPredefinedLayers {
public static final String LAYERS_NAME = "layers";
public static final String INPUT_DIM_NAME = "input_dim";
public static final String OUTPUT_DIM_NAME = "output_dim";
public static final String BIDIRECTIONAL_NAME = "bidirectional";
public static final String BEAMSEARCH_MAX_LENGTH = "max_length";
public static final String BEAMSEARCH_WIDTH_NAME = "width";
......
......@@ -22,6 +22,8 @@ package de.monticore.lang.monticar.cnnarch.predefined;
import de.monticore.lang.monticar.cnnarch._symboltable.*;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
......@@ -38,13 +40,14 @@ abstract public class BaseRNN extends PredefinedLayerDeclaration {
@Override
public List<ArchTypeSymbol> computeOutputTypes(List<ArchTypeSymbol> inputTypes, LayerSymbol layer, VariableSymbol.Member member) {
boolean bidirectional = layer.getBooleanValue(AllPredefinedLayers.BIDIRECTIONAL_NAME).get();
int units = layer.getIntValue(AllPredefinedLayers.UNITS_NAME).get();
if (member == VariableSymbol.Member.STATE) {
int layers = layer.getIntValue(AllPredefinedLayers.LAYERS_NAME).get();
return Collections.singletonList(new ArchTypeSymbol.Builder()
.channels(layers)
.channels(bidirectional ? 2 * layers : layers)
.height(units)
.elementType("-oo", "oo")
.build());
......@@ -52,7 +55,7 @@ abstract public class BaseRNN extends PredefinedLayerDeclaration {
else {
return Collections.singletonList(new ArchTypeSymbol.Builder()
.channels(layer.getInputTypes().get(0).getChannels())
.height(units)
.height(bidirectional ? 2 * units : units)
.elementType("-oo", "oo")
.build());
}
......@@ -60,12 +63,13 @@ abstract public class BaseRNN extends PredefinedLayerDeclaration {
@Override
public void checkInput(List<ArchTypeSymbol> inputTypes, LayerSymbol layer, VariableSymbol.Member member) {
boolean bidirectional = layer.getBooleanValue(AllPredefinedLayers.BIDIRECTIONAL_NAME).get();
int units = layer.getIntValue(AllPredefinedLayers.UNITS_NAME).get();
int layers = layer.getIntValue(AllPredefinedLayers.LAYERS_NAME).get();
if (member == VariableSymbol.Member.STATE) {
errorIfInputSizeIsNotOne(inputTypes, layer);
errorIfInputChannelSizeIsInvalid(inputTypes, layer, layers);
errorIfInputChannelSizeIsInvalid(inputTypes, layer, bidirectional ? 2 * layers : layers);
errorIfInputHeightIsInvalid(inputTypes, layer, units);
errorIfInputWidthIsInvalid(inputTypes, layer, 1);
}
......@@ -92,4 +96,22 @@ abstract public class BaseRNN extends PredefinedLayerDeclaration {
public boolean canBeOutput(VariableSymbol.Member member) {
return member == VariableSymbol.Member.NONE || member == VariableSymbol.Member.STATE;
}
protected static List<ParameterSymbol> createParameters() {
return new ArrayList<>(Arrays.asList(
new ParameterSymbol.Builder()
.name(AllPredefinedLayers.UNITS_NAME)
.constraints(Constraints.INTEGER, Constraints.POSITIVE)
.build(),
new ParameterSymbol.Builder()
.name(AllPredefinedLayers.LAYERS_NAME)
.constraints(Constraints.INTEGER, Constraints.POSITIVE)
.defaultValue(1)
.build(),
new ParameterSymbol.Builder()
.name(AllPredefinedLayers.BIDIRECTIONAL_NAME)
.constraints(Constraints.BOOLEAN)
.defaultValue(false)
.build()));
}
}
......@@ -20,13 +20,6 @@
*/
package de.monticore.lang.monticar.cnnarch.predefined;
import de.monticore.lang.monticar.cnnarch._symboltable.Constraints;
import de.monticore.lang.monticar.cnnarch._symboltable.ParameterSymbol;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
public class GRU extends BaseRNN {
private GRU() {
......@@ -35,17 +28,7 @@ public class GRU extends BaseRNN {
public static GRU create() {
GRU declaration = new GRU();
List<ParameterSymbol> parameters = new ArrayList<>(Arrays.asList(
new ParameterSymbol.Builder()
.name(AllPredefinedLayers.UNITS_NAME)
.constraints(Constraints.INTEGER, Constraints.POSITIVE)
.build(),
new ParameterSymbol.Builder()
.name(AllPredefinedLayers.LAYERS_NAME)
.constraints(Constraints.INTEGER, Constraints.POSITIVE)
.defaultValue(1)
.build()));
declaration.setParameters(parameters);
declaration.setParameters(createParameters());
return declaration;
}
}
......@@ -20,13 +20,6 @@
*/
package de.monticore.lang.monticar.cnnarch.predefined;
import de.monticore.lang.monticar.cnnarch._symboltable.Constraints;
import de.monticore.lang.monticar.cnnarch._symboltable.ParameterSymbol;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
public class LSTM extends BaseRNN {
private LSTM() {
......@@ -35,17 +28,7 @@ public class LSTM extends BaseRNN {
public static LSTM create() {
LSTM declaration = new LSTM();
List<ParameterSymbol> parameters = new ArrayList<>(Arrays.asList(
new ParameterSymbol.Builder()
.name(AllPredefinedLayers.UNITS_NAME)
.constraints(Constraints.INTEGER, Constraints.POSITIVE)
.build(),
new ParameterSymbol.Builder()
.name(AllPredefinedLayers.LAYERS_NAME)
.constraints(Constraints.INTEGER, Constraints.POSITIVE)
.defaultValue(1)
.build()));
declaration.setParameters(parameters);
declaration.setParameters(createParameters());
return declaration;
}
}
......@@ -20,10 +20,6 @@
*/
package de.monticore.lang.monticar.cnnarch.predefined;
import de.monticore.lang.monticar.cnnarch._symboltable.*;
import java.util.*;
public class RNN extends BaseRNN {
private RNN() {
......@@ -32,17 +28,7 @@ public class RNN extends BaseRNN {
public static RNN create() {
RNN declaration = new RNN();
List<ParameterSymbol> parameters = new ArrayList<>(Arrays.asList(
new ParameterSymbol.Builder()
.name(AllPredefinedLayers.UNITS_NAME)
.constraints(Constraints.INTEGER, Constraints.POSITIVE)
.build(),
new ParameterSymbol.Builder()
.name(AllPredefinedLayers.LAYERS_NAME)
.constraints(Constraints.INTEGER, Constraints.POSITIVE)
.defaultValue(1)
.build()));
declaration.setParameters(parameters);
declaration.setParameters(createParameters());
return declaration;
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment