Commit 8a9fa0a9 authored by Julian Johannes Steinsberger-Dührßen's avatar Julian Johannes Steinsberger-Dührßen
Browse files

Finished LayerPathParameter tagging

parent a5e1bb87
Pipeline #319034 passed with stage
in 11 minutes and 4 seconds
......@@ -30,13 +30,11 @@ public class CheckEpisodicMemoryLayer extends CNNArchSymbolCoCo {
for (ArchitectureElementSymbol element : elements) {
if (element instanceof ParallelCompositeElementSymbol) {
checkForEpisodicMemory((ParallelCompositeElementSymbol) element);
} else if (element.getName().equals("EpisodicMemory")) {
checkParameters((LayerSymbol) element);
}
}
}
public void checkForEpisodicMemory(ParallelCompositeElementSymbol parallelElement) {
protected void checkForEpisodicMemory(ParallelCompositeElementSymbol parallelElement) {
for (ArchitectureElementSymbol subStream : parallelElement.getElements()) {
if (subStream instanceof SerialCompositeElementSymbol) { //should always be the case
for (ArchitectureElementSymbol element : ((SerialCompositeElementSymbol) subStream).getElements()) {
......@@ -51,31 +49,4 @@ public class CheckEpisodicMemoryLayer extends CNNArchSymbolCoCo {
}
}
}
public void checkParameters(LayerSymbol layer) {
List<ArgumentSymbol> arguments = layer.getArguments();
String queryNetDir = new String("");
String queryNetPrefix = new String("");
for (ArgumentSymbol arg : arguments) {
if (arg.getName().equals("queryNetDir")) {
queryNetDir = arg.getRhs().getStringValue().get();
} else if (arg.getName().equals("queryNetPrefix")) {
queryNetPrefix = arg.getRhs().getStringValue().get();
}
}
File dir = new File(queryNetDir);
if (dir.exists()) {
for (File file : dir.listFiles()) {
String file_name = file.getName();
if (file_name.startsWith(queryNetPrefix)) {
return;
}
}
}
Log.error("0" + ErrorCodes.INVALID_EPISODIC_QUERY_NET_PATH_OR_PREFIX +
" For the concatination of queryNetDir and queryNetPrefix exists no file wich path has this as prefix.",
layer.getSourcePosition());
}
}
......@@ -26,7 +26,7 @@ public class CheckLargeMemoryLayer extends CNNArchSymbolCoCo {
}
}
public void checkLargeMemoryLayer(LayerSymbol layer) {
protected void checkLargeMemoryLayer(LayerSymbol layer) {
List<ArgumentSymbol> arguments = layer.getArguments();
Integer subKeySize = new Integer(0);
Integer k = new Integer(0);
......
/**
*
* (c) https://github.com/MontiCore/monticore
*
* The license generally applicable for this project
* can be found under https://github.com/MontiCore/monticore.
*/
/* (c) https://github.com/MontiCore/monticore */
package de.monticore.lang.monticar.cnnarch._cocos;
import de.monticore.lang.monticar.cnnarch.helper.ErrorCodes;
import de.monticore.lang.monticar.cnnarch._symboltable.LayerSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.ArgumentSymbol;
import de.se_rwth.commons.logging.Log;
import java.io.File;
import java.util.*;
public class CheckLayerPathParameter {
public static void check(LayerSymbol sym, String path, String tag, HashMap layerPathParameterTags) {
checkTag(sym, tag, layerPathParameterTags);
checkPath(sym, path);
}
protected static void checkTag(LayerSymbol layer, String tag, HashMap layerPathParameterTags){
if (!tag.equals("") && !layerPathParameterTags.containsKey(tag)) {
Log.error("0" + ErrorCodes.INVALID_LAYER_PATH_PARAMETER_TAG +
"The LayerPathParameter tag " + tag + " was not found.",
layer.getSourcePosition());
}
}
protected static void checkPath(LayerSymbol layer, String path) {
File dir = new File(path);
if (dir.exists()) {
return;
}
Log.error("0" + ErrorCodes.INVALID_LAYER_PATH_PARAMETER_PATH +
" For the concatination of queryNetDir and queryNetPrefix exists no file which path has this as prefix.",
layer.getSourcePosition());
}
}
......@@ -74,7 +74,18 @@ abstract public class ArchExpressionSymbol extends CommonSymbol {
public boolean isString(){
if (getValue().isPresent()){
return getStringValue().isPresent();
Optional<String> stringValue = getStringValue();
if (stringValue.isPresent() && !stringValue.get().startsWith("tag:")) {
return getStringValue().isPresent();
}
}
return false;
}
public boolean isStringTag(){
if (getValue().isPresent()) {
Optional<String> stringValue = getStringValue();
return stringValue.isPresent() && stringValue.get().startsWith("tag:");
}
return false;
}
......
......@@ -14,11 +14,13 @@ package de.monticore.lang.monticar.cnnarch._symboltable;
import de.monticore.lang.monticar.cnnarch.helper.Utils;
import de.monticore.lang.monticar.cnnarch.predefined.AllPredefinedLayers;
import de.monticore.lang.monticar.cnnarch.predefined.AllPredefinedVariables;
import de.monticore.lang.monticar.cnnarch._cocos.CheckLayerPathParameter;
import de.monticore.symboltable.CommonScopeSpanningSymbol;
import de.monticore.symboltable.Scope;
import de.monticore.symboltable.Symbol;
import org.apache.commons.math3.ml.neuralnet.Network;
import java.lang.RuntimeException;
import java.util.*;
public class ArchitectureSymbol extends CommonScopeSpanningSymbol {
......@@ -32,7 +34,6 @@ public class ArchitectureSymbol extends CommonScopeSpanningSymbol {
private List<ConstantSymbol> constants = new ArrayList<>();
private String dataPath;
private String weightsPath;
private List<LayerPathParameterTagSymbol> layerPathParameterTagSymbols = new ArrayList<>();
private String componentName;
public ArchitectureSymbol() {
......@@ -83,12 +84,6 @@ public class ArchitectureSymbol extends CommonScopeSpanningSymbol {
this.weightsPath = weightsPath;
}
public void setLayerPathParameterTagSymbols(List<LayerPathParameterTagSymbol> layerPathParameterTagSymbols) { this.layerPathParameterTagSymbols = layerPathParameterTagSymbols; }
public List<LayerPathParameterTagSymbol> getLayerPathParameterTagSymbols() {
return layerPathParameterTagSymbols;
}
public void setComponentName(String componentName){
this.componentName = componentName;
}
......@@ -221,12 +216,37 @@ public class ArchitectureSymbol extends CommonScopeSpanningSymbol {
return copy;
}
public void processLayerPathParameterTags(){
public void processLayerPathParameterTags(HashMap layerPathParameterTags){
for(NetworkInstructionSymbol networkInstruction : networkInstructions){
List<ArchitectureElementSymbol> elements = networkInstruction.getBody().getElements();
processElementsLayerPathParameterTags(elements, layerPathParameterTags);
}
}
for (ArchitectureElementSymbol element : elements){
public void processElementsLayerPathParameterTags(List<ArchitectureElementSymbol> elements, HashMap layerPathParameterTags){
for (ArchitectureElementSymbol element : elements){
if (element instanceof SerialCompositeElementSymbol || element instanceof ParallelCompositeElementSymbol){
processElementsLayerPathParameterTags(((CompositeElementSymbol) element).getElements(), layerPathParameterTags);
}else if (element instanceof LayerSymbol){
for (ArgumentSymbol param : ((LayerSymbol) element).getArguments()){
boolean isPathParam = false;
for (Constraints constr : param.getParameter().getConstraints()){
if (constr.name().equals("PATH_TAG_OR_PATH")){
isPathParam = true;
}
}
if (isPathParam){
String paramValue = param.getRhs().getStringValue().get();
if (paramValue.startsWith("tag:")) {
String pathTag = param.getRhs().getStringValue().get().split(":")[1];
String path = (String) layerPathParameterTags.get(pathTag);
param.setRhs(ArchSimpleExpressionSymbol.of(path));
CheckLayerPathParameter.check((LayerSymbol) element, path, pathTag, layerPathParameterTags);
}else{
CheckLayerPathParameter.check((LayerSymbol) element, paramValue, "", layerPathParameterTags);
}
}
}
}
}
}
......
......@@ -58,6 +58,16 @@ public enum Constraints {
return "a string";
}
},
PATH_TAG_OR_PATH {
@Override
public boolean isValid(ArchSimpleExpressionSymbol exp) {
return exp.isString() || exp.isStringTag();
}
@Override
public String msgString() {
return "a path tag or a path string";
}
},
TUPLE {
@Override
public boolean isValid(ArchSimpleExpressionSymbol exp) {
......
/**
*
* (c) https://github.com/MontiCore/monticore
*
* The license generally applicable for this project
* can be found under https://github.com/MontiCore/monticore.
*/
/* (c) https://github.com/MontiCore/monticore */
package de.monticore.lang.monticar.cnnarch._symboltable;
import de.monticore.symboltable.SymbolKind;
public class LayerPathParameterTagKind implements SymbolKind {
private static final String NAME = "de.monticore.lang.monticar.cnnarch._symboltable.LayerPathParameterTagKind";
@Override
public String getName() {
return NAME;
}
@Override
public boolean isKindOf(SymbolKind kind) {
return NAME.equals(kind.getName()) || SymbolKind.super.isKindOf(kind);
}
}
/**
*
* (c) https://github.com/MontiCore/monticore
*
* The license generally applicable for this project
* can be found under https://github.com/MontiCore/monticore.
*/
package de.monticore.lang.monticar.cnnarch._symboltable;
import de.monticore.symboltable.SymbolKind;
import de.monticore.symboltable.CommonSymbol;
import java.util.*;
public abstract class LayerPathParameterTagSymbol extends CommonSymbol {
public static final LayerPathParameterTagKind KIND = new LayerPathParameterTagKind();
private String path;
private String id;
protected LayerPathParameterTagSymbol(String name) {
super(name, KIND);
}
public String getPath() {
return path;
}
public void setPath(String path) {
this.path = path;
}
public String getId() {
return id;
}
public void setId(String id) {
this.id = id;
}
}
......@@ -39,7 +39,8 @@ public class ErrorCodes {
public static final String INVALID_CONSTANT = "x04856";
public static final String INVALID_LARGE_MEMORY_LAYER_PARAMETERS = "x04866";
public static final String INVALID_EPISODIC_MEMORY_LAYER_PLACEMENT = "x04876";
public static final String INVALID_EPISODIC_QUERY_NET_PATH_OR_PREFIX = "x04877";
public static final String INVALID_LAYER_PATH_PARAMETER_PATH = "x04887";
public static final String INVALID_LAYER_PATH_PARAMETER_TAG = "x04888";
public static final String OUTPUT_WRITTEN_TO_MULTIPLE_TIMES = "x04836";
public static final String UNROLL_INPUTS_TOO_MANY = "x02384";
public static final String UNROLL_OUTPUTS_TOO_MANY = "x02385";
......
......@@ -105,7 +105,7 @@ public class EpisodicMemory extends PredefinedLayerDeclaration {
.build(),
new ParameterSymbol.Builder()
.name(AllPredefinedLayers.QUERY_NET_DIR_NAME)
.constraints(Constraints.STRING)
.constraints(Constraints.PATH_TAG_OR_PATH)
.defaultValue(-1)
.build(),
new ParameterSymbol.Builder()
......
......@@ -53,7 +53,7 @@ public class LoadNetwork extends PredefinedLayerDeclaration {
List<ParameterSymbol> parameters = new ArrayList<>(Arrays.asList(
new ParameterSymbol.Builder()
.name(AllPredefinedLayers.NETWORK_DIR_NAME)
.constraints(Constraints.STRING)
.constraints(Constraints.PATH_TAG_OR_PATH)
.build(),
new ParameterSymbol.Builder()
.name(AllPredefinedLayers.NETWORK_PREFIX_NAME)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment