Commit 914c7491 authored by Julian Johannes Steinsberger-Dührßen's avatar Julian Johannes Steinsberger-Dührßen
Browse files

added LargeMemory layer

parent 0bce8533
File mode changed from 100755 to 100644
......@@ -53,6 +53,7 @@ public class AllPredefinedLayers {
public static final String SWAPAXES_NAME = "SwapAxes";
public static final String BROADCAST_ADD_NAME = "BroadcastAdd";
public static final String RESHAPE_NAME = "Reshape";
public static final String LARGE_MEMORY_NAME = "LargeMemory";
//predefined argument names
public static final String KERNEL_NAME = "kernel";
......@@ -87,6 +88,10 @@ public class AllPredefinedLayers {
public static final String BEAMSEARCH_WIDTH_NAME = "width";
public static final String SHAPE_NAME = "shape";
public static final String RNN_DROPOUT_NAME = "dropout";
//parameters for memory layers
public static final String NUM_SUB_KEYS_NAME = "numSubKeys";
public static final String QUERRY_SIZE_NAME = "querrySize";
public static final String K_NAME = "k";
//possible String values
public static final String PADDING_VALID = "valid";
......@@ -131,7 +136,8 @@ public class AllPredefinedLayers {
BroadcastMultiply.create(),
SwapAxes.create(),
BroadcastAdd.create(),
Reshape.create());
Reshape.create(),
LargeMemory.create());
}
public static List<UnrollDeclarationSymbol> createUnrollList(){
......
/**
*
* (c) https://github.com/MontiCore/monticore
*
* The license generally applicable for this project
* can be found under https://github.com/MontiCore/monticore.
*/
/* (c) https://github.com/MontiCore/monticore */
package de.monticore.lang.monticar.cnnarch.predefined;
import de.monticore.lang.monticar.cnnarch._symboltable.*;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
public class LargeMemory extends PredefinedLayerDeclaration {
private LargeMemory() {
super(AllPredefinedLayers.LARGE_MEMORY_NAME);
}
@Override
public List<ArchTypeSymbol> computeOutputTypes(List<ArchTypeSymbol> inputTypes, LayerSymbol layer, VariableSymbol.Member member) {
int querrySize = layer.getIntValue(AllPredefinedLayers.QUERRY_SIZE_NAME).get();
return Collections.singletonList(new ArchTypeSymbol.Builder()
.channels(1)
.height(querrySize)
.width(1)
.elementType("-oo", "oo")
.build());
}
@Override
public void checkInput(List<ArchTypeSymbol> inputTypes, LayerSymbol layer, VariableSymbol.Member member) {
errorIfInputSizeIsNotOne(inputTypes, layer);
}
public static LargeMemory create(){
LargeMemory declaration = new LargeMemory();
List<ParameterSymbol> parameters = new ArrayList<>(Arrays.asList(
new ParameterSymbol.Builder()
.name(AllPredefinedLayers.NUM_SUB_KEYS_NAME)
.constraints(Constraints.INTEGER, Constraints.POSITIVE)
.build(),
new ParameterSymbol.Builder()
.name(AllPredefinedLayers.QUERRY_SIZE_NAME)
.constraints(Constraints.INTEGER, Constraints.POSITIVE)
.defaultValue(512)
.build(),
new ParameterSymbol.Builder()
.name(AllPredefinedLayers.K_NAME)
.constraints(Constraints.INTEGER, Constraints.POSITIVE)
.build()));
declaration.setParameters(parameters);
return declaration;
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment