Aufgrund einer Wartung wird GitLab am 21.09. zwischen 8:00 und 9:00 Uhr kurzzeitig nicht zur Verfügung stehen. / Due to maintenance, GitLab will be temporarily unavailable on 21.09. between 8:00 and 9:00 am.

Commit 0c5691b4 authored by Julian Johannes Steinsberger-Dührßen's avatar Julian Johannes Steinsberger-Dührßen
Browse files

added LargeMemory layer

parent 9e697195
......@@ -89,6 +89,7 @@ public class AllPredefinedLayers {
public static final String BEAMSEARCH_WIDTH_NAME = "width";
public static final String SHAPE_NAME = "shape";
public static final String RNN_DROPOUT_NAME = "dropout";
//parameters for memory layer
public static final String SUB_KEY_SIZE_NAME = "subKeySize";
public static final String QUERY_SIZE_NAME = "querySize";
......@@ -96,6 +97,7 @@ public class AllPredefinedLayers {
public static final String K_NAME = "k";
public static final String NUM_HEADS_NAME = "numHeads";
public static final String VALUE_SHAPE_NAME = "valueShape";
//parameters for replay memory layer
public static final String REPLAY_INTERVAL_NAME = "replayInterval";
public static final String REPLAY_BATCH_SIZE_NAME = "replayBatchSize";
......@@ -103,7 +105,7 @@ public class AllPredefinedLayers {
public static final String REPLAY_GRADIENT_STEPS_NAME = "replayGradientSteps";
public static final String STORE_PROB_NAME = "storeProb";
public static final String MAX_STORED_SAMPLES_NAME = "maxStoredSamples";
//possible String values
public static final String PADDING_VALID = "valid";
public static final String PADDING_SAME = "same";
......
/**
*
* (c) https://github.com/MontiCore/monticore
*
* The license generally applicable for this project
* can be found under https://github.com/MontiCore/monticore.
*/
/* (c) https://github.com/MontiCore/monticore */
package de.monticore.lang.monticar.cnnarch.predefined;
import de.monticore.lang.monticar.cnnarch._symboltable.*;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
public class LargeMemory extends PredefinedLayerDeclaration {
private LargeMemory() {
super(AllPredefinedLayers.LARGE_MEMORY_NAME);
}
@Override
public List<ArchTypeSymbol> computeOutputTypes(List<ArchTypeSymbol> inputTypes, LayerSymbol layer, VariableSymbol.Member member) {
int querrySize = layer.getIntValue(AllPredefinedLayers.QUERRY_SIZE_NAME).get();
return Collections.singletonList(new ArchTypeSymbol.Builder()
.channels(1)
.height(querrySize)
.width(1)
.elementType("-oo", "oo")
.build());
}
@Override
public void checkInput(List<ArchTypeSymbol> inputTypes, LayerSymbol layer, VariableSymbol.Member member) {
errorIfInputSizeIsNotOne(inputTypes, layer);
}
public static LargeMemory create(){
LargeMemory declaration = new LargeMemory();
List<ParameterSymbol> parameters = new ArrayList<>(Arrays.asList(
new ParameterSymbol.Builder()
.name(AllPredefinedLayers.NUM_SUB_KEYS_NAME)
.constraints(Constraints.INTEGER, Constraints.POSITIVE)
.build(),
new ParameterSymbol.Builder()
.name(AllPredefinedLayers.QUERRY_SIZE_NAME)
.constraints(Constraints.INTEGER, Constraints.POSITIVE)
.defaultValue(512)
.build(),
new ParameterSymbol.Builder()
.name(AllPredefinedLayers.K_NAME)
.constraints(Constraints.INTEGER, Constraints.POSITIVE)
.build()));
declaration.setParameters(parameters);
return declaration;
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment