added DotProductSelfAttention layer, removed replay for normal MemoryLayer

parent 766d34c4
......@@ -54,12 +54,12 @@ public class AllPredefinedLayers {
public static final String SWAPAXES_NAME = "SwapAxes";
public static final String BROADCAST_ADD_NAME = "BroadcastAdd";
public static final String RESHAPE_NAME = "Reshape";
public static final String DOT_PRODUCT_SELF_ATTENTION_NAME = "DotProductSelfAttention";
//replay layers
public static final String MEMORY_NAME = "Memory";
public static final String REPLAY_MEMORY_NAME = "ReplayMemory";
public static final List<String> REPLAY_LAYER_NAMES = new ArrayList<String>(Arrays.asList(MEMORY_NAME,
REPLAY_MEMORY_NAME));
public static final List<String> REPLAY_LAYER_NAMES = new ArrayList<String>(Arrays.asList(REPLAY_MEMORY_NAME));
//predefined argument names
......@@ -96,13 +96,19 @@ public class AllPredefinedLayers {
public static final String SHAPE_NAME = "shape";
public static final String RNN_DROPOUT_NAME = "dropout";
//parameters DotProductSelfAttention
public static final String SCALE_FACTOR_NAME="scaleFactor";
public static final String DIM_KEYS_NAME="dimKeys";
public static final String DIM_VALUES_NAME="dimValues";
public static final String USE_PROJ_BIAS_NAME="useProjBias";
//shared parameters replay layers
public static final String USE_REPLAY_NAME = "useReplay";
public static final String REPLAY_INTERVAL_NAME = "replayInterval";
public static final String REPLAY_BATCH_SIZE_NAME = "replayBatchSize";
public static final String REPLAY_STEPS_NAME = "replaySteps";
public static final String REPLAY_GRADIENT_STEPS_NAME = "replayGradientSteps";
public static final String REPLAY_MEMORY_STORE_PROB_NAME = "replayMemoryStoreProb";
public static final String REPLAY_MEMORY_STORE_DIST_MEASURE_NAME = "replayMemoryStoreDistMeasure";
public static final String USE_LOCAL_ADAPTION_NAME = "useLocalAdaption";
public static final String LOCAL_ADAPTION_K_NAME = "localAdaptionK";
......@@ -115,11 +121,12 @@ public class AllPredefinedLayers {
public static final String QUERY_ACT_NAME = "queryAct";
public static final String K_NAME = "k";
public static final String NUM_HEADS_NAME = "numHeads";
public static final String STORE_DIST_MEASURE_NAME = "storeDistMeasure";
public static final String VALUE_SHAPE_NAME = "valueShape";
//parameters for replay memory layer
public static final String MAX_STORED_SAMPLES_NAME = "maxStoredSamples";
public static final String REPLAY_MEMORY_STORE_PROB_NAME = "replayMemoryStoreProb";
public static final String QUERY_NET_DIR_NAME = "queryNetDir";
public static final String QUERY_NET_PREFIX_NAME = "queryNetPrefix";
......
/**
*
* (c) https://github.com/MontiCore/monticore
*
* The license generally applicable for this project
* can be found under https://github.com/MontiCore/monticore.
*/
/* (c) https://github.com/MontiCore/monticore */
package de.monticore.lang.monticar.cnnarch.predefined;
import de.monticore.lang.monticar.cnnarch._symboltable.*;
import de.monticore.lang.monticar.cnnarch.helper.ErrorCodes;
import de.se_rwth.commons.logging.Log;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
public class DotProductSelfAttention extends PredefinedLayerDeclaration {
private DotProductSelfAttention() {
super(AllPredefinedLayers.DOT_PRODUCT_SELF_ATTENTION_NAME);
}
@Override
public List<ArchTypeSymbol> computeOutputTypes(List<ArchTypeSymbol> inputTypes, LayerSymbol layer, VariableSymbol.Member member) {
return layer.getInputTypes();
}
@Override
public void checkInput(List<ArchTypeSymbol> inputTypes, LayerSymbol layer, VariableSymbol.Member member) {
if (inputTypes.size() < 3) {
Log.error("0" + ErrorCodes.INVALID_ELEMENT_INPUT_SHAPE + " To few inputs. " +
"DotProductSelfAttentnion layer expects 3 Inputs: querys, keys, values, but "
+ inputTypes.size() + " were provided."
, layer.getSourcePosition());
} else if (inputTypes.size() > 3) {
Log.error("0" + ErrorCodes.INVALID_ELEMENT_INPUT_SHAPE + " To many inputs. " +
"DotProductSelfAttentnion layer expects 3 Inputs: querys, keys, values, but "
+ inputTypes.size() + " were provided."
, layer.getSourcePosition());
}
}
public static DotProductSelfAttention create(){
DotProductSelfAttention declaration = new DotProductSelfAttention();
List<ParameterSymbol> parameters = new ArrayList<>(Arrays.asList(
new ParameterSymbol.Builder()
.name(AllPredefinedLayers.SCALE_FACTOR_NAME)
.constraints(Constraints.INTEGER, Constraints.POSITIVE)
.defaultValue(-1)
.build(),
new ParameterSymbol.Builder()
.name(AllPredefinedLayers.NUM_HEADS_NAME)
.constraints(Constraints.INTEGER, Constraints.POSITIVE)
.defaultValue(1)
.build(),
new ParameterSymbol.Builder()
.name(AllPredefinedLayers.DIM_KEYS_NAME)
.constraints(Constraints.INTEGER, Constraints.POSITIVE)
.defaultValue(-1)
.build(),
new ParameterSymbol.Builder()
.name(AllPredefinedLayers.DIM_VALUES_NAME)
.constraints(Constraints.INTEGER, Constraints.POSITIVE)
.defaultValue(-1)
.build(),
new ParameterSymbol.Builder()
.name(AllPredefinedLayers.USE_PROJ_BIAS_NAME)
.constraints(Constraints.BOOLEAN)
.defaultValue(true)
.build()));
declaration.setParameters(parameters);
return declaration;
}
}
......@@ -61,55 +61,7 @@ public class Memory extends PredefinedLayerDeclaration {
Memory declaration = new Memory();
List<ParameterSymbol> parameters = new ArrayList<>(Arrays.asList(
new ParameterSymbol.Builder()
.name(AllPredefinedLayers.USE_REPLAY_NAME)
.constraints(Constraints.BOOLEAN)
.defaultValue(false)
.build(),
new ParameterSymbol.Builder()
.name(AllPredefinedLayers.REPLAY_INTERVAL_NAME)
.constraints(Constraints.INTEGER, Constraints.POSITIVE)
.build(),
new ParameterSymbol.Builder()
.name(AllPredefinedLayers.REPLAY_BATCH_SIZE_NAME)
.constraints(Constraints.INTEGER, Constraints.POSITIVE_OR_MINUS_ONE)
.defaultValue(-1)
.build(),
new ParameterSymbol.Builder()
.name(AllPredefinedLayers.REPLAY_STEPS_NAME)
.constraints(Constraints.INTEGER, Constraints.POSITIVE)
.build(),
new ParameterSymbol.Builder()
.name(AllPredefinedLayers.REPLAY_GRADIENT_STEPS_NAME)
.constraints(Constraints.INTEGER, Constraints.POSITIVE)
.defaultValue(1)
.build(),
new ParameterSymbol.Builder()
.name(AllPredefinedLayers.REPLAY_MEMORY_STORE_PROB_NAME)
.constraints(Constraints.NUMBER, Constraints.BETWEEN_ZERO_AND_ONE)
.defaultValue(1)
.build(),
new ParameterSymbol.Builder()
.name(AllPredefinedLayers.USE_LOCAL_ADAPTION_NAME)
.constraints(Constraints.BOOLEAN, Constraints.POSITIVE)
.defaultValue(false)
.build(),
new ParameterSymbol.Builder()
.name(AllPredefinedLayers.LOCAL_ADAPTION_K_NAME)
.constraints(Constraints.INTEGER, Constraints.POSITIVE)
.defaultValue(1)
.build(),
new ParameterSymbol.Builder()
.name(AllPredefinedLayers.LOCAL_ADAPTION_GRADIENT_STEPS_NAME)
.constraints(Constraints.INTEGER, Constraints.POSITIVE)
.defaultValue(1)
.build(),
new ParameterSymbol.Builder()
.name(AllPredefinedLayers.REPLAY_MEMORY_STORE_DIST_MEASURE_NAME)
.constraints(Constraints.DIST_MEASURE_TYPE)
.defaultValue(AllPredefinedLayers.INNER_PROD)
.build(),
new ParameterSymbol.Builder()
.name(AllPredefinedLayers.LOCAL_ADAPTION_MEMORY_STORE_DIST_MEASURE_NAME)
.name(AllPredefinedLayers.STORE_DIST_MEASURE_NAME)
.constraints(Constraints.DIST_MEASURE_TYPE)
.defaultValue(AllPredefinedLayers.INNER_PROD)
.build(),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment