Commit 6bf25ee8 authored by Evgeny Kusmenko's avatar Evgeny Kusmenko

Merge branch 'oneclick_nn_training' into 'master'

Oneclick nn training

See merge request !24
parents 7db86343 48cedc75
Pipeline #125341 canceled with stages
......@@ -8,14 +8,14 @@
<groupId>de.monticore.lang.monticar</groupId>
<artifactId>cnnarch-mxnet-generator</artifactId>
<version>0.2.13-SNAPSHOT</version>
<version>0.2.14-SNAPSHOT</version>
<!-- == PROJECT DEPENDENCIES ============================================= -->
<properties>
<!-- .. SE-Libraries .................................................. -->
<CNNArch.version>0.2.9</CNNArch.version>
<CNNArch.version>0.3.0-SNAPSHOT</CNNArch.version>
<CNNTrain.version>0.2.6</CNNTrain.version>
<embedded-montiarc-math-opt-generator>0.1.4</embedded-montiarc-math-opt-generator>
......
......@@ -100,4 +100,4 @@
<activeProfiles>
<activeProfile>se-nexus</activeProfile>
</activeProfiles>
</settings>
\ No newline at end of file
</settings>
......@@ -26,6 +26,8 @@ import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureElementSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.CompositeElementSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.CNNArchCompilationUnitSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.CNNArchLanguage;
import de.monticore.lang.monticar.cnnarch.DataPathConfigParser;
import de.monticore.lang.monticar.generator.FileContent;
import de.monticore.lang.monticar.generator.cmake.CMakeConfig;
import de.monticore.lang.monticar.generator.cmake.CMakeFindModule;
......@@ -34,6 +36,8 @@ import de.monticore.symboltable.Scope;
import de.se_rwth.commons.logging.Log;
import java.io.IOException;
import java.lang.System;
import java.nio.file.Path;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
......@@ -87,6 +91,11 @@ public class CNNArch2MxNet extends CNNArchGenerator {
}
try{
String confPath = getModelsDirPath() + "/data_paths.txt";
DataPathConfigParser newParserConfig = new DataPathConfigParser(confPath);
String dataPath = newParserConfig.getDataPath(rootModelName);
compilationUnit.get().getArchitecture().setDataPath(dataPath);
compilationUnit.get().getArchitecture().setComponentName(rootModelName);
generateFiles(compilationUnit.get().getArchitecture());
} catch (IOException e){
Log.error(e.toString());
......
......@@ -132,6 +132,10 @@ public abstract class CNNArchTemplateController {
return getArchitecture().getEnclosingScope().getSpanningSymbol().get().getFullName().replaceAll("\\.","_");
}
public String getDataPath(){
return getArchitecture().getDataPath();
}
public List<String> getLayerInputs(ArchitectureElementSymbol layer){
List<String> inputNames = new ArrayList<>();
......@@ -168,7 +172,11 @@ public abstract class CNNArchTemplateController {
return list;
}
public void include(String relativePath, String templateWithoutFileEnding, Writer writer) {
public String getComponentName(){
return getArchitecture().getComponentName();
}
public void include(String relativePath, String templateWithoutFileEnding, Writer writer){
String templatePath = relativePath + templateWithoutFileEnding + FTL_FILE_ENDING;
Map<String, Object> ftlContext = new HashMap<>();
ftlContext.put(TEMPLATE_CONTROLLER_KEY, this);
......
......@@ -18,9 +18,9 @@ class MyConstant(mx.init.Initializer):
class ${tc.fileNameWithoutEnding}:
module = None
_data_dir_ = "data/${tc.fullArchitectureName}/"
_model_dir_ = "model/${tc.fullArchitectureName}/"
_model_prefix_ = "${tc.architectureName}"
_data_dir_ = "${tc.dataPath}/"
_model_dir_ = "model/${tc.componentName}/"
_model_prefix_ = "model"
_input_names_ = [${tc.join(tc.architectureInputs, ",", "'", "'")}]
_input_shapes_ = [<#list tc.architecture.inputs as input>(${tc.join(input.definition.type.dimensions, ",")})</#list>]
_output_names_ = [${tc.join(tc.architectureOutputs, ",", "'", "_label'")}]
......
......@@ -11,8 +11,8 @@
class ${tc.fileNameWithoutEnding}{
public:
const std::string json_file = "model/${tc.fullArchitectureName}/${tc.architectureName}_newest-symbol.json";
const std::string param_file = "model/${tc.fullArchitectureName}/${tc.architectureName}_newest-0000.params";
const std::string json_file = "model/${tc.componentName}/model_newest-symbol.json";
const std::string param_file = "model/${tc.componentName}/model_newest-0000.params";
//const std::vector<std::string> input_keys = {"data"};
const std::vector<std::string> input_keys = {${tc.join(tc.architectureInputs, ",", "\"", "\"")}};
const std::vector<std::vector<mx_uint>> input_shapes = {<#list tc.architecture.inputs as input>{1,${tc.join(input.definition.type.dimensions, ",")}}<#if input?has_next>,</#if></#list>};
......
Alexnet data/Alexnet
VGG16 data/VGG16
ThreeInputCNN_M14 data/ThreeInputCNN_M14
ResNeXt50 data/ResNeXt50
\ No newline at end of file
Alexnet data/Alexnet
CifarClassifierNetwork data/CifarClassifierNetwork
VGG16 data/VGG16
\ No newline at end of file
......@@ -20,7 +20,7 @@ class CNNCreator_Alexnet:
module = None
_data_dir_ = "data/Alexnet/"
_model_dir_ = "model/Alexnet/"
_model_prefix_ = "Alexnet"
_model_prefix_ = "model"
_input_names_ = ['data']
_input_shapes_ = [(3,224,224)]
_output_names_ = ['predictions_label']
......
......@@ -20,7 +20,7 @@ class CNNCreator_CifarClassifierNetwork:
module = None
_data_dir_ = "data/CifarClassifierNetwork/"
_model_dir_ = "model/CifarClassifierNetwork/"
_model_prefix_ = "CifarClassifierNetwork"
_model_prefix_ = "model"
_input_names_ = ['data']
_input_shapes_ = [(3,32,32)]
_output_names_ = ['softmax_label']
......
......@@ -20,7 +20,7 @@ class CNNCreator_VGG16:
module = None
_data_dir_ = "data/VGG16/"
_model_dir_ = "model/VGG16/"
_model_prefix_ = "VGG16"
_model_prefix_ = "model"
_input_names_ = ['data']
_input_shapes_ = [(3,224,224)]
_output_names_ = ['predictions_label']
......
......@@ -11,8 +11,8 @@
class CNNPredictor_Alexnet{
public:
const std::string json_file = "model/Alexnet/Alexnet_newest-symbol.json";
const std::string param_file = "model/Alexnet/Alexnet_newest-0000.params";
const std::string json_file = "model/Alexnet/model_newest-symbol.json";
const std::string param_file = "model/Alexnet/model_newest-0000.params";
//const std::vector<std::string> input_keys = {"data"};
const std::vector<std::string> input_keys = {"data"};
const std::vector<std::vector<mx_uint>> input_shapes = {{1,3,224,224}};
......
......@@ -11,8 +11,8 @@
class CNNPredictor_CifarClassifierNetwork{
public:
const std::string json_file = "model/CifarClassifierNetwork/CifarClassifierNetwork_newest-symbol.json";
const std::string param_file = "model/CifarClassifierNetwork/CifarClassifierNetwork_newest-0000.params";
const std::string json_file = "model/CifarClassifierNetwork/model_newest-symbol.json";
const std::string param_file = "model/CifarClassifierNetwork/model_newest-0000.params";
//const std::vector<std::string> input_keys = {"data"};
const std::vector<std::string> input_keys = {"data"};
const std::vector<std::vector<mx_uint>> input_shapes = {{1,3,32,32}};
......
......@@ -11,8 +11,8 @@
class CNNPredictor_VGG16{
public:
const std::string json_file = "model/VGG16/VGG16_newest-symbol.json";
const std::string param_file = "model/VGG16/VGG16_newest-0000.params";
const std::string json_file = "model/VGG16/model_newest-symbol.json";
const std::string param_file = "model/VGG16/model_newest-0000.params";
//const std::vector<std::string> input_keys = {"data"};
const std::vector<std::string> input_keys = {"data"};
const std::vector<std::vector<mx_uint>> input_shapes = {{1,3,224,224}};
......
CifarClassifierNetwork data/CifarClassifierNetwork
MultipleOutputs data/MultipleOutputs
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment