Commit 9dd5761f authored by Julian Johannes Steinsberger-Dührßen's avatar Julian Johannes Steinsberger-Dührßen
Browse files

Added generation of replay subnets for ReplayMemory.

parent de6687e2
......@@ -157,6 +157,7 @@ public abstract class ArchitectureElementSymbol extends ResolvableSymbol {
else {
return Optional.empty();
}
}
/**
......
......@@ -214,5 +214,26 @@ public class ArchitectureSymbol extends CommonScopeSpanningSymbol {
return copy;
}
public void processForReplayMemory(){
for(NetworkInstructionSymbol networkInstruction : networkInstructions){
List<ArchitectureElementSymbol> elements = networkInstruction.getBody().getElements();
List<ArchitectureElementSymbol> elementsNew = new ArrayList<>();
List<List<ArchitectureElementSymbol>> replaySubNetworks = new ArrayList<>(new ArrayList<>());
List<ArchitectureElementSymbol> currentReplaySubNetworkElements = new ArrayList<>();
for (ArchitectureElementSymbol element : elements){
if (element.getName().equals("ReplayMemory")) {
if (!currentReplaySubNetworkElements.isEmpty()){
replaySubNetworks.add(currentReplaySubNetworkElements);
}
currentReplaySubNetworkElements = new ArrayList<>();
}
currentReplaySubNetworkElements.add(element);
}
if (!currentReplaySubNetworkElements.isEmpty() && !replaySubNetworks.isEmpty()){
replaySubNetworks.add(currentReplaySubNetworkElements);
}
networkInstruction.getBody().setReplaySubNetworks(replaySubNetworks);
}
}
}
......@@ -17,17 +17,18 @@ public abstract class CompositeElementSymbol extends ArchitectureElementSymbol {
protected List<ArchitectureElementSymbol> elements = new ArrayList<>();
public CompositeElementSymbol() {
super("");
setResolvedThis(this);
}
abstract protected void setElements(List<ArchitectureElementSymbol> elements);
public List<ArchitectureElementSymbol> getElements() {
return elements;
}
abstract protected void setElements(List<ArchitectureElementSymbol> elements);
@Override
public boolean isAtomic() {
return getElements().isEmpty();
......
......@@ -68,6 +68,16 @@ public enum Constraints {
return "a tuple of integers";
}
},
INTEGER_OR_INTEGER_TUPLE {
@Override
public boolean isValid(ArchSimpleExpressionSymbol exp) {
return exp.isInt().get() || exp.isIntTuple().get();
}
@Override
public String msgString() {
return "an integer or tuple of integers";
}
},
POSITIVE {
@Override
public boolean isValid(ArchSimpleExpressionSymbol exp) {
......@@ -90,6 +100,28 @@ public enum Constraints {
return "a positive number";
}
},
POSITIVE_OR_MINUS_ONE {
@Override
public boolean isValid(ArchSimpleExpressionSymbol exp) {
if (exp.getDoubleValue().isPresent()){
return exp.getDoubleValue().get() > 0 || exp.getDoubleValue().get() == -1;
}
else if (exp.getDoubleTupleValues().isPresent()){
boolean isPositive = true;
for (double value : exp.getDoubleTupleValues().get()){
if (value < -1 || value == 0){
isPositive = false;
}
}
return isPositive;
}
return false;
}
@Override
public String msgString() {
return "a positive number";
}
},
NON_NEGATIVE {
@Override
public boolean isValid(ArchSimpleExpressionSymbol exp) {
......@@ -207,7 +239,6 @@ public enum Constraints {
}
return false;
}
@Override
protected String msgString() {
return AllPredefinedLayers.MEMORY_ACTIVATION_LINEAR + " or "
......
......@@ -9,6 +9,8 @@ package de.monticore.lang.monticar.cnnarch._symboltable;
import de.monticore.symboltable.SymbolKind;
import java.util.*;
public abstract class NetworkInstructionSymbol extends ResolvableSymbol {
private SerialCompositeElementSymbol body;
......@@ -16,7 +18,7 @@ public abstract class NetworkInstructionSymbol extends ResolvableSymbol {
protected NetworkInstructionSymbol(String name, SymbolKind kind) {
super(name, kind);
}
public SerialCompositeElementSymbol getBody() {
return body;
}
......
......@@ -12,6 +12,8 @@ import java.util.*;
public class SerialCompositeElementSymbol extends CompositeElementSymbol {
protected List<List<ArchitectureElementSymbol>> replaySubNetworks = new ArrayList<>(new ArrayList<>());
protected void setElements(List<ArchitectureElementSymbol> elements) {
ArchitectureElementSymbol previous = null;
for (ArchitectureElementSymbol current : elements){
......@@ -32,6 +34,32 @@ public class SerialCompositeElementSymbol extends CompositeElementSymbol {
this.elements = elements;
}
protected void setReplaySubNetworks(List<List<ArchitectureElementSymbol>> replaySubNetworks){
for (List<ArchitectureElementSymbol> subElements: replaySubNetworks){
ArchitectureElementSymbol previous = null;
for (ArchitectureElementSymbol current : subElements){
if (previous != null){
current.setInputElement(previous);
previous.setOutputElement(current);
}
else {
if (getInputElement().isPresent()){
current.setInputElement(getInputElement().get());
}
if (getOutputElement().isPresent()){
current.setOutputElement(getOutputElement().get());
}
}
previous = current;
}
}
this.replaySubNetworks = replaySubNetworks;
}
public List<List<ArchitectureElementSymbol>> getReplaySubNetworks() {
return replaySubNetworks;
}
@Override
public void setInputElement(ArchitectureElementSymbol inputElement) {
super.setInputElement(inputElement);
......
......@@ -54,6 +54,7 @@ public class AllPredefinedLayers {
public static final String BROADCAST_ADD_NAME = "BroadcastAdd";
public static final String RESHAPE_NAME = "Reshape";
public static final String MEMORY_NAME = "Memory";
public static final String REPLAY_MEMORY_NAME = "ReplayMemory";
//predefined argument names
public static final String KERNEL_NAME = "kernel";
......@@ -88,13 +89,21 @@ public class AllPredefinedLayers {
public static final String BEAMSEARCH_WIDTH_NAME = "width";
public static final String SHAPE_NAME = "shape";
public static final String RNN_DROPOUT_NAME = "dropout";
//parameters for memory layers
//parameters for memory layer
public static final String SUB_KEY_SIZE_NAME = "subKeySize";
public static final String QUERY_SIZE_NAME = "querySize";
public static final String ACT_QUERY_NAME = "actQuery";
public static final String QUERY_ACT_NAME = "queryAct";
public static final String K_NAME = "k";
public static final String NUM_HEADS_NAME = "numHeads";
public static final String VALUE_SHAPE_NAME = "valueShape";
//parameters for replay memory layer
public static final String REPLAY_INTERVAL_NAME = "replayInterval";
public static final String REPLAY_BATCH_SIZE_NAME = "replayBatchSize";
public static final String REPLAY_STEPS_NAME = "replaySteps";
public static final String REPLAY_GRADIENT_STEPS_NAME = "replayGradientSteps";
public static final String STORE_PROB_NAME = "storeProb";
public static final String MAX_STORED_SAMPLES_NAME = "maxStoredSamples";
//possible String values
public static final String PADDING_VALID = "valid";
public static final String PADDING_SAME = "same";
......@@ -146,7 +155,8 @@ public class AllPredefinedLayers {
SwapAxes.create(),
BroadcastAdd.create(),
Reshape.create(),
Memory.create());
Memory.create(),
ReplayMemory.create());
}
public static List<UnrollDeclarationSymbol> createUnrollList(){
......
......@@ -14,6 +14,7 @@ import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
public class Memory extends PredefinedLayerDeclaration {
......@@ -24,14 +25,31 @@ public class Memory extends PredefinedLayerDeclaration {
@Override
public List<ArchTypeSymbol> computeOutputTypes(List<ArchTypeSymbol> inputTypes, LayerSymbol layer, VariableSymbol.Member member) {
int querySize = layer.getIntValue(AllPredefinedLayers.QUERY_SIZE_NAME).get();
return Collections.singletonList(new ArchTypeSymbol.Builder()
.channels(1)
.height(querySize)
.width(1)
.elementType("-oo", "oo")
.build());
Optional<Integer> optValue = layer.getIntValue(AllPredefinedLayers.QUERY_SIZE_NAME);
if (optValue.isPresent()){
int querySize = optValue.get();
return Collections.singletonList(new ArchTypeSymbol.Builder()
.channels(1)
.height(querySize)
.width(1)
.elementType("-oo", "oo")
.build());
}else{
Optional<List<Integer>> optTupleValue = layer.getIntTupleValue(AllPredefinedLayers.QUERY_SIZE_NAME);
List<Integer> list = new ArrayList<>();
for (Object value : optTupleValue.get()) {
list.add((Integer) value);
}
int listLen = list.size();
int lastEntry = list.get(listLen-1);
return Collections.singletonList(new ArchTypeSymbol.Builder()
.channels(1)
.height(lastEntry)
.width(1)
.elementType("-oo", "oo")
.build());
}
}
@Override
......@@ -48,11 +66,11 @@ public class Memory extends PredefinedLayerDeclaration {
.build(),
new ParameterSymbol.Builder()
.name(AllPredefinedLayers.QUERY_SIZE_NAME)
.constraints(Constraints.INTEGER, Constraints.POSITIVE)
.constraints(Constraints.INTEGER_OR_INTEGER_TUPLE, Constraints.POSITIVE)
.defaultValue(512)
.build(),
new ParameterSymbol.Builder()
.name(AllPredefinedLayers.ACT_QUERY_NAME)
.name(AllPredefinedLayers.QUERY_ACT_NAME)
.constraints(Constraints.ACTIVATION_TYPE)
.defaultValue("linear")
.build(),
......@@ -63,6 +81,12 @@ public class Memory extends PredefinedLayerDeclaration {
new ParameterSymbol.Builder()
.name(AllPredefinedLayers.NUM_HEADS_NAME)
.constraints(Constraints.INTEGER, Constraints.POSITIVE)
.defaultValue(1)
.build(),
new ParameterSymbol.Builder()
.name(AllPredefinedLayers.VALUE_SHAPE_NAME)
.constraints(Constraints.INTEGER_OR_INTEGER_TUPLE, Constraints.POSITIVE_OR_MINUS_ONE)
.defaultValue(-1)
.build()));
declaration.setParameters(parameters);
return declaration;
......
/**
*
* (c) https://github.com/MontiCore/monticore
*
* The license generally applicable for this project
* can be found under https://github.com/MontiCore/monticore.
*/
/* (c) https://github.com/MontiCore/monticore */
package de.monticore.lang.monticar.cnnarch.predefined;
import de.monticore.lang.monticar.cnnarch._symboltable.*;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
public class ReplayMemory extends PredefinedLayerDeclaration {
private ReplayMemory() {
super(AllPredefinedLayers.REPLAY_MEMORY_NAME);
}
@Override
public List<ArchTypeSymbol> computeOutputTypes(List<ArchTypeSymbol> inputTypes, LayerSymbol layer, VariableSymbol.Member member) {
return Collections.singletonList(new ArchTypeSymbol.Builder()
.channels(1)
.height(1)
.width(1)
.elementType("-oo", "oo")
.build());
}
@Override
public void checkInput(List<ArchTypeSymbol> inputTypes, LayerSymbol layer, VariableSymbol.Member member) {
errorIfInputSizeIsNotOne(inputTypes, layer);
}
public static ReplayMemory create(){
ReplayMemory declaration = new ReplayMemory();
List<ParameterSymbol> parameters = new ArrayList<>(Arrays.asList(
new ParameterSymbol.Builder()
.name(AllPredefinedLayers.REPLAY_INTERVAL_NAME)
.constraints(Constraints.INTEGER, Constraints.POSITIVE)
.build(),
new ParameterSymbol.Builder()
.name(AllPredefinedLayers.REPLAY_BATCH_SIZE_NAME)
.constraints(Constraints.INTEGER, Constraints.POSITIVE_OR_MINUS_ONE)
.defaultValue(-1)
.build(),
new ParameterSymbol.Builder()
.name(AllPredefinedLayers.REPLAY_STEPS_NAME)
.constraints(Constraints.INTEGER, Constraints.POSITIVE)
.defaultValue("linear")
.build(),
new ParameterSymbol.Builder()
.name(AllPredefinedLayers.REPLAY_GRADIENT_STEPS_NAME)
.constraints(Constraints.INTEGER, Constraints.POSITIVE)
.defaultValue(1)
.build(),
new ParameterSymbol.Builder()
.name(AllPredefinedLayers.STORE_PROB_NAME)
.constraints(Constraints.NUMBER, Constraints.BETWEEN_ZERO_AND_ONE)
.defaultValue(1)
.build(),
new ParameterSymbol.Builder()
.name(AllPredefinedLayers.MAX_STORED_SAMPLES_NAME)
.constraints(Constraints.INTEGER, Constraints.POSITIVE_OR_MINUS_ONE)
.defaultValue(-1)
.build()));
declaration.setParameters(parameters);
return declaration;
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment