Commit 44088e98 authored by Evgeny Kusmenko's avatar Evgeny Kusmenko

Merge branch 'develop' into 'master'

Develop

See merge request !9
parents 931b3386 9d0ed96a
Pipeline #267607 passed with stage
in 3 minutes and 1 second
# (c) https://github.com/MontiCore/monticore
stages:
- windows
#- windows
- linux
masterJobLinux:
......@@ -14,12 +14,12 @@ masterJobLinux:
only:
- master
masterJobWindows:
stage: windows
script:
- mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B clean install --settings settings.xml
tags:
- Windows10
#masterJobWindows:
# stage: windows
# script:
# - mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B clean install --settings settings.xml
# tags:
# - Windows10
BranchJobLinux:
stage: linux
......
......@@ -9,7 +9,7 @@
<groupId>de.monticore.lang.monticar</groupId>
<artifactId>cnnarch-generator</artifactId>
<version>0.0.5-SNAPSHOT</version>
<version>0.0.6-SNAPSHOT</version>
<!-- == PROJECT DEPENDENCIES ============================================= -->
......@@ -17,8 +17,8 @@
<!-- .. SE-Libraries .................................................. -->
<CNNArch.version>0.3.4-SNAPSHOT</CNNArch.version>
<CNNTrain.version>0.3.9-SNAPSHOT</CNNTrain.version>
<CNNArch.version>0.3.5-SNAPSHOT</CNNArch.version>
<CNNTrain.version>0.3.10-SNAPSHOT</CNNTrain.version>
<embedded-montiarc-math-opt-generator>0.1.6</embedded-montiarc-math-opt-generator>
<!-- .. Libraries .................................................. -->
......
......@@ -91,6 +91,10 @@ public class ArchitectureElementData {
return getLayerSymbol().getIntTupleValue(AllPredefinedLayers.STRIDE_NAME).get();
}
public int getGroups(){
return getLayerSymbol().getIntValue(AllPredefinedLayers.GROUPS_NAME).get();
}
public int getUnits(){
return getLayerSymbol().getIntValue(AllPredefinedLayers.UNITS_NAME).get();
}
......@@ -225,10 +229,8 @@ public class ArchitectureElementData {
if(pad.equals("same")){
return getTransPadding(getLayerSymbol()); //The padding calculated here is only used in the gluon/ mxnet backend, in the tensorlflow one it is interpreted as "same"
}else if(pad.equals("valid")){
} else { // padding valid
return Arrays.asList(0,0);
}else{ //"no loss"
return Arrays.asList(0,0,-1,0,0,0,0,0);
}
}
......
......@@ -70,10 +70,14 @@ public abstract class CNNArchGenerator {
ArchitectureSymbol architectureSymbol = symbolCompiler.compileArchitectureSymbol(scope, rootModelName);
try{
String confPath = getModelsDirPath() + "/data_paths.txt";
DataPathConfigParser newParserConfig = new DataPathConfigParser(confPath);
String dataPath = newParserConfig.getDataPath(rootModelName);
String dataConfPath = getModelsDirPath() + "/data_paths.txt";
DataPathConfigParser dataParserConfig = new DataPathConfigParser(dataConfPath);
String dataPath = dataParserConfig.getDataPath(rootModelName);
architectureSymbol.setDataPath(dataPath);
String weightsConfPath = getModelsDirPath() + "/weights_paths.txt";
WeightsPathConfigParser weightsParserConfig = new WeightsPathConfigParser(weightsConfPath);
String weightsPath = weightsParserConfig.getWeightsPath(rootModelName);
architectureSymbol.setWeightsPath(weightsPath);
architectureSymbol.setComponentName(rootModelName);
generateFiles(architectureSymbol);
} catch (IOException e){
......
......@@ -10,7 +10,7 @@ import java.io.Writer;
import java.util.*;
public abstract class CNNArchTemplateController {
public static final String FTL_FILE_ENDING = ".ftl";
public static final String TEMPLATE_ELEMENTS_DIR_PATH = "elements/";
public static final String TEMPLATE_CONTROLLER_KEY = "tc";
......@@ -116,7 +116,11 @@ public abstract class CNNArchTemplateController {
public String getDataPath(){
return getArchitecture().getDataPath();
}
public String getWeightsPath(){
return getArchitecture().getWeightsPath();
}
public List<String> getLayerInputs(ArchitectureElementSymbol layer){
List<String> inputNames = new ArrayList<>();
......
......@@ -43,6 +43,62 @@ public class ConfigurationData {
return String.valueOf(getConfiguration().getEntry("batch_size") .getValue());
}
public String getKValue() {
if (!getConfiguration().getEntryMap().containsKey("k_value")) {
return null;
}
return String.valueOf(getConfiguration().getEntry("k_value") .getValue());
}
public String getGeneratorLossWeight() {
if (!getConfiguration().getEntryMap().containsKey("generator_loss_weight")) {
return null;
}
return String.valueOf(getConfiguration().getEntry("generator_loss_weight") .getValue());
}
public String getDiscriminatorLossWeight() {
if (!getConfiguration().getEntryMap().containsKey("discriminator_loss_weight")) {
return null;
}
return String.valueOf(getConfiguration().getEntry("discriminator_loss_weight") .getValue());
}
public String getSpeedPeriod() {
if (!getConfiguration().getEntryMap().containsKey("speed_period")) {
return null;
}
return String.valueOf(getConfiguration().getEntry("speed_period") .getValue());
}
public Boolean getPrintImages() {
if (!getConfiguration().getEntryMap().containsKey("print_images")) {
return null;
}
return (Boolean) getConfiguration().getEntry("print_images").getValue().getValue();
}
public String getGeneratorLoss() {
if (!getConfiguration().getEntryMap().containsKey("generator_loss")) {
return null;
}
return String.valueOf(getConfiguration().getEntry("generator_loss") .getValue());
}
public String getGeneratorTargetName() {
if (!getConfiguration().getEntryMap().containsKey("generator_target_name")) {
return null;
}
return String.valueOf(getConfiguration().getEntry("generator_target_name") .getValue());
}
public String getNoiseInput() {
if (!getConfiguration().getEntryMap().containsKey("noise_input")) {
return null;
}
return String.valueOf(getConfiguration().getEntry("noise_input") .getValue());
}
public Boolean getLoadCheckpoint() {
if (!getConfiguration().getEntryMap().containsKey("load_checkpoint")) {
return null;
......@@ -64,6 +120,13 @@ public class ConfigurationData {
return String.valueOf(getConfiguration().getEntry("log_period").getValue());
}
public Boolean getLoadPretrained() {
if (!getConfiguration().getEntryMap().containsKey("load_pretrained")) {
return null;
}
return (Boolean) getConfiguration().getEntry("load_pretrained").getValue().getValue();
}
public Boolean getNormalize() {
if (!getConfiguration().getEntryMap().containsKey("normalize")) {
return null;
......@@ -92,6 +155,10 @@ public class ConfigurationData {
return (String) getConfiguration().getEntry("preprocessing_name").getValue().toString();
}
public Boolean getPreprocessor() {
return (Boolean) configuration.hasPreprocessor();
}
public String getContext() {
if (!getConfiguration().getEntryMap().containsKey("context")) {
return null;
......@@ -128,7 +195,7 @@ public class ConfigurationData {
} else{
return mapToStrings;}
}
public String getLossWeights() {
if (!getConfiguration().getEntryMap().containsKey("loss_weights")) {
return null;
......
......@@ -40,12 +40,14 @@ public abstract class TrainParamSupportChecker implements CNNTrainVisitor {
public void visit(ASTLoadCheckpointEntry node){}
public void visit(ASTLoadPretrainedEntry node){}
public void visit(ASTNormalizeEntry node){}
public void visit(ASTTrainContextEntry node){}
public void visit(ASTEvalMetricEntry node){}
public void visit(ASTSGDOptimizer node){}
public void visit(ASTAdamOptimizer node){}
......
/* (c) https://github.com/MontiCore/monticore */
package de.monticore.lang.monticar.cnnarch.generator;
import de.se_rwth.commons.logging.Log;
import java.io.*;
import java.util.Properties;
public class WeightsPathConfigParser{
private String configTargetPath;
private Properties properties;
public WeightsPathConfigParser(String configPath) {
setConfigPath(configPath);
properties = new Properties();
try
{
properties.load(new FileInputStream(configTargetPath));
} catch(IOException e)
{
Log.error("Config file " + configPath + " could not be found");
}
}
public String getConfigPath() {
if (configTargetPath.charAt(configTargetPath.length() - 1) != '/') {
this.configTargetPath = configTargetPath + "/";
}
return configTargetPath;
}
public void setConfigPath(String configTargetPath){
this.configTargetPath = configTargetPath;
}
public String getWeightsPath(String modelName) {
String path = properties.getProperty(modelName);
if(path == null) {
Log.warn("Weights path config file did not specify a path for component '" + modelName + "'");
return path;
}
return path;
}
}
/* (c) https://github.com/MontiCore/monticore */
package de.monticore.lang.monticar.cnnarch.generator;
import de.monticore.lang.monticar.cnnarch._symboltable.*;
import de.monticore.lang.monticar.cnnarch.predefined.AllPredefinedVariables;
import de.monticore.symboltable.Scope;
import de.se_rwth.commons.logging.Log;
import org.junit.Before;
import org.junit.Test;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;
public class WeightsPathConfigParserTest extends AbstractSymtabTest {
@Before
public void setUp() {
// ensure an empty log
Log.getFindings().clear();
Log.enableFailQuick(false);
}
@Test
public void testWeightsPathConfigParserValidComponent() {
WeightsPathConfigParser parser = new WeightsPathConfigParser("src/test/resources/architectures/weights_paths.txt");
String weights_path = parser.getWeightsPath("ComponentName");
assertTrue("Wrong weights path returned", weights_path.equals("/path/to/training/weights"));
}
@Test
public void testWeightsPathConfigParserInvalidComponent() {
WeightsPathConfigParser parser = new WeightsPathConfigParser("src/test/resources/architectures/weights_paths.txt");
String weights_path = parser.getWeightsPath("NotExistingComponent");
assertTrue("For not listed components, null should be returned", weights_path == null);
assertTrue(Log.getFindings().size() == 1);
}
@Test
public void testWeightsPathConfigParserInvalidPath() {
WeightsPathConfigParser parser = new WeightsPathConfigParser("invalid/path/weights_paths.txt");
assertTrue(Log.getFindings().size() == 1);
}
}
ComponentName /path/to/training/weights
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment