Commit 74018915 authored by Julian Johannes Steinsberger-Dührßen's avatar Julian Johannes Steinsberger-Dührßen
Browse files

added parameters and coco for replay memory

parent 00062a76
Pipeline #282263 passed with stage
in 8 minutes and 46 seconds
...@@ -61,6 +61,7 @@ public class CNNArchCocos { ...@@ -61,6 +61,7 @@ public class CNNArchCocos {
.addCoCo(new CheckLayerVariableDeclarationIsUsed()) .addCoCo(new CheckLayerVariableDeclarationIsUsed())
.addCoCo(new CheckConstants()) .addCoCo(new CheckConstants())
.addCoCo(new CheckMemoryLayer()) .addCoCo(new CheckMemoryLayer())
.addCoCo(new CheckReplayMemoryLayer())
.addCoCo(new CheckUnrollInputsOutputsTooMany()); .addCoCo(new CheckUnrollInputsOutputsTooMany());
} }
...@@ -70,7 +71,6 @@ public class CNNArchCocos { ...@@ -70,7 +71,6 @@ public class CNNArchCocos {
.addCoCo(new CheckVariableDeclarationName()) .addCoCo(new CheckVariableDeclarationName())
.addCoCo(new CheckVariableName()) .addCoCo(new CheckVariableName())
.addCoCo(new CheckArgmaxLayer()) .addCoCo(new CheckArgmaxLayer())
.addCoCo(new CheckReplayMemoryLayer())
.addCoCo(new CheckExpressions()); .addCoCo(new CheckExpressions());
} }
......
...@@ -12,39 +12,70 @@ import de.monticore.lang.monticar.cnnarch._symboltable.StreamInstructionSymbol; ...@@ -12,39 +12,70 @@ import de.monticore.lang.monticar.cnnarch._symboltable.StreamInstructionSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureElementSymbol; import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureElementSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.ParallelCompositeElementSymbol; import de.monticore.lang.monticar.cnnarch._symboltable.ParallelCompositeElementSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.SerialCompositeElementSymbol; import de.monticore.lang.monticar.cnnarch._symboltable.SerialCompositeElementSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.LayerSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.ArgumentSymbol;
import de.monticore.lang.monticar.cnnarch.helper.ErrorCodes; import de.monticore.lang.monticar.cnnarch.helper.ErrorCodes;
import de.se_rwth.commons.logging.Log; import de.se_rwth.commons.logging.Log;
import java.util.Optional; import java.util.Optional;
import java.util.List; import java.util.List;
import java.io.File;
public class CheckReplayMemoryLayer extends CNNArchSymbolCoCo { public class CheckReplayMemoryLayer extends CNNArchSymbolCoCo {
@Override @Override
public void check(StreamInstructionSymbol stream) { public void check(StreamInstructionSymbol stream) {
List<ArchitectureElementSymbol> elements = stream.getBody().getElements(); List<ArchitectureElementSymbol> elements = stream.getBody().getElements();
for (ArchitectureElementSymbol element : elements){ for (ArchitectureElementSymbol element : elements) {
if (element instanceof ParallelCompositeElementSymbol){ if (element instanceof ParallelCompositeElementSymbol) {
checkForReplayMemory((ParallelCompositeElementSymbol) element); checkForReplayMemory((ParallelCompositeElementSymbol) element);
} else if (element.getName().equals("ReplayMemory")) {
checkParameters((LayerSymbol) element);
} }
} }
} }
public void checkForReplayMemory(ParallelCompositeElementSymbol parallelElement) { public void checkForReplayMemory(ParallelCompositeElementSymbol parallelElement) {
for (ArchitectureElementSymbol subStream: parallelElement.getElements()){ for (ArchitectureElementSymbol subStream : parallelElement.getElements()) {
if (subStream instanceof SerialCompositeElementSymbol){ //should always be the case if (subStream instanceof SerialCompositeElementSymbol) { //should always be the case
for (ArchitectureElementSymbol element : ((SerialCompositeElementSymbol) subStream).getElements()){ for (ArchitectureElementSymbol element : ((SerialCompositeElementSymbol) subStream).getElements()) {
if (element instanceof ParallelCompositeElementSymbol){ if (element instanceof ParallelCompositeElementSymbol) {
checkForReplayMemory((ParallelCompositeElementSymbol) element); checkForReplayMemory((ParallelCompositeElementSymbol) element);
} } else if (element.getName().equals("ReplayMemory")) {
else if (element.getName().equals("ReplayMemory")) {
Log.error("0" + ErrorCodes.INVALID_REPLAY_MEMORY_LAYER_PLACEMENT + Log.error("0" + ErrorCodes.INVALID_REPLAY_MEMORY_LAYER_PLACEMENT +
" Invalid placement of ReplayMemory layer. It can't be placed inside a Prallalel execution block.", " Invalid placement of ReplayMemory layer. It can't be placed inside a Prallalel execution block.",
element.getSourcePosition()); element.getSourcePosition());
} }
} }
} }
} }
} }
public void checkParameters(LayerSymbol layer) {
List<ArgumentSymbol> arguments = layer.getArguments();
String queryNetDir = new String("");
String queryNetPrefix = new String("");
for (ArgumentSymbol arg : arguments) {
if (arg.getName().equals("queryNetDir")) {
queryNetDir = arg.getRhs().getStringValue().get();
} else if (arg.getName().equals("queryNetPrefix")) {
queryNetPrefix = arg.getRhs().getStringValue().get();
}
}
File dir = new File(queryNetDir);
if (dir.exists()) {
for (File file : dir.listFiles()) {
String file_name = file.getName();
if (file_name.startsWith(queryNetPrefix)) {
return;
}
}
}
Log.error("0" + ErrorCodes.INVALID_REPLAY_QUERY_NET_PATH_OR_PREFIX +
" For the concatination of queryNetDir and queryNetPrefix exists no file wich path has this as prefix.",
layer.getSourcePosition());
}
} }
...@@ -48,6 +48,16 @@ public enum Constraints { ...@@ -48,6 +48,16 @@ public enum Constraints {
return "a boolean"; return "a boolean";
} }
}, },
STRING {
@Override
public boolean isValid(ArchSimpleExpressionSymbol exp) {
return exp.isString();
}
@Override
public String msgString() {
return "a string";
}
},
TUPLE { TUPLE {
@Override @Override
public boolean isValid(ArchSimpleExpressionSymbol exp) { public boolean isValid(ArchSimpleExpressionSymbol exp) {
......
...@@ -39,6 +39,7 @@ public class ErrorCodes { ...@@ -39,6 +39,7 @@ public class ErrorCodes {
public static final String INVALID_CONSTANT = "x04856"; public static final String INVALID_CONSTANT = "x04856";
public static final String INVALID_MEMORY_LAYER_PARAMETERS = "x04866"; public static final String INVALID_MEMORY_LAYER_PARAMETERS = "x04866";
public static final String INVALID_REPLAY_MEMORY_LAYER_PLACEMENT = "x04876"; public static final String INVALID_REPLAY_MEMORY_LAYER_PLACEMENT = "x04876";
public static final String INVALID_REPLAY_QUERY_NET_PATH_OR_PREFIX = "x04877";
public static final String OUTPUT_WRITTEN_TO_MULTIPLE_TIMES = "x04836"; public static final String OUTPUT_WRITTEN_TO_MULTIPLE_TIMES = "x04836";
public static final String UNROLL_INPUTS_TOO_MANY = "x02384"; public static final String UNROLL_INPUTS_TOO_MANY = "x02384";
public static final String UNROLL_OUTPUTS_TOO_MANY = "x02385"; public static final String UNROLL_OUTPUTS_TOO_MANY = "x02385";
......
...@@ -102,9 +102,13 @@ public class AllPredefinedLayers { ...@@ -102,9 +102,13 @@ public class AllPredefinedLayers {
public static final String REPLAY_INTERVAL_NAME = "replayInterval"; public static final String REPLAY_INTERVAL_NAME = "replayInterval";
public static final String REPLAY_BATCH_SIZE_NAME = "replayBatchSize"; public static final String REPLAY_BATCH_SIZE_NAME = "replayBatchSize";
public static final String REPLAY_STEPS_NAME = "replaySteps"; public static final String REPLAY_STEPS_NAME = "replaySteps";
public static final String REPLAY_GRADIENT_STEPS_NAME = "replayGradientSteps"; public static final String REPLAY_GRADIENT_STEPS_TRAINING_NAME = "replayGradientStepsTraining";
public static final String STORE_PROB_NAME = "storeProb"; public static final String STORE_PROB_NAME = "storeProb";
public static final String MAX_STORED_SAMPLES_NAME = "maxStoredSamples"; public static final String MAX_STORED_SAMPLES_NAME = "maxStoredSamples";
public static final String REPLAY_K_NAME = "replayK";
public static final String REPLAY_GRADIENT_STEPS_PREDICTION_NAME = "replayGradientStepsPrediction";
public static final String QUERY_NET_DIR_NAME = "queryNetDir";
public static final String QUERY_NET_PREFIX_NAME = "queryNetPrefix";
//possible String values //possible String values
public static final String PADDING_VALID = "valid"; public static final String PADDING_VALID = "valid";
......
...@@ -57,7 +57,7 @@ public class ReplayMemory extends PredefinedLayerDeclaration { ...@@ -57,7 +57,7 @@ public class ReplayMemory extends PredefinedLayerDeclaration {
.defaultValue("linear") .defaultValue("linear")
.build(), .build(),
new ParameterSymbol.Builder() new ParameterSymbol.Builder()
.name(AllPredefinedLayers.REPLAY_GRADIENT_STEPS_NAME) .name(AllPredefinedLayers.REPLAY_GRADIENT_STEPS_TRAINING_NAME)
.constraints(Constraints.INTEGER, Constraints.POSITIVE) .constraints(Constraints.INTEGER, Constraints.POSITIVE)
.defaultValue(1) .defaultValue(1)
.build(), .build(),
...@@ -70,6 +70,26 @@ public class ReplayMemory extends PredefinedLayerDeclaration { ...@@ -70,6 +70,26 @@ public class ReplayMemory extends PredefinedLayerDeclaration {
.name(AllPredefinedLayers.MAX_STORED_SAMPLES_NAME) .name(AllPredefinedLayers.MAX_STORED_SAMPLES_NAME)
.constraints(Constraints.INTEGER, Constraints.POSITIVE_OR_MINUS_ONE) .constraints(Constraints.INTEGER, Constraints.POSITIVE_OR_MINUS_ONE)
.defaultValue(-1) .defaultValue(-1)
.build(),
new ParameterSymbol.Builder()
.name(AllPredefinedLayers.REPLAY_K_NAME)
.constraints(Constraints.INTEGER, Constraints.POSITIVE)
.defaultValue(1)
.build(),
new ParameterSymbol.Builder()
.name(AllPredefinedLayers.REPLAY_GRADIENT_STEPS_PREDICTION_NAME)
.constraints(Constraints.INTEGER, Constraints.POSITIVE)
.defaultValue(1)
.build(),
new ParameterSymbol.Builder()
.name(AllPredefinedLayers.QUERY_NET_DIR_NAME)
.constraints(Constraints.STRING)
.defaultValue(-1)
.build(),
new ParameterSymbol.Builder()
.name(AllPredefinedLayers.QUERY_NET_PREFIX_NAME)
.constraints(Constraints.STRING)
.defaultValue(-1)
.build())); .build()));
declaration.setParameters(parameters); declaration.setParameters(parameters);
return declaration; return declaration;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment