Commit 1cfffbee authored by Julian Johannes Steinsberger-Dührßen's avatar Julian Johannes Steinsberger-Dührßen
Browse files

Added mask option to DotProductSelfAttention layer

parent c0945e7f
Pipeline #310122 passed with stage
in 8 minutes and 37 seconds
......@@ -108,6 +108,7 @@ public class AllPredefinedLayers {
public static final String DIM_KEYS_NAME="dimKeys";
public static final String DIM_VALUES_NAME="dimValues";
public static final String USE_PROJ_BIAS_NAME="useProjBias";
public static final String USE_MASK_NAME="useMask";
//shared parameters episodic replay layers
public static final String USE_REPLAY_NAME = "useReplay";
......
......@@ -32,12 +32,14 @@ public class DotProductSelfAttention extends PredefinedLayerDeclaration {
public void checkInput(List<ArchTypeSymbol> inputTypes, LayerSymbol layer, VariableSymbol.Member member) {
if (inputTypes.size() < 3) {
Log.error("0" + ErrorCodes.INVALID_ELEMENT_INPUT_SHAPE + " To few inputs. " +
"DotProductSelfAttentnion layer expects 3 Inputs: querys, keys, values, but "
"DotProductSelfAttentnion layer expects at least 3 Inputs: querys, keys, values and " +
"optionally a mask if useMask is set to true, but "
+ inputTypes.size() + " were provided."
, layer.getSourcePosition());
} else if (inputTypes.size() > 3) {
} else if (inputTypes.size() > 4) {
Log.error("0" + ErrorCodes.INVALID_ELEMENT_INPUT_SHAPE + " To many inputs. " +
"DotProductSelfAttentnion layer expects 3 Inputs: querys, keys, values, but "
"DotProductSelfAttentnion layer expects 3 or 4 Inputs: querys, keys, values and " +
"optionally a mask if useMask is set to true, but "
+ inputTypes.size() + " were provided."
, layer.getSourcePosition());
}
......@@ -70,6 +72,11 @@ public class DotProductSelfAttention extends PredefinedLayerDeclaration {
.name(AllPredefinedLayers.USE_PROJ_BIAS_NAME)
.constraints(Constraints.BOOLEAN)
.defaultValue(true)
.build(),
new ParameterSymbol.Builder()
.name(AllPredefinedLayers.USE_MASK_NAME)
.constraints(Constraints.BOOLEAN)
.defaultValue(false)
.build()));
declaration.setParameters(parameters);
return declaration;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment