Commit 07b62c4a authored by Julian Johannes Steinsberger-Dührßen's avatar Julian Johannes Steinsberger-Dührßen
Browse files

Bug Fixes

parent c7fbe677
Pipeline #286838 passed with stage
in 9 minutes and 24 seconds
......@@ -122,7 +122,7 @@ public class AllPredefinedLayers {
public static final String K_NAME = "k";
public static final String NUM_HEADS_NAME = "numHeads";
public static final String STORE_DIST_MEASURE_NAME = "storeDistMeasure";
public static final String VALUE_SHAPE_NAME = "valueShape";
public static final String VALUES_DIM_NAME = "valuesDim";
//parameters for replay memory layer
public static final String MAX_STORED_SAMPLES_NAME = "maxStoredSamples";
......@@ -184,6 +184,7 @@ public class AllPredefinedLayers {
SwapAxes.create(),
BroadcastAdd.create(),
Reshape.create(),
DotProductSelfAttention.create(),
Memory.create(),
ReplayMemory.create());
}
......
......@@ -25,7 +25,7 @@ public class DotProductSelfAttention extends PredefinedLayerDeclaration {
@Override
public List<ArchTypeSymbol> computeOutputTypes(List<ArchTypeSymbol> inputTypes, LayerSymbol layer, VariableSymbol.Member member) {
return layer.getInputTypes();
return Arrays.asList(layer.getInputTypes().get(2));
}
@Override
......
......@@ -89,8 +89,8 @@ public class Memory extends PredefinedLayerDeclaration {
.defaultValue(1)
.build(),
new ParameterSymbol.Builder()
.name(AllPredefinedLayers.VALUE_SHAPE_NAME)
.constraints(Constraints.INTEGER_OR_INTEGER_TUPLE, Constraints.POSITIVE_OR_MINUS_ONE)
.name(AllPredefinedLayers.VALUES_DIM_NAME)
.constraints(Constraints.INTEGER, Constraints.POSITIVE_OR_MINUS_ONE)
.defaultValue(-1)
.build()));
declaration.setParameters(parameters);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment