Working LargeLayer

parent 3862b755
......@@ -191,6 +191,33 @@ public enum Constraints {
+ AllPredefinedLayers.POOL_AVG;
}
},
ACTIVATION_TYPE {
@Override
public boolean isValid(ArchSimpleExpressionSymbol exp) {
Optional<String> optString= exp.getStringValue();
if (optString.isPresent()){
if (optString.get().equals(AllPredefinedLayers.MEMORY_ACTIVATION_LINEAR)
|| optString.get().equals(AllPredefinedLayers.MEMORY_ACTIVATION_RELU)
|| optString.get().equals(AllPredefinedLayers.MEMORY_ACTIVATION_TANH)
|| optString.get().equals(AllPredefinedLayers.MEMORY_ACTIVATION_SIGMOID)
|| optString.get().equals(AllPredefinedLayers.MEMORY_ACTIVATION_SOFTRELU)
|| optString.get().equals(AllPredefinedLayers.MEMORY_ACTIVATION_SOFTSIGN)){
return true;
}
}
return false;
}
@Override
protected String msgString() {
return AllPredefinedLayers.MEMORY_ACTIVATION_LINEAR + " or "
+ AllPredefinedLayers.MEMORY_ACTIVATION_RELU + " or "
+ AllPredefinedLayers.MEMORY_ACTIVATION_TANH + " or "
+ AllPredefinedLayers.MEMORY_ACTIVATION_SIGMOID + " or "
+ AllPredefinedLayers.MEMORY_ACTIVATION_SOFTRELU + " or "
+ AllPredefinedLayers.MEMORY_ACTIVATION_SOFTSIGN;
}
},
NULLABLE_AXIS {
@Override
public boolean isValid(ArchSimpleExpressionSymbol exp) {
......
......@@ -53,7 +53,7 @@ public class AllPredefinedLayers {
public static final String SWAPAXES_NAME = "SwapAxes";
public static final String BROADCAST_ADD_NAME = "BroadcastAdd";
public static final String RESHAPE_NAME = "Reshape";
public static final String LARGE_MEMORY_NAME = "LargeMemory";
public static final String MEMORY_NAME = "Memory";
//predefined argument names
public static final String KERNEL_NAME = "kernel";
......@@ -89,9 +89,11 @@ public class AllPredefinedLayers {
public static final String SHAPE_NAME = "shape";
public static final String RNN_DROPOUT_NAME = "dropout";
//parameters for memory layers
public static final String NUM_SUB_KEYS_NAME = "numSubKeys";
public static final String QUERRY_SIZE_NAME = "querrySize";
public static final String SUB_KEY_SIZE_NAME = "subKeySize";
public static final String QUERY_SIZE_NAME = "querySize";
public static final String ACT_QUERY_NAME = "actQuery";
public static final String K_NAME = "k";
public static final String NUM_HEADS_NAME = "numHeads";
//possible String values
public static final String PADDING_VALID = "valid";
......@@ -99,7 +101,14 @@ public class AllPredefinedLayers {
public static final String PADDING_NO_LOSS = "no_loss";
public static final String POOL_MAX = "max";
public static final String POOL_AVG = "avg";
//possible activation values for the querry network in the memory layer
public static final String MEMORY_ACTIVATION_LINEAR = "linear";
public static final String MEMORY_ACTIVATION_RELU = "relu";
public static final String MEMORY_ACTIVATION_TANH = "tanh";
public static final String MEMORY_ACTIVATION_SIGMOID = "sigmoid";
public static final String MEMORY_ACTIVATION_SOFTRELU = "softrelu";
public static final String MEMORY_ACTIVATION_SOFTSIGN = "softsign";
//list with all predefined layers
public static List<LayerDeclarationSymbol> createList(){
......@@ -137,7 +146,7 @@ public class AllPredefinedLayers {
SwapAxes.create(),
BroadcastAdd.create(),
Reshape.create(),
LargeMemory.create());
Memory.create());
}
public static List<UnrollDeclarationSymbol> createUnrollList(){
......
......@@ -15,20 +15,20 @@ import java.util.Arrays;
import java.util.Collections;
import java.util.List;
public class LargeMemory extends PredefinedLayerDeclaration {
public class Memory extends PredefinedLayerDeclaration {
private LargeMemory() {
super(AllPredefinedLayers.LARGE_MEMORY_NAME);
private Memory() {
super(AllPredefinedLayers.MEMORY_NAME);
}
@Override
public List<ArchTypeSymbol> computeOutputTypes(List<ArchTypeSymbol> inputTypes, LayerSymbol layer, VariableSymbol.Member member) {
int querrySize = layer.getIntValue(AllPredefinedLayers.QUERRY_SIZE_NAME).get();
int querySize = layer.getIntValue(AllPredefinedLayers.QUERY_SIZE_NAME).get();
return Collections.singletonList(new ArchTypeSymbol.Builder()
.channels(1)
.height(querrySize)
.height(querySize)
.width(1)
.elementType("-oo", "oo")
.build());
......@@ -39,21 +39,30 @@ public class LargeMemory extends PredefinedLayerDeclaration {
errorIfInputSizeIsNotOne(inputTypes, layer);
}
public static LargeMemory create(){
LargeMemory declaration = new LargeMemory();
public static Memory create(){
Memory declaration = new Memory();
List<ParameterSymbol> parameters = new ArrayList<>(Arrays.asList(
new ParameterSymbol.Builder()
.name(AllPredefinedLayers.NUM_SUB_KEYS_NAME)
.name(AllPredefinedLayers.SUB_KEY_SIZE_NAME)
.constraints(Constraints.INTEGER, Constraints.POSITIVE)
.build(),
new ParameterSymbol.Builder()
.name(AllPredefinedLayers.QUERRY_SIZE_NAME)
.name(AllPredefinedLayers.QUERY_SIZE_NAME)
.constraints(Constraints.INTEGER, Constraints.POSITIVE)
.defaultValue(512)
.build(),
new ParameterSymbol.Builder()
.name(AllPredefinedLayers.ACT_QUERY_NAME)
.constraints(Constraints.ACTIVATION_TYPE)
.defaultValue("linear")
.build(),
new ParameterSymbol.Builder()
.name(AllPredefinedLayers.K_NAME)
.constraints(Constraints.INTEGER, Constraints.POSITIVE)
.build(),
new ParameterSymbol.Builder()
.name(AllPredefinedLayers.NUM_HEADS_NAME)
.constraints(Constraints.INTEGER, Constraints.POSITIVE)
.build()));
declaration.setParameters(parameters);
return declaration;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment