Commit 0eb98f6f authored by Sebastian Nickels's avatar Sebastian Nickels
Browse files

Implemented bidirectional RNNs

parent afb3db82
...@@ -71,6 +71,7 @@ public class AllPredefinedLayers { ...@@ -71,6 +71,7 @@ public class AllPredefinedLayers {
public static final String LAYERS_NAME = "layers"; public static final String LAYERS_NAME = "layers";
public static final String INPUT_DIM_NAME = "input_dim"; public static final String INPUT_DIM_NAME = "input_dim";
public static final String OUTPUT_DIM_NAME = "output_dim"; public static final String OUTPUT_DIM_NAME = "output_dim";
public static final String BIDIRECTIONAL_NAME = "bidirectional";
public static final String BEAMSEARCH_MAX_LENGTH = "max_length"; public static final String BEAMSEARCH_MAX_LENGTH = "max_length";
public static final String BEAMSEARCH_WIDTH_NAME = "width"; public static final String BEAMSEARCH_WIDTH_NAME = "width";
......
...@@ -22,6 +22,8 @@ package de.monticore.lang.monticar.cnnarch.predefined; ...@@ -22,6 +22,8 @@ package de.monticore.lang.monticar.cnnarch.predefined;
import de.monticore.lang.monticar.cnnarch._symboltable.*; import de.monticore.lang.monticar.cnnarch._symboltable.*;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
...@@ -38,13 +40,14 @@ abstract public class BaseRNN extends PredefinedLayerDeclaration { ...@@ -38,13 +40,14 @@ abstract public class BaseRNN extends PredefinedLayerDeclaration {
@Override @Override
public List<ArchTypeSymbol> computeOutputTypes(List<ArchTypeSymbol> inputTypes, LayerSymbol layer, VariableSymbol.Member member) { public List<ArchTypeSymbol> computeOutputTypes(List<ArchTypeSymbol> inputTypes, LayerSymbol layer, VariableSymbol.Member member) {
boolean bidirectional = layer.getBooleanValue(AllPredefinedLayers.BIDIRECTIONAL_NAME).get();
int units = layer.getIntValue(AllPredefinedLayers.UNITS_NAME).get(); int units = layer.getIntValue(AllPredefinedLayers.UNITS_NAME).get();
if (member == VariableSymbol.Member.STATE) { if (member == VariableSymbol.Member.STATE) {
int layers = layer.getIntValue(AllPredefinedLayers.LAYERS_NAME).get(); int layers = layer.getIntValue(AllPredefinedLayers.LAYERS_NAME).get();
return Collections.singletonList(new ArchTypeSymbol.Builder() return Collections.singletonList(new ArchTypeSymbol.Builder()
.channels(layers) .channels(bidirectional ? 2 * layers : layers)
.height(units) .height(units)
.elementType("-oo", "oo") .elementType("-oo", "oo")
.build()); .build());
...@@ -52,7 +55,7 @@ abstract public class BaseRNN extends PredefinedLayerDeclaration { ...@@ -52,7 +55,7 @@ abstract public class BaseRNN extends PredefinedLayerDeclaration {
else { else {
return Collections.singletonList(new ArchTypeSymbol.Builder() return Collections.singletonList(new ArchTypeSymbol.Builder()
.channels(layer.getInputTypes().get(0).getChannels()) .channels(layer.getInputTypes().get(0).getChannels())
.height(units) .height(bidirectional ? 2 * units : units)
.elementType("-oo", "oo") .elementType("-oo", "oo")
.build()); .build());
} }
...@@ -60,12 +63,13 @@ abstract public class BaseRNN extends PredefinedLayerDeclaration { ...@@ -60,12 +63,13 @@ abstract public class BaseRNN extends PredefinedLayerDeclaration {
@Override @Override
public void checkInput(List<ArchTypeSymbol> inputTypes, LayerSymbol layer, VariableSymbol.Member member) { public void checkInput(List<ArchTypeSymbol> inputTypes, LayerSymbol layer, VariableSymbol.Member member) {
boolean bidirectional = layer.getBooleanValue(AllPredefinedLayers.BIDIRECTIONAL_NAME).get();
int units = layer.getIntValue(AllPredefinedLayers.UNITS_NAME).get(); int units = layer.getIntValue(AllPredefinedLayers.UNITS_NAME).get();
int layers = layer.getIntValue(AllPredefinedLayers.LAYERS_NAME).get(); int layers = layer.getIntValue(AllPredefinedLayers.LAYERS_NAME).get();
if (member == VariableSymbol.Member.STATE) { if (member == VariableSymbol.Member.STATE) {
errorIfInputSizeIsNotOne(inputTypes, layer); errorIfInputSizeIsNotOne(inputTypes, layer);
errorIfInputChannelSizeIsInvalid(inputTypes, layer, layers); errorIfInputChannelSizeIsInvalid(inputTypes, layer, bidirectional ? 2 * layers : layers);
errorIfInputHeightIsInvalid(inputTypes, layer, units); errorIfInputHeightIsInvalid(inputTypes, layer, units);
errorIfInputWidthIsInvalid(inputTypes, layer, 1); errorIfInputWidthIsInvalid(inputTypes, layer, 1);
} }
...@@ -92,4 +96,22 @@ abstract public class BaseRNN extends PredefinedLayerDeclaration { ...@@ -92,4 +96,22 @@ abstract public class BaseRNN extends PredefinedLayerDeclaration {
public boolean canBeOutput(VariableSymbol.Member member) { public boolean canBeOutput(VariableSymbol.Member member) {
return member == VariableSymbol.Member.NONE || member == VariableSymbol.Member.STATE; return member == VariableSymbol.Member.NONE || member == VariableSymbol.Member.STATE;
} }
protected static List<ParameterSymbol> createParameters() {
return new ArrayList<>(Arrays.asList(
new ParameterSymbol.Builder()
.name(AllPredefinedLayers.UNITS_NAME)
.constraints(Constraints.INTEGER, Constraints.POSITIVE)
.build(),
new ParameterSymbol.Builder()
.name(AllPredefinedLayers.LAYERS_NAME)
.constraints(Constraints.INTEGER, Constraints.POSITIVE)
.defaultValue(1)
.build(),
new ParameterSymbol.Builder()
.name(AllPredefinedLayers.BIDIRECTIONAL_NAME)
.constraints(Constraints.BOOLEAN)
.defaultValue(false)
.build()));
}
} }
...@@ -20,13 +20,6 @@ ...@@ -20,13 +20,6 @@
*/ */
package de.monticore.lang.monticar.cnnarch.predefined; package de.monticore.lang.monticar.cnnarch.predefined;
import de.monticore.lang.monticar.cnnarch._symboltable.Constraints;
import de.monticore.lang.monticar.cnnarch._symboltable.ParameterSymbol;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
public class GRU extends BaseRNN { public class GRU extends BaseRNN {
private GRU() { private GRU() {
...@@ -35,17 +28,7 @@ public class GRU extends BaseRNN { ...@@ -35,17 +28,7 @@ public class GRU extends BaseRNN {
public static GRU create() { public static GRU create() {
GRU declaration = new GRU(); GRU declaration = new GRU();
List<ParameterSymbol> parameters = new ArrayList<>(Arrays.asList( declaration.setParameters(createParameters());
new ParameterSymbol.Builder()
.name(AllPredefinedLayers.UNITS_NAME)
.constraints(Constraints.INTEGER, Constraints.POSITIVE)
.build(),
new ParameterSymbol.Builder()
.name(AllPredefinedLayers.LAYERS_NAME)
.constraints(Constraints.INTEGER, Constraints.POSITIVE)
.defaultValue(1)
.build()));
declaration.setParameters(parameters);
return declaration; return declaration;
} }
} }
...@@ -20,13 +20,6 @@ ...@@ -20,13 +20,6 @@
*/ */
package de.monticore.lang.monticar.cnnarch.predefined; package de.monticore.lang.monticar.cnnarch.predefined;
import de.monticore.lang.monticar.cnnarch._symboltable.Constraints;
import de.monticore.lang.monticar.cnnarch._symboltable.ParameterSymbol;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
public class LSTM extends BaseRNN { public class LSTM extends BaseRNN {
private LSTM() { private LSTM() {
...@@ -35,17 +28,7 @@ public class LSTM extends BaseRNN { ...@@ -35,17 +28,7 @@ public class LSTM extends BaseRNN {
public static LSTM create() { public static LSTM create() {
LSTM declaration = new LSTM(); LSTM declaration = new LSTM();
List<ParameterSymbol> parameters = new ArrayList<>(Arrays.asList( declaration.setParameters(createParameters());
new ParameterSymbol.Builder()
.name(AllPredefinedLayers.UNITS_NAME)
.constraints(Constraints.INTEGER, Constraints.POSITIVE)
.build(),
new ParameterSymbol.Builder()
.name(AllPredefinedLayers.LAYERS_NAME)
.constraints(Constraints.INTEGER, Constraints.POSITIVE)
.defaultValue(1)
.build()));
declaration.setParameters(parameters);
return declaration; return declaration;
} }
} }
...@@ -20,10 +20,6 @@ ...@@ -20,10 +20,6 @@
*/ */
package de.monticore.lang.monticar.cnnarch.predefined; package de.monticore.lang.monticar.cnnarch.predefined;
import de.monticore.lang.monticar.cnnarch._symboltable.*;
import java.util.*;
public class RNN extends BaseRNN { public class RNN extends BaseRNN {
private RNN() { private RNN() {
...@@ -32,17 +28,7 @@ public class RNN extends BaseRNN { ...@@ -32,17 +28,7 @@ public class RNN extends BaseRNN {
public static RNN create() { public static RNN create() {
RNN declaration = new RNN(); RNN declaration = new RNN();
List<ParameterSymbol> parameters = new ArrayList<>(Arrays.asList( declaration.setParameters(createParameters());
new ParameterSymbol.Builder()
.name(AllPredefinedLayers.UNITS_NAME)
.constraints(Constraints.INTEGER, Constraints.POSITIVE)
.build(),
new ParameterSymbol.Builder()
.name(AllPredefinedLayers.LAYERS_NAME)
.constraints(Constraints.INTEGER, Constraints.POSITIVE)
.defaultValue(1)
.build()));
declaration.setParameters(parameters);
return declaration; return declaration;
} }
} }
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment