Change Memory Layer names; Added Load Network Layer; Multiple input support for episodic Memory

parent 04607854
......@@ -60,8 +60,8 @@ public class CNNArchCocos {
.addCoCo(new CheckLayerVariableDeclarationLayerType())
.addCoCo(new CheckLayerVariableDeclarationIsUsed())
.addCoCo(new CheckConstants())
.addCoCo(new CheckMemoryLayer())
.addCoCo(new CheckReplayMemoryLayer())
.addCoCo(new CheckLargeMemoryLayer())
.addCoCo(new CheckEpisodicMemoryLayer())
.addCoCo(new CheckUnrollInputsOutputsTooMany());
}
......
......@@ -21,7 +21,7 @@ import java.util.Optional;
import java.util.List;
import java.io.File;
public class CheckReplayMemoryLayer extends CNNArchSymbolCoCo {
public class CheckEpisodicMemoryLayer extends CNNArchSymbolCoCo {
@Override
public void check(StreamInstructionSymbol stream) {
......@@ -29,22 +29,22 @@ public class CheckReplayMemoryLayer extends CNNArchSymbolCoCo {
for (ArchitectureElementSymbol element : elements) {
if (element instanceof ParallelCompositeElementSymbol) {
checkForReplayMemory((ParallelCompositeElementSymbol) element);
} else if (element.getName().equals("ReplayMemory")) {
checkForEpisodicMemory((ParallelCompositeElementSymbol) element);
} else if (element.getName().equals("EpisodicMemory")) {
checkParameters((LayerSymbol) element);
}
}
}
public void checkForReplayMemory(ParallelCompositeElementSymbol parallelElement) {
public void checkForEpisodicMemory(ParallelCompositeElementSymbol parallelElement) {
for (ArchitectureElementSymbol subStream : parallelElement.getElements()) {
if (subStream instanceof SerialCompositeElementSymbol) { //should always be the case
for (ArchitectureElementSymbol element : ((SerialCompositeElementSymbol) subStream).getElements()) {
if (element instanceof ParallelCompositeElementSymbol) {
checkForReplayMemory((ParallelCompositeElementSymbol) element);
} else if (element.getName().equals("ReplayMemory")) {
Log.error("0" + ErrorCodes.INVALID_REPLAY_MEMORY_LAYER_PLACEMENT +
" Invalid placement of ReplayMemory layer. It can't be placed inside a Prallalel execution block.",
checkForEpisodicMemory((ParallelCompositeElementSymbol) element);
} else if (element.getName().equals("EpisodicMemory")) {
Log.error("0" + ErrorCodes.INVALID_EPISODIC_MEMORY_LAYER_PLACEMENT +
" Invalid placement of EpisodicMemory layer. It can't be placed inside a Prallalel execution block.",
element.getSourcePosition());
}
}
......@@ -74,7 +74,7 @@ public class CheckReplayMemoryLayer extends CNNArchSymbolCoCo {
}
}
}
Log.error("0" + ErrorCodes.INVALID_REPLAY_QUERY_NET_PATH_OR_PREFIX +
Log.error("0" + ErrorCodes.INVALID_EPISODIC_QUERY_NET_PATH_OR_PREFIX +
" For the concatination of queryNetDir and queryNetPrefix exists no file wich path has this as prefix.",
layer.getSourcePosition());
}
......
......@@ -17,16 +17,16 @@ import java.util.Optional;
import java.util.List;
public class CheckMemoryLayer extends CNNArchSymbolCoCo {
public class CheckLargeMemoryLayer extends CNNArchSymbolCoCo {
@Override
public void check(ArchitectureElementSymbol sym) {
if (sym instanceof LayerSymbol && sym.getName().equals("Memory")) {
checkMemoryLayer((LayerSymbol) sym);
if (sym instanceof LayerSymbol && sym.getName().equals("LargeMemory")) {
checkLargeMemoryLayer((LayerSymbol) sym);
}
}
public void checkMemoryLayer(LayerSymbol layer) {
public void checkLargeMemoryLayer(LayerSymbol layer) {
List<ArgumentSymbol> arguments = layer.getArguments();
Integer subKeySize = new Integer(0);
Integer k = new Integer(0);
......@@ -40,8 +40,8 @@ public class CheckMemoryLayer extends CNNArchSymbolCoCo {
}
if (subKeySize < k) {
Log.error("0" + ErrorCodes.INVALID_MEMORY_LAYER_PARAMETERS +
" Invalid Memory layer Parameter values, subKeySize has to be greater or equal to k. ",
Log.error("0" + ErrorCodes.INVALID_LARGE_MEMORY_LAYER_PARAMETERS +
" Invalid LargeMemory layer Parameter values, subKeySize has to be greater or equal to k. ",
layer.getSourcePosition());
}
}
......
......@@ -30,7 +30,6 @@ public class ArchTypeSymbol extends CommonSymbol {
private int widthIndex = -1;
private List<ArchSimpleExpressionSymbol> dimensions = new ArrayList<>();
public ArchTypeSymbol() {
super("", KIND);
ASTElementType elementType = new ASTElementType();
......@@ -146,7 +145,7 @@ public class ArchTypeSymbol extends CommonSymbol {
}
return dimensionList;
}
public Set<ParameterSymbol> resolve() {
if (!isResolved()){
if (isResolvable()){
......
......@@ -214,15 +214,15 @@ public class ArchitectureSymbol extends CommonScopeSpanningSymbol {
return copy;
}
public void processForReplayMemory(){
public void processForEpisodicReplayMemory(){
for(NetworkInstructionSymbol networkInstruction : networkInstructions){
List<ArchitectureElementSymbol> elements = networkInstruction.getBody().getElements();
List<ArchitectureElementSymbol> elementsNew = new ArrayList<>();
List<List<ArchitectureElementSymbol>> replaySubNetworks = new ArrayList<>(new ArrayList<>());
List<ArchitectureElementSymbol> currentReplaySubNetworkElements = new ArrayList<>();
List<List<ArchitectureElementSymbol>> episodicSubNetworks = new ArrayList<>(new ArrayList<>());
List<ArchitectureElementSymbol> currentEpisodicSubNetworkElements = new ArrayList<>();
for (ArchitectureElementSymbol element : elements){
if (AllPredefinedLayers.REPLAY_LAYER_NAMES.contains(element.getName())) {
if (AllPredefinedLayers.EPISODIC_REPLAY_LAYER_NAMES.contains(element.getName())) {
boolean use_replay = false;
boolean use_local_adaption = false;
......@@ -251,18 +251,18 @@ public class ArchitectureSymbol extends CommonScopeSpanningSymbol {
}
if (use_replay || use_local_adaption){
if (!currentReplaySubNetworkElements.isEmpty()){
replaySubNetworks.add(currentReplaySubNetworkElements);
if (!currentEpisodicSubNetworkElements.isEmpty()){
episodicSubNetworks.add(currentEpisodicSubNetworkElements);
}
currentReplaySubNetworkElements = new ArrayList<>();
currentEpisodicSubNetworkElements = new ArrayList<>();
}
}
currentReplaySubNetworkElements.add(element);
currentEpisodicSubNetworkElements.add(element);
}
if (!currentReplaySubNetworkElements.isEmpty() && !replaySubNetworks.isEmpty()){
replaySubNetworks.add(currentReplaySubNetworkElements);
if (!currentEpisodicSubNetworkElements.isEmpty() && !episodicSubNetworks.isEmpty()){
episodicSubNetworks.add(currentEpisodicSubNetworkElements);
}
networkInstruction.getBody().setReplaySubNetworks(replaySubNetworks);
networkInstruction.getBody().setEpisodicSubNetworks(episodicSubNetworks);
}
}
}
......@@ -12,7 +12,7 @@ import java.util.*;
public class SerialCompositeElementSymbol extends CompositeElementSymbol {
protected List<List<ArchitectureElementSymbol>> replaySubNetworks = new ArrayList<>(new ArrayList<>());
protected List<List<ArchitectureElementSymbol>> episodicSubNetworks = new ArrayList<>(new ArrayList<>());
protected void setElements(List<ArchitectureElementSymbol> elements) {
ArchitectureElementSymbol previous = null;
......@@ -34,8 +34,8 @@ public class SerialCompositeElementSymbol extends CompositeElementSymbol {
this.elements = elements;
}
protected void setReplaySubNetworks(List<List<ArchitectureElementSymbol>> replaySubNetworks){
for (List<ArchitectureElementSymbol> subElements: replaySubNetworks){
protected void setEpisodicSubNetworks(List<List<ArchitectureElementSymbol>> episodicSubNetworks){
for (List<ArchitectureElementSymbol> subElements: episodicSubNetworks){
ArchitectureElementSymbol previous = null;
for (ArchitectureElementSymbol current : subElements){
if (previous != null){
......@@ -53,11 +53,11 @@ public class SerialCompositeElementSymbol extends CompositeElementSymbol {
previous = current;
}
}
this.replaySubNetworks = replaySubNetworks;
this.episodicSubNetworks = episodicSubNetworks;
}
public List<List<ArchitectureElementSymbol>> getReplaySubNetworks() {
return replaySubNetworks;
public List<List<ArchitectureElementSymbol>> getEpisodicSubNetworks() {
return episodicSubNetworks;
}
@Override
......
......@@ -37,9 +37,9 @@ public class ErrorCodes {
public static final String ILLEGAL_LAYER_USE = "x04845";
public static final String UNUSED_LAYER = "x04847";
public static final String INVALID_CONSTANT = "x04856";
public static final String INVALID_MEMORY_LAYER_PARAMETERS = "x04866";
public static final String INVALID_REPLAY_MEMORY_LAYER_PLACEMENT = "x04876";
public static final String INVALID_REPLAY_QUERY_NET_PATH_OR_PREFIX = "x04877";
public static final String INVALID_LARGE_MEMORY_LAYER_PARAMETERS = "x04866";
public static final String INVALID_EPISODIC_MEMORY_LAYER_PLACEMENT = "x04876";
public static final String INVALID_EPISODIC_QUERY_NET_PATH_OR_PREFIX = "x04877";
public static final String OUTPUT_WRITTEN_TO_MULTIPLE_TIMES = "x04836";
public static final String UNROLL_INPUTS_TOO_MANY = "x02384";
public static final String UNROLL_OUTPUTS_TOO_MANY = "x02385";
......
......@@ -55,11 +55,12 @@ public class AllPredefinedLayers {
public static final String BROADCAST_ADD_NAME = "BroadcastAdd";
public static final String RESHAPE_NAME = "Reshape";
public static final String DOT_PRODUCT_SELF_ATTENTION_NAME = "DotProductSelfAttention";
public static final String LOAD_NETWORK_NAME = "LoadNetwork";
//replay layers
public static final String MEMORY_NAME = "Memory";
public static final String REPLAY_MEMORY_NAME = "ReplayMemory";
public static final List<String> REPLAY_LAYER_NAMES = new ArrayList<String>(Arrays.asList(REPLAY_MEMORY_NAME));
public static final String LARGE_MEMORY_NAME = "LargeMemory";
public static final String EPISODIC_MEMORY_NAME = "EpisodicMemory";
public static final List<String> EPISODIC_REPLAY_LAYER_NAMES = new ArrayList<String>(Arrays.asList(EPISODIC_MEMORY_NAME));
//predefined argument names
......@@ -96,6 +97,11 @@ public class AllPredefinedLayers {
public static final String SHAPE_NAME = "shape";
public static final String RNN_DROPOUT_NAME = "dropout";
//parameters LoadNetwork layer
public static final String NETWORK_DIR_NAME = "networkDir";
public static final String NETWORK_PREFIX_NAME = "networkPrefix";
public static final String NUM_INPUTS_NAME = "numInputs";
public static final String OUTPUT_SHAPE_NAME = "outputShape";
//parameters DotProductSelfAttention
public static final String SCALE_FACTOR_NAME="scaleFactor";
......@@ -103,7 +109,7 @@ public class AllPredefinedLayers {
public static final String DIM_VALUES_NAME="dimValues";
public static final String USE_PROJ_BIAS_NAME="useProjBias";
//shared parameters replay layers
//shared parameters episodic replay layers
public static final String USE_REPLAY_NAME = "useReplay";
public static final String REPLAY_INTERVAL_NAME = "replayInterval";
public static final String REPLAY_BATCH_SIZE_NAME = "replayBatchSize";
......@@ -115,7 +121,7 @@ public class AllPredefinedLayers {
public static final String LOCAL_ADAPTION_GRADIENT_STEPS_NAME = "localAdaptionGradientSteps";
public static final String LOCAL_ADAPTION_MEMORY_STORE_DIST_MEASURE_NAME = "localAdaptionMemoryStoreDistMeasure";
//parameters for memory layer
//parameters for episodic memory layer
public static final String SUB_KEY_SIZE_NAME = "subKeySize";
public static final String QUERY_SIZE_NAME = "querySize";
public static final String QUERY_ACT_NAME = "queryAct";
......@@ -124,11 +130,12 @@ public class AllPredefinedLayers {
public static final String STORE_DIST_MEASURE_NAME = "storeDistMeasure";
public static final String VALUES_DIM_NAME = "valuesDim";
//parameters for replay memory layer
//parameters for episodic memory layer
public static final String MAX_STORED_SAMPLES_NAME = "maxStoredSamples";
public static final String REPLAY_MEMORY_STORE_PROB_NAME = "replayMemoryStoreProb";
public static final String QUERY_NET_DIR_NAME = "queryNetDir";
public static final String QUERY_NET_PREFIX_NAME = "queryNetPrefix";
public static final String QUERY_NET_NUM_INPUTS_NAME = "queryNetNumInputs";
//possible String values
public static final String PADDING_VALID = "valid";
......@@ -184,9 +191,10 @@ public class AllPredefinedLayers {
SwapAxes.create(),
BroadcastAdd.create(),
Reshape.create(),
LoadNetwork.create(),
DotProductSelfAttention.create(),
Memory.create(),
ReplayMemory.create());
LargeMemory.create(),
EpisodicMemory.create());
}
public static List<UnrollDeclarationSymbol> createUnrollList(){
......
......@@ -16,31 +16,39 @@ import java.util.Collections;
import java.util.List;
import java.util.Optional;
public class ReplayMemory extends PredefinedLayerDeclaration {
public class EpisodicMemory extends PredefinedLayerDeclaration {
private ReplayMemory() {
super(AllPredefinedLayers.REPLAY_MEMORY_NAME);
private EpisodicMemory() {
super(AllPredefinedLayers.EPISODIC_MEMORY_NAME);
}
@Override
public List<ArchTypeSymbol> computeOutputTypes(List<ArchTypeSymbol> inputTypes, LayerSymbol layer, VariableSymbol.Member member) {
return Collections.singletonList(new ArchTypeSymbol.Builder()
.channels(1)
.height(1)
.width(1)
.elementType("-oo", "oo")
.build());
List<ArchTypeSymbol> outputShapes = new ArrayList<>(layer.getInputTypes().size());
for (int i = 0; i < layer.getInputTypes().size(); i++) {
ArchTypeSymbol inputShape = layer.getInputTypes().get(i);
int inputHeight = inputShape.getHeight();
int inputWidth = inputShape.getWidth();
int inputChannels = inputShape.getChannels();
outputShapes.add(new ArchTypeSymbol.Builder()
.height(inputHeight)
.width(inputWidth)
.channels(inputChannels)
.elementType(layer.getInputTypes().get(i).getDomain())
.build());
}
return outputShapes;
}
@Override
public void checkInput(List<ArchTypeSymbol> inputTypes, LayerSymbol layer, VariableSymbol.Member member) {
errorIfInputSizeIsNotOne(inputTypes, layer);
errorIfInputIsEmpty(inputTypes, layer);
}
public static ReplayMemory create(){
ReplayMemory declaration = new ReplayMemory();
public static EpisodicMemory create(){
EpisodicMemory declaration = new EpisodicMemory();
List<ParameterSymbol> parameters = new ArrayList<>(Arrays.asList(
new ParameterSymbol.Builder()
.name(AllPredefinedLayers.USE_REPLAY_NAME)
......@@ -104,6 +112,11 @@ public class ReplayMemory extends PredefinedLayerDeclaration {
.name(AllPredefinedLayers.QUERY_NET_PREFIX_NAME)
.constraints(Constraints.STRING)
.defaultValue(-1)
.build(),
new ParameterSymbol.Builder()
.name(AllPredefinedLayers.QUERY_NET_NUM_INPUTS_NAME)
.constraints(Constraints.INTEGER)
.defaultValue(1)
.build()));
declaration.setParameters(parameters);
return declaration;
......
......@@ -16,10 +16,10 @@ import java.util.Collections;
import java.util.List;
import java.util.Optional;
public class Memory extends PredefinedLayerDeclaration {
public class LargeMemory extends PredefinedLayerDeclaration {
private Memory() {
super(AllPredefinedLayers.MEMORY_NAME);
private LargeMemory() {
super(AllPredefinedLayers.LARGE_MEMORY_NAME);
}
@Override
......@@ -57,8 +57,8 @@ public class Memory extends PredefinedLayerDeclaration {
errorIfInputSizeIsNotOne(inputTypes, layer);
}
public static Memory create(){
Memory declaration = new Memory();
public static LargeMemory create(){
LargeMemory declaration = new LargeMemory();
List<ParameterSymbol> parameters = new ArrayList<>(Arrays.asList(
new ParameterSymbol.Builder()
.name(AllPredefinedLayers.STORE_DIST_MEASURE_NAME)
......
/**
*
* (c) https://github.com/MontiCore/monticore
*
* The license generally applicable for this project
* can be found under https://github.com/MontiCore/monticore.
*/
/* (c) https://github.com/MontiCore/monticore */
package de.monticore.lang.monticar.cnnarch.predefined;
import de.monticore.lang.monticar.cnnarch._symboltable.*;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
public class LoadNetwork extends PredefinedLayerDeclaration {
private LoadNetwork() {
super(AllPredefinedLayers.LOAD_NETWORK_NAME);
}
@Override
public List<ArchTypeSymbol> computeOutputTypes(List<ArchTypeSymbol> inputTypes, LayerSymbol layer, VariableSymbol.Member member) {
Optional<List<Integer>> optValue = layer.getIntTupleValue(AllPredefinedLayers.OUTPUT_SHAPE_NAME);
List<Integer> shapeList = Arrays.asList(1, 1, 1);
if (optValue.isPresent()) {
List<Integer> outputShape = optValue.get();
for (int i = 0; i < outputShape.size() && i < 3; i++) {
shapeList.set(i, outputShape.get(i));
}
}
return Collections.singletonList(new ArchTypeSymbol.Builder()
.channels(shapeList.get(0))
.height(shapeList.get(1))
.width(shapeList.get(2))
.elementType("-oo", "oo")
.build());
}
@Override
public void checkInput(List<ArchTypeSymbol> inputTypes, LayerSymbol layer, VariableSymbol.Member member) {
errorIfInputIsEmpty(inputTypes, layer);
}
public static LoadNetwork create(){
LoadNetwork declaration = new LoadNetwork();
List<ParameterSymbol> parameters = new ArrayList<>(Arrays.asList(
new ParameterSymbol.Builder()
.name(AllPredefinedLayers.NETWORK_DIR_NAME)
.constraints(Constraints.STRING)
.build(),
new ParameterSymbol.Builder()
.name(AllPredefinedLayers.NETWORK_PREFIX_NAME)
.constraints(Constraints.STRING)
.build(),
new ParameterSymbol.Builder()
.name(AllPredefinedLayers.NUM_INPUTS_NAME)
.constraints(Constraints.INTEGER)
.build(),
new ParameterSymbol.Builder()
.name(AllPredefinedLayers.OUTPUT_SHAPE_NAME)
.constraints(Constraints.INTEGER_TUPLE)
.build()));
declaration.setParameters(parameters);
return declaration;
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment