Commit ca5e7210 authored by Sebastian Nickels's avatar Sebastian Nickels
Browse files

Added array length for states (replaced isValidMember with getArrayLength() ==...

Added array length for states (replaced isValidMember with getArrayLength() == 1) and added RNNsearch
parent 79cfdeba
......@@ -27,6 +27,8 @@ import de.monticore.lang.monticar.cnnarch._symboltable.VariableSymbol;
import de.monticore.lang.monticar.cnnarch.helper.ErrorCodes;
import de.se_rwth.commons.logging.Log;
import java.util.Optional;
public class CheckVariableMember extends CNNArchSymbolCoCo {
@Override
......@@ -40,14 +42,32 @@ public class CheckVariableMember extends CNNArchSymbolCoCo {
if (variable.getType() == VariableSymbol.Type.LAYER) {
LayerDeclarationSymbol layerDeclaration = variable.getLayerVariableDeclaration().getLayer().getDeclaration();
if (layerDeclaration.isPredefined() && !((PredefinedLayerDeclaration) layerDeclaration).isValidMember(variable.getMember())) {
if (layerDeclaration.isPredefined() && ((PredefinedLayerDeclaration) layerDeclaration).getArrayLength(variable.getMember()) == 0) {
Log.error("0" + ErrorCodes.INVALID_MEMBER + " Layer has no member " + variable.getMember().toString().toLowerCase() + ". ",
variable.getSourcePosition());
}
if (variable.getArrayAccess().isPresent()) {
Log.error("0" + ErrorCodes.INVALID_MEMBER + " Currently layer variable array access is not implemented. ",
variable.getSourcePosition());
Optional<Integer> arrayAccess = variable.getArrayAccess().get().getIntValue();
int arrayLength = 0;
if (layerDeclaration.isPredefined()) {
arrayLength = ((PredefinedLayerDeclaration) layerDeclaration).getArrayLength(variable.getMember());
}
String name = variable.getName() + "." + variable.getMember().toString().toLowerCase();
if (arrayAccess.isPresent() && arrayLength == 1) {
Log.error("0" + ErrorCodes.INVALID_ARRAY_ACCESS + " The layer variable '" + name +
"' does not support array access. "
, variable.getSourcePosition());
} else if (!arrayAccess.isPresent() || arrayAccess.get() < 0 || arrayAccess.get() >= arrayLength) {
Log.error("0" + ErrorCodes.INVALID_ARRAY_ACCESS + " The layer variable array access value of '" + name +
"' must be an integer between 0 and " + (arrayLength - 1) + ". " +
"The current value is: " + variable.getArrayAccess().get().getValue().get().toString()
, variable.getSourcePosition());
}
//
}
}
......
......@@ -222,9 +222,9 @@ public class ArchTypeSymbol extends CommonSymbol {
}
public static class Builder{
private int height = 0;
private int width = 0;
private int channels = 0;
private int height = 1;
private int width = 1;
private int channels = 1;
private ASTElementType domain = null;
public Builder height(int height){
......@@ -264,26 +264,10 @@ public class ArchTypeSymbol extends CommonSymbol {
public ArchTypeSymbol build(){
ArchTypeSymbol sym = new ArchTypeSymbol();
int index = 0;
List<Integer> dimensions = new ArrayList<>();
if (channels != 0) {
sym.setChannelIndex(index++);
dimensions.add(channels);
}
if (height != 0) {
sym.setHeightIndex(index++);
dimensions.add(height);
}
if (width != 0) {
sym.setWidthIndex(index);
dimensions.add(width);
}
sym.setDimensions(dimensions);
sym.setChannelIndex(0);
sym.setHeightIndex(1);
sym.setWidthIndex(2);
sym.setDimensions(Arrays.asList(channels, height, width));
if (domain == null){
domain = new ASTElementType();
......
......@@ -85,8 +85,12 @@ abstract public class PredefinedLayerDeclaration extends LayerDeclarationSymbol
LayerSymbol layer,
VariableSymbol.Member member);
public boolean isValidMember(VariableSymbol.Member member) {
return member == VariableSymbol.Member.NONE || member == VariableSymbol.Member.OUTPUT;
public int getArrayLength(VariableSymbol.Member member) {
if (member == VariableSymbol.Member.NONE || member == VariableSymbol.Member.OUTPUT) {
return 1;
}
return 0;
}
public boolean canBeInput(VariableSymbol.Member member) {
......
......@@ -269,13 +269,40 @@ public class VariableSymbol extends ArchitectureElementSymbol {
// to the layer's state via member == STATE)
getLayerVariableDeclaration().getLayer().resolveOrError();
LayerDeclarationSymbol layerDeclaration = getLayerVariableDeclaration().getLayer().getDeclaration();
setResolvedThis(this);
if (!getArrayAccess().isPresent() && layerDeclaration.isPredefined() && ((PredefinedLayerDeclaration) layerDeclaration).getArrayLength(getMember()) > 1) {
List<ArchitectureElementSymbol> parallelElements = createExpandedParallelElements();
ParallelCompositeElementSymbol composite = new ParallelCompositeElementSymbol();
composite.setElements(parallelElements);
getSpannedScope().getAsMutableScope().add(composite);
composite.setAstNode(getAstNode().get());
for (ArchitectureElementSymbol element : parallelElements) {
element.putInScope(composite.getSpannedScope());
element.setAstNode(getAstNode().get());
}
if (getInputElement().isPresent()) {
composite.setInputElement(getInputElement().get());
}
if (getOutputElement().isPresent()) {
composite.setOutputElement(getOutputElement().get());
}
composite.resolveOrError();
setResolvedThis(composite);
}
else {
setResolvedThis(this);
}
}
else {
throw new ArchResolveException();
}
}
}
......@@ -317,6 +344,44 @@ public class VariableSymbol extends ArchitectureElementSymbol {
}
}
}
else if (getType() == Type.LAYER) {
LayerVariableDeclarationSymbol layerVariableDeclaration = getLayerVariableDeclaration();
PredefinedLayerDeclaration predefinedLayerDeclaration =
(PredefinedLayerDeclaration) layerVariableDeclaration.getLayer().getDeclaration();
int layerArrayLength = predefinedLayerDeclaration.getArrayLength(getMember());
if (!getInputElement().isPresent()) {
for (int i = 0; i < layerArrayLength; ++i) {
VariableSymbol element = new VariableSymbol(getName());
element.setArrayAccess(i);
element.setMember(getMember());
parallelElements.add(element);
}
}
else {
for (int i = 0; i < layerArrayLength; ++i) {
SerialCompositeElementSymbol serialComposite = new SerialCompositeElementSymbol();
VariableSymbol element = new VariableSymbol(getName());
element.setArrayAccess(i);
element.setMember(getMember());
element.setAstNode(getAstNode().get());
LayerSymbol getLayer = new LayerSymbol(AllPredefinedLayers.GET_NAME);
getLayer.setArguments(Collections.singletonList(
new ArgumentSymbol.Builder()
.parameter(AllPredefinedLayers.INDEX_NAME)
.value(ArchSimpleExpressionSymbol.of(i))
.build()));
getLayer.setAstNode(getAstNode().get());
serialComposite.setElements(Arrays.asList(getLayer, element));
parallelElements.add(serialComposite);
}
}
}
return parallelElements;
}
......@@ -405,7 +470,6 @@ public class VariableSymbol extends ArchitectureElementSymbol {
"Actual type: " + inputType.getName() + ".");
}
}
}
}
else if (getType() == Type.LAYER) {
......@@ -439,7 +503,9 @@ public class VariableSymbol extends ArchitectureElementSymbol {
return Optional.of(ioDeclaration.getArrayLength());
}
else if (getType() == Type.LAYER) {
return Optional.of(1);
PredefinedLayerDeclaration predefinedLayerDeclaration =
(PredefinedLayerDeclaration) getLayerVariableDeclaration().getLayer().getDeclaration();
return Optional.of(predefinedLayerDeclaration.getArrayLength(getMember()));
}
else {
return Optional.empty();
......
......@@ -29,8 +29,6 @@ import java.util.List;
abstract public class BaseRNN extends PredefinedLayerDeclaration {
protected int numberOfStates = 1;
public BaseRNN(String name) {
super(name);
}
......@@ -40,6 +38,17 @@ abstract public class BaseRNN extends PredefinedLayerDeclaration {
return member == VariableSymbol.Member.NONE;
}
@Override
public int getArrayLength(VariableSymbol.Member member) {
if (member == VariableSymbol.Member.NONE ||
member == VariableSymbol.Member.STATE ||
member == VariableSymbol.Member.OUTPUT) {
return 1;
}
return 0;
}
@Override
public List<ArchTypeSymbol> computeOutputTypes(List<ArchTypeSymbol> inputTypes, LayerSymbol layer, VariableSymbol.Member member) {
boolean bidirectional = layer.getBooleanValue(AllPredefinedLayers.BIDIRECTIONAL_NAME).get();
......@@ -49,9 +58,9 @@ abstract public class BaseRNN extends PredefinedLayerDeclaration {
int layers = layer.getIntValue(AllPredefinedLayers.LAYERS_NAME).get();
return Collections.singletonList(new ArchTypeSymbol.Builder()
.channels(numberOfStates)
.height(bidirectional ? 2 * layers : layers)
.width(units)
.channels(bidirectional ? 2 * layers : layers)
.height(units)
.width(1)
.elementType("-oo", "oo")
.build());
}
......@@ -59,6 +68,7 @@ abstract public class BaseRNN extends PredefinedLayerDeclaration {
return Collections.singletonList(new ArchTypeSymbol.Builder()
.channels(layer.getInputTypes().get(0).getChannels())
.height(bidirectional ? 2 * units : units)
.width(1)
.elementType("-oo", "oo")
.build());
}
......@@ -70,26 +80,18 @@ abstract public class BaseRNN extends PredefinedLayerDeclaration {
int units = layer.getIntValue(AllPredefinedLayers.UNITS_NAME).get();
int layers = layer.getIntValue(AllPredefinedLayers.LAYERS_NAME).get();
errorIfInputSizeIsNotOne(inputTypes, layer);
errorIfInputWidthIsInvalid(inputTypes, layer, 1);
if (member == VariableSymbol.Member.STATE) {
errorIfInputSizeIsNotOne(inputTypes, layer);
errorIfInputChannelSizeIsInvalid(inputTypes, layer, numberOfStates);
errorIfInputHeightIsInvalid(inputTypes, layer, bidirectional ? 2 * layers : layers);
errorIfInputWidthIsInvalid(inputTypes, layer, units);
errorIfInputChannelSizeIsInvalid(inputTypes, layer, bidirectional ? 2 * layers : layers);
errorIfInputHeightIsInvalid(inputTypes, layer, units);
}
else {
errorIfInputSizeIsNotOne(inputTypes, layer);
errorIfInputChannelSizeIsInvalid(inputTypes, layer, layer.getInputTypes().get(0).getChannels());
errorIfInputWidthIsInvalid(inputTypes, layer, 1);
}
}
@Override
public boolean isValidMember(VariableSymbol.Member member) {
return member == VariableSymbol.Member.NONE ||
member == VariableSymbol.Member.OUTPUT ||
member == VariableSymbol.Member.STATE;
}
@Override
public boolean canBeInput(VariableSymbol.Member member) {
return member == VariableSymbol.Member.OUTPUT || member == VariableSymbol.Member.STATE;
......
......@@ -20,12 +20,24 @@
*/
package de.monticore.lang.monticar.cnnarch.predefined;
import de.monticore.lang.monticar.cnnarch._symboltable.VariableSymbol;
public class LSTM extends BaseRNN {
private LSTM() {
super(AllPredefinedLayers.LSTM_NAME);
}
@Override
public int getArrayLength(VariableSymbol.Member member) {
if (member == VariableSymbol.Member.STATE) {
return 2;
}
else if (member == VariableSymbol.Member.NONE || member == VariableSymbol.Member.OUTPUT) {
return 1;
}
numberOfStates = 2;
return 0;
}
public static LSTM create() {
......
......@@ -67,6 +67,7 @@ public class AllCoCoTest extends AbstractCoCoTest {
checkValid("valid_tests", "MultipleOutputs");
checkValid("valid_tests", "MultipleStreams");
checkValid("valid_tests", "RNNencdec");
checkValid("valid_tests", "RNNsearch");
checkValid("valid_tests", "RNNtest");
}
......
......@@ -3,30 +3,30 @@ architecture RNNsearch{
def output Z(0:49999)^{1} target[30]
layer GRU(units=1000, bidirectional=true) encoder;
source -> Embedding(output_dim=620) -> encoder;
layer FullyConnected(units=1000, flatten=false) fc;
encoder.output -> fc;
source -> Embedding(output_dim=620) -> encoder -> fc;
1 -> target[0];
layer GRU(units=1000) decoder;
encoder.state -> SwapAxes(axes=(0, 1)) -> Split(n=2) -> [1] -> SwapAxes(axes=(0, 1)) -> decoder.state;
encoder.state -> Split(n=2) -> [1] -> decoder.state;
timed<t> GreedySearch(max_length=30) {
(
(
(
decoder.state ->
Repeat(n=30, axis=1)
Repeat(n=30, axis=0)
|
fc.output
) ->
Concatenate(dim=2) ->
Concatenate(axis=1) ->
FullyConnected(units=1000, flatten=false) ->
Tanh() ->
FullyConnected(units=30) ->
Softmax()
Softmax() ->
ExpandDims(axis=0)
|
fc.output
) ->
......@@ -35,12 +35,11 @@ architecture RNNsearch{
target[t-1] ->
Embedding(output_dim=620)
) ->
Concatenate() ->
Concatenate(axis=1) ->
decoder ->
FullyConnected(units=50000) ->
Softmax() ->
ArgMax() ->
target[t]
};
}
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment