Commit 59254c06 authored by Julian Dierkes's avatar Julian Dierkes

Merge branch 'develop' of...

Merge branch 'develop' of git.rwth-aachen.de:monticore/EmbeddedMontiArc/generators/cnnarch2x into develop
parents 28de2022 8dc1bb5c
Pipeline #264302 failed with stages
......@@ -91,6 +91,10 @@ public class ArchitectureElementData {
return getLayerSymbol().getIntTupleValue(AllPredefinedLayers.STRIDE_NAME).get();
}
public int getGroups(){
return getLayerSymbol().getIntValue(AllPredefinedLayers.GROUPS_NAME).get();
}
public int getUnits(){
return getLayerSymbol().getIntValue(AllPredefinedLayers.UNITS_NAME).get();
}
......
......@@ -70,10 +70,14 @@ public abstract class CNNArchGenerator {
ArchitectureSymbol architectureSymbol = symbolCompiler.compileArchitectureSymbol(scope, rootModelName);
try{
String confPath = getModelsDirPath() + "/data_paths.txt";
DataPathConfigParser newParserConfig = new DataPathConfigParser(confPath);
String dataPath = newParserConfig.getDataPath(rootModelName);
String dataConfPath = getModelsDirPath() + "/data_paths.txt";
DataPathConfigParser dataParserConfig = new DataPathConfigParser(dataConfPath);
String dataPath = dataParserConfig.getDataPath(rootModelName);
architectureSymbol.setDataPath(dataPath);
String weightsConfPath = getModelsDirPath() + "/weights_paths.txt";
WeightsPathConfigParser weightsParserConfig = new WeightsPathConfigParser(weightsConfPath);
String weightsPath = weightsParserConfig.getWeightsPath(rootModelName);
architectureSymbol.setWeightsPath(weightsPath);
architectureSymbol.setComponentName(rootModelName);
generateFiles(architectureSymbol);
} catch (IOException e){
......
......@@ -10,7 +10,7 @@ import java.io.Writer;
import java.util.*;
public abstract class CNNArchTemplateController {
public static final String FTL_FILE_ENDING = ".ftl";
public static final String TEMPLATE_ELEMENTS_DIR_PATH = "elements/";
public static final String TEMPLATE_CONTROLLER_KEY = "tc";
......@@ -116,7 +116,11 @@ public abstract class CNNArchTemplateController {
public String getDataPath(){
return getArchitecture().getDataPath();
}
public String getWeightsPath(){
return getArchitecture().getWeightsPath();
}
public List<String> getLayerInputs(ArchitectureElementSymbol layer){
List<String> inputNames = new ArrayList<>();
......
......@@ -120,6 +120,13 @@ public class ConfigurationData {
return String.valueOf(getConfiguration().getEntry("log_period").getValue());
}
public Boolean getLoadPretrained() {
if (!getConfiguration().getEntryMap().containsKey("load_pretrained")) {
return null;
}
return (Boolean) getConfiguration().getEntry("load_pretrained").getValue().getValue();
}
public Boolean getNormalize() {
if (!getConfiguration().getEntryMap().containsKey("normalize")) {
return null;
......@@ -188,7 +195,7 @@ public class ConfigurationData {
} else{
return mapToStrings;}
}
public String getLossWeights() {
if (!getConfiguration().getEntryMap().containsKey("loss_weights")) {
return null;
......
......@@ -40,12 +40,14 @@ public abstract class TrainParamSupportChecker implements CNNTrainVisitor {
public void visit(ASTLoadCheckpointEntry node){}
public void visit(ASTLoadPretrainedEntry node){}
public void visit(ASTNormalizeEntry node){}
public void visit(ASTTrainContextEntry node){}
public void visit(ASTEvalMetricEntry node){}
public void visit(ASTSGDOptimizer node){}
public void visit(ASTAdamOptimizer node){}
......
/* (c) https://github.com/MontiCore/monticore */
package de.monticore.lang.monticar.cnnarch.generator;
import de.se_rwth.commons.logging.Log;
import java.io.*;
import java.util.Properties;
public class WeightsPathConfigParser{
private String configTargetPath;
private Properties properties;
public WeightsPathConfigParser(String configPath) {
setConfigPath(configPath);
properties = new Properties();
try
{
properties.load(new FileInputStream(configTargetPath));
} catch(IOException e)
{
Log.error("Config file " + configPath + " could not be found");
}
}
public String getConfigPath() {
if (configTargetPath.charAt(configTargetPath.length() - 1) != '/') {
this.configTargetPath = configTargetPath + "/";
}
return configTargetPath;
}
public void setConfigPath(String configTargetPath){
this.configTargetPath = configTargetPath;
}
public String getWeightsPath(String modelName) {
String path = properties.getProperty(modelName);
if(path == null) {
Log.warn("Weights path config file did not specify a path for component '" + modelName + "'");
return path;
}
return path;
}
}
ComponentName /path/to/training/weights
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment