Commit 235254ec authored by Evgeny Kusmenko's avatar Evgeny Kusmenko

Merge branch 'word_embedding' into 'master'

added reshape layer

See merge request !27
parents d4c18ae4 3f74a9ad
Pipeline #201607 passed with stages
in 20 minutes and 47 seconds
......@@ -39,6 +39,7 @@ public class AllPredefinedLayers {
public static final String LSTM_NAME = "LSTM";
public static final String GRU_NAME = "GRU";
public static final String EMBEDDING_NAME = "Embedding";
public static final String RESHAPE_NAME = "Reshape";
//predefined argument names
public static final String KERNEL_NAME = "kernel";
......@@ -64,6 +65,7 @@ public class AllPredefinedLayers {
public static final String FLATTEN_PARAMETER_NAME = "flatten";
public static final String BEAMSEARCH_MAX_LENGTH = "max_length";
public static final String BEAMSEARCH_WIDTH_NAME = "width";
public static final String SHAPE_NAME = "shape";
//possible String values
public static final String PADDING_VALID = "valid";
......@@ -98,6 +100,7 @@ public class AllPredefinedLayers {
LSTM.create(),
GRU.create(),
Embedding.create(),
Reshape.create(),
RNN.create());
}
}
/**
*
* (c) https://github.com/MontiCore/monticore
*
* The license generally applicable for this project
* can be found under https://github.com/MontiCore/monticore.
*/
package de.monticore.lang.monticar.cnnarch.predefined;
import de.monticore.lang.monticar.cnnarch._symboltable.*;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
public class Reshape extends PredefinedLayerDeclaration {
public Reshape() {
super(AllPredefinedLayers.RESHAPE_NAME);
}
@Override
public List<ArchTypeSymbol> computeOutputTypes(List<ArchTypeSymbol> inputTypes, LayerSymbol layer, VariableSymbol.Member member) {
List<Integer> shape = layer.getIntTupleValue(AllPredefinedLayers.SHAPE_NAME).get();
Collections.reverse(shape);
return Collections.singletonList(
new ArchTypeSymbol.Builder()
.channels(shape.size() < 3 ? 0 : shape.get(2))
.height(shape.size() < 2 ? 0 : shape.get(1))
.width(shape.size() < 1 ? 0 : shape.get(0))
.elementType("0", "oo")
.build());
}
@Override
public void checkInput(List<ArchTypeSymbol> inputTypes, LayerSymbol layer, VariableSymbol.Member member) {
}
public static Reshape create(){
Reshape reshape = new Reshape();
List<ParameterSymbol> parameters = new ArrayList<>(Arrays.asList(
new ParameterSymbol.Builder()
.name(AllPredefinedLayers.SHAPE_NAME)
.constraints(Constraints.INTEGER_TUPLE)
.defaultValue(-2)
.build()));
reshape.setParameters(parameters);
return reshape;
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment