Commit 74018915 authored by Julian Johannes Steinsberger-Dührßen's avatar Julian Johannes Steinsberger-Dührßen
Browse files

added parameters and coco for replay memory

parent 00062a76
Pipeline #282263 passed with stage
in 8 minutes and 46 seconds
......@@ -61,6 +61,7 @@ public class CNNArchCocos {
.addCoCo(new CheckLayerVariableDeclarationIsUsed())
.addCoCo(new CheckConstants())
.addCoCo(new CheckMemoryLayer())
.addCoCo(new CheckReplayMemoryLayer())
.addCoCo(new CheckUnrollInputsOutputsTooMany());
}
......@@ -70,7 +71,6 @@ public class CNNArchCocos {
.addCoCo(new CheckVariableDeclarationName())
.addCoCo(new CheckVariableName())
.addCoCo(new CheckArgmaxLayer())
.addCoCo(new CheckReplayMemoryLayer())
.addCoCo(new CheckExpressions());
}
......
......@@ -12,11 +12,14 @@ import de.monticore.lang.monticar.cnnarch._symboltable.StreamInstructionSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureElementSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.ParallelCompositeElementSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.SerialCompositeElementSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.LayerSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.ArgumentSymbol;
import de.monticore.lang.monticar.cnnarch.helper.ErrorCodes;
import de.se_rwth.commons.logging.Log;
import java.util.Optional;
import java.util.List;
import java.io.File;
public class CheckReplayMemoryLayer extends CNNArchSymbolCoCo {
......@@ -24,21 +27,22 @@ public class CheckReplayMemoryLayer extends CNNArchSymbolCoCo {
public void check(StreamInstructionSymbol stream) {
List<ArchitectureElementSymbol> elements = stream.getBody().getElements();
for (ArchitectureElementSymbol element : elements){
if (element instanceof ParallelCompositeElementSymbol){
for (ArchitectureElementSymbol element : elements) {
if (element instanceof ParallelCompositeElementSymbol) {
checkForReplayMemory((ParallelCompositeElementSymbol) element);
} else if (element.getName().equals("ReplayMemory")) {
checkParameters((LayerSymbol) element);
}
}
}
public void checkForReplayMemory(ParallelCompositeElementSymbol parallelElement) {
for (ArchitectureElementSymbol subStream: parallelElement.getElements()){
if (subStream instanceof SerialCompositeElementSymbol){ //should always be the case
for (ArchitectureElementSymbol element : ((SerialCompositeElementSymbol) subStream).getElements()){
if (element instanceof ParallelCompositeElementSymbol){
for (ArchitectureElementSymbol subStream : parallelElement.getElements()) {
if (subStream instanceof SerialCompositeElementSymbol) { //should always be the case
for (ArchitectureElementSymbol element : ((SerialCompositeElementSymbol) subStream).getElements()) {
if (element instanceof ParallelCompositeElementSymbol) {
checkForReplayMemory((ParallelCompositeElementSymbol) element);
}
else if (element.getName().equals("ReplayMemory")) {
} else if (element.getName().equals("ReplayMemory")) {
Log.error("0" + ErrorCodes.INVALID_REPLAY_MEMORY_LAYER_PLACEMENT +
" Invalid placement of ReplayMemory layer. It can't be placed inside a Prallalel execution block.",
element.getSourcePosition());
......@@ -47,4 +51,31 @@ public class CheckReplayMemoryLayer extends CNNArchSymbolCoCo {
}
}
}
public void checkParameters(LayerSymbol layer) {
List<ArgumentSymbol> arguments = layer.getArguments();
String queryNetDir = new String("");
String queryNetPrefix = new String("");
for (ArgumentSymbol arg : arguments) {
if (arg.getName().equals("queryNetDir")) {
queryNetDir = arg.getRhs().getStringValue().get();
} else if (arg.getName().equals("queryNetPrefix")) {
queryNetPrefix = arg.getRhs().getStringValue().get();
}
}
File dir = new File(queryNetDir);
if (dir.exists()) {
for (File file : dir.listFiles()) {
String file_name = file.getName();
if (file_name.startsWith(queryNetPrefix)) {
return;
}
}
}
Log.error("0" + ErrorCodes.INVALID_REPLAY_QUERY_NET_PATH_OR_PREFIX +
" For the concatination of queryNetDir and queryNetPrefix exists no file wich path has this as prefix.",
layer.getSourcePosition());
}
}
......@@ -48,6 +48,16 @@ public enum Constraints {
return "a boolean";
}
},
STRING {
@Override
public boolean isValid(ArchSimpleExpressionSymbol exp) {
return exp.isString();
}
@Override
public String msgString() {
return "a string";
}
},
TUPLE {
@Override
public boolean isValid(ArchSimpleExpressionSymbol exp) {
......
......@@ -39,6 +39,7 @@ public class ErrorCodes {
public static final String INVALID_CONSTANT = "x04856";
public static final String INVALID_MEMORY_LAYER_PARAMETERS = "x04866";
public static final String INVALID_REPLAY_MEMORY_LAYER_PLACEMENT = "x04876";
public static final String INVALID_REPLAY_QUERY_NET_PATH_OR_PREFIX = "x04877";
public static final String OUTPUT_WRITTEN_TO_MULTIPLE_TIMES = "x04836";
public static final String UNROLL_INPUTS_TOO_MANY = "x02384";
public static final String UNROLL_OUTPUTS_TOO_MANY = "x02385";
......
......@@ -102,9 +102,13 @@ public class AllPredefinedLayers {
public static final String REPLAY_INTERVAL_NAME = "replayInterval";
public static final String REPLAY_BATCH_SIZE_NAME = "replayBatchSize";
public static final String REPLAY_STEPS_NAME = "replaySteps";
public static final String REPLAY_GRADIENT_STEPS_NAME = "replayGradientSteps";
public static final String REPLAY_GRADIENT_STEPS_TRAINING_NAME = "replayGradientStepsTraining";
public static final String STORE_PROB_NAME = "storeProb";
public static final String MAX_STORED_SAMPLES_NAME = "maxStoredSamples";
public static final String REPLAY_K_NAME = "replayK";
public static final String REPLAY_GRADIENT_STEPS_PREDICTION_NAME = "replayGradientStepsPrediction";
public static final String QUERY_NET_DIR_NAME = "queryNetDir";
public static final String QUERY_NET_PREFIX_NAME = "queryNetPrefix";
//possible String values
public static final String PADDING_VALID = "valid";
......
......@@ -57,7 +57,7 @@ public class ReplayMemory extends PredefinedLayerDeclaration {
.defaultValue("linear")
.build(),
new ParameterSymbol.Builder()
.name(AllPredefinedLayers.REPLAY_GRADIENT_STEPS_NAME)
.name(AllPredefinedLayers.REPLAY_GRADIENT_STEPS_TRAINING_NAME)
.constraints(Constraints.INTEGER, Constraints.POSITIVE)
.defaultValue(1)
.build(),
......@@ -70,6 +70,26 @@ public class ReplayMemory extends PredefinedLayerDeclaration {
.name(AllPredefinedLayers.MAX_STORED_SAMPLES_NAME)
.constraints(Constraints.INTEGER, Constraints.POSITIVE_OR_MINUS_ONE)
.defaultValue(-1)
.build(),
new ParameterSymbol.Builder()
.name(AllPredefinedLayers.REPLAY_K_NAME)
.constraints(Constraints.INTEGER, Constraints.POSITIVE)
.defaultValue(1)
.build(),
new ParameterSymbol.Builder()
.name(AllPredefinedLayers.REPLAY_GRADIENT_STEPS_PREDICTION_NAME)
.constraints(Constraints.INTEGER, Constraints.POSITIVE)
.defaultValue(1)
.build(),
new ParameterSymbol.Builder()
.name(AllPredefinedLayers.QUERY_NET_DIR_NAME)
.constraints(Constraints.STRING)
.defaultValue(-1)
.build(),
new ParameterSymbol.Builder()
.name(AllPredefinedLayers.QUERY_NET_PREFIX_NAME)
.constraints(Constraints.STRING)
.defaultValue(-1)
.build()));
declaration.setParameters(parameters);
return declaration;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment