Commit 1e584cc7 authored by Evgeny Kusmenko's avatar Evgeny Kusmenko
Browse files

Merge branch 'bagherinejad' into 'master'

Add position encoding layer

See merge request !57
parents bafacfc3 9858fcd5
Pipeline #714434 passed with stage
in 10 minutes and 11 seconds
......@@ -62,6 +62,7 @@ public class AllPredefinedLayers {
public static final String CONVOLUTION3D_NAME = "Convolution3D";
public static final String UP_CONVOLUTION3D_NAME = "UpConvolution3D";
public static final String VECTOR_QUANTIZE_NAME = "VectorQuantize";
public static final String POSITION_ENCODING_NAME = "PositionEncoding";
public static final String AdaNet_Name = "AdaNet"; //AdaNet layer
......@@ -248,7 +249,8 @@ public class AllPredefinedLayers {
GraphSumPool.create(),
AdaNet.create(),
Reparameterize.create(),
VectorQuantize.create());
VectorQuantize.create(),
PositionEncoding.create());
}
......
/**
*
* (c) https://github.com/MontiCore/monticore
*
* The license generally applicable for this project
* can be found under https://github.com/MontiCore/monticore.
*/
/* (c) https://github.com/MontiCore/monticore */
package de.monticore.lang.monticar.cnnarch.predefined;
import de.monticore.lang.monticar.cnnarch._symboltable.*;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
public class PositionEncoding extends PredefinedLayerDeclaration {
private PositionEncoding() {
super(AllPredefinedLayers.POSITION_ENCODING_NAME);
}
@Override
public List<ArchTypeSymbol> computeOutputTypes(List<ArchTypeSymbol> inputTypes, LayerSymbol layer, VariableSymbol.Member member) {
int outputDim = layer.getIntValue(AllPredefinedLayers.OUTPUT_DIM_NAME).get();
int maxLength = layer.getIntValue(AllPredefinedLayers.MAX_LENGTH_NAME).get();
return Collections.singletonList(new ArchTypeSymbol.Builder()
.channels(maxLength)
.height(outputDim)
.elementType("-oo", "oo")
.build());
}
@Override
public void checkInput(List<ArchTypeSymbol> inputTypes, LayerSymbol layer, VariableSymbol.Member member) {
;
}
public static PositionEncoding create(){
PositionEncoding declaration = new PositionEncoding();
List<ParameterSymbol> parameters = new ArrayList<>(Arrays.asList(
new ParameterSymbol.Builder()
.name(AllPredefinedLayers.MAX_LENGTH_NAME)
.constraints(Constraints.INTEGER, Constraints.POSITIVE)
.build(),
new ParameterSymbol.Builder()
.name(AllPredefinedLayers.OUTPUT_DIM_NAME)
.constraints(Constraints.INTEGER, Constraints.POSITIVE)
.build()));
declaration.setParameters(parameters);
return declaration;
}
}
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment