Commit 8243a33c authored by Sebastian Nickels's avatar Sebastian Nickels

Added Embedding layer

parent f4aa25ae
......@@ -30,7 +30,7 @@
<groupId>de.monticore.lang.monticar</groupId>
<artifactId>cnn-arch</artifactId>
<version>0.3.2-SNAPSHOT</version>
<version>0.3.3-SNAPSHOT</version>
......
......@@ -49,6 +49,7 @@ public class AllPredefinedLayers {
public static final String RNN_NAME = "RNN";
public static final String LSTM_NAME = "LSTM";
public static final String GRU_NAME = "GRU";
public static final String EMBEDDING_NAME = "Embedding";
//predefined argument names
public static final String KERNEL_NAME = "kernel";
......@@ -68,6 +69,8 @@ public class AllPredefinedLayers {
public static final String POOL_TYPE_NAME = "pool_type";
public static final String SIZE_NAME = "size";
public static final String LAYERS_NAME = "layers";
public static final String INPUT_DIM_NAME = "input_dim";
public static final String OUTPUT_DIM_NAME = "output_dim";
public static final String BEAMSEARCH_MAX_LENGTH = "max_length";
public static final String BEAMSEARCH_WIDTH_NAME = "width";
......@@ -101,7 +104,8 @@ public class AllPredefinedLayers {
OneHot.create(),
RNN.create(),
LSTM.create(),
GRU.create());
GRU.create(),
Embedding.create());
}
}
/**
*
* ******************************************************************************
* MontiCAR Modeling Family, www.se-rwth.de
* Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
* All rights reserved.
*
* This project is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 3.0 of the License, or (at your option) any later version.
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package de.monticore.lang.monticar.cnnarch.predefined;
import de.monticore.lang.monticar.cnnarch._symboltable.*;
import de.monticore.lang.monticar.cnnarch.helper.ErrorCodes;
import de.monticore.lang.monticar.types2._ast.ASTElementType;
import de.se_rwth.commons.logging.Log;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
public class Embedding extends PredefinedLayerDeclaration {
private Embedding() {
super(AllPredefinedLayers.EMBEDDING_NAME);
}
@Override
public List<ArchTypeSymbol> computeOutputTypes(List<ArchTypeSymbol> inputTypes, LayerSymbol layer, VariableSymbol.Member member) {
int outputDim = layer.getIntValue(AllPredefinedLayers.OUTPUT_DIM_NAME).get();
ArchTypeSymbol inputType = layer.getInputTypes().get(0);
if (inputType.getHeight() == 1) {
return Collections.singletonList(new ArchTypeSymbol.Builder()
.channels(inputType.getChannels())
.height(outputDim)
.elementType("-oo", "oo")
.build());
}
else {
return Collections.singletonList(new ArchTypeSymbol.Builder()
.channels(inputType.getChannels())
.height(inputType.getHeight())
.width(outputDim)
.elementType("-oo", "oo")
.build());
}
}
private static void inferInputDim(LayerSymbol layer) {
ASTElementType domain = layer.getInputTypes().get(0).getDomain();
// Only infer when not already done and upper limit is available
if (layer.getIntValue(AllPredefinedLayers.INPUT_DIM_NAME).get() == 0
&& domain.isPresentRange()
&& !domain.getRange().hasNoUpperLimit()) {
int inputDim = domain.getRange().getEndValue().intValue();
layer.setIntValue(AllPredefinedLayers.INPUT_DIM_NAME, inputDim);
}
}
@Override
public void checkInput(List<ArchTypeSymbol> inputTypes, LayerSymbol layer, VariableSymbol.Member member) {
inferInputDim(layer);
errorIfInputSizeIsNotOne(inputTypes, layer);
// Only up to 3 dimensions are supported so the input needs to be at maximum 2-dimensional as the output has one
// more dimension than the output
errorIfInputWidthIsInvalid(inputTypes, layer, 1);
int inputDim = layer.getIntValue(AllPredefinedLayers.INPUT_DIM_NAME).get();
if (inputDim == 0) {
Log.error("0" + ErrorCodes.MISSING_ARGUMENT + " Missing argument. The argument 'input_dim' is in this case required. ", layer.getSourcePosition());
}
ASTElementType domain = layer.getInputTypes().get(0).getDomain();
if (!domain.isNaturalNumber()) {
Log.error("0" + ErrorCodes.INVALID_ELEMENT_INPUT_DOMAIN + " Invalid layer input domain: Input must be natural. ", layer.getSourcePosition());
}
else if (!domain.isPresentRange() || domain.getRange().hasNoUpperLimit()) {
Log.error("0" + ErrorCodes.INVALID_ELEMENT_INPUT_DOMAIN + " Invalid layer input domain: Input range must have an upper limit. ", layer.getSourcePosition());
}
else if (domain.getRange().getEndValue().intValue() > inputDim) {
Log.error("0" + ErrorCodes.INVALID_ELEMENT_INPUT_DOMAIN + " Invalid layer input domain: Input upper limit must be smaller than 'input_dim'. ", layer.getSourcePosition());
}
}
public static Embedding create(){
Embedding declaration = new Embedding();
List<ParameterSymbol> parameters = new ArrayList<>(Arrays.asList(
new ParameterSymbol.Builder()
.name(AllPredefinedLayers.INPUT_DIM_NAME)
.constraints(Constraints.INTEGER, Constraints.POSITIVE)
.defaultValue(0)
.build(),
new ParameterSymbol.Builder()
.name(AllPredefinedLayers.OUTPUT_DIM_NAME)
.constraints(Constraints.INTEGER, Constraints.POSITIVE)
.build()));
declaration.setParameters(parameters);
return declaration;
}
}
......@@ -50,14 +50,15 @@ public class OneHot extends PredefinedLayerDeclaration {
}
private static void inferSizeFromOutput(LayerSymbol layer){
int outputChannels = 0;
if (layer.getOutputElement().isPresent() && layer.getOutputElement().get().isOutput()) {
outputChannels = ((VariableSymbol) layer.getOutputElement().get()).getIoDeclaration().getType().getChannels();
}
// Only infer when not already done and next element is output
if (layer.getIntValue(AllPredefinedLayers.SIZE_NAME).get() == 0
&& layer.getOutputElement().isPresent()
&& layer.getOutputElement().get().isOutput()) {
int outputChannels = ((VariableSymbol) layer.getOutputElement().get()).getIoDeclaration().getType().getChannels();
layer.setIntValue(AllPredefinedLayers.SIZE_NAME, outputChannels);
}
}
@Override
public void checkInput(List<ArchTypeSymbol> inputTypes, LayerSymbol layer, VariableSymbol.Member member) {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment