Commit 89e98056 authored by Jan Philipp Haller's avatar Jan Philipp Haller
Browse files

Changes for new interface,

updated Generationtest,
updated dependency versions.
parent 104949c3
Pipeline #336041 canceled with stages
...@@ -9,17 +9,15 @@ ...@@ -9,17 +9,15 @@
<groupId>de.monticore.lang.monticar</groupId> <groupId>de.monticore.lang.monticar</groupId>
<artifactId>cnnarch-caffe2-generator</artifactId> <artifactId>cnnarch-caffe2-generator</artifactId>
<version>0.2.14-SNAPSHOT</version> <version>0.4.0-SNAPSHOT</version>
<!-- == PROJECT DEPENDENCIES ============================================= --> <!-- == PROJECT DEPENDENCIES ============================================= -->
<properties> <properties>
<!-- .. SE-Libraries .................................................. --> <!-- .. SE-Libraries .................................................. -->
<CNNArch.version>0.3.4-SNAPSHOT</CNNArch.version> <CNNArch2X.version>0.4.0-SNAPSHOT</CNNArch2X.version>
<CNNTrain.version>0.3.9-SNAPSHOT</CNNTrain.version>
<CNNArch2X.version>0.0.5-SNAPSHOT</CNNArch2X.version>
<embedded-montiarc-math-opt-generator>0.1.4</embedded-montiarc-math-opt-generator>
<!-- .. Libraries .................................................. --> <!-- .. Libraries .................................................. -->
<guava.version>18.0</guava.version> <guava.version>18.0</guava.version>
...@@ -67,41 +65,6 @@ ...@@ -67,41 +65,6 @@
<version>${CNNArch2X.version}</version> <version>${CNNArch2X.version}</version>
</dependency> </dependency>
<dependency>
<groupId>de.monticore.lang.monticar</groupId>
<artifactId>cnn-arch</artifactId>
<version>${CNNArch.version}</version>
</dependency>
<dependency>
<groupId>de.monticore.lang.monticar</groupId>
<artifactId>cnn-arch</artifactId>
<version>${CNNArch.version}</version>
<classifier>${grammars.classifier}</classifier>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>de.monticore.lang.monticar</groupId>
<artifactId>cnn-train</artifactId>
<version>${CNNTrain.version}</version>
</dependency>
<dependency>
<groupId>de.monticore.lang.monticar</groupId>
<artifactId>cnn-train</artifactId>
<version>${CNNTrain.version}</version>
<classifier>${grammars.classifier}</classifier>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>de.monticore.lang.monticar</groupId>
<artifactId>embedded-montiarc-math-opt-generator</artifactId>
<version>${embedded-montiarc-math-opt-generator}</version>
</dependency>
<!-- .. Test Libraries ............................................... --> <!-- .. Test Libraries ............................................... -->
<dependency> <dependency>
<groupId>junit</groupId> <groupId>junit</groupId>
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
package de.monticore.lang.monticar.cnnarch.caffe2generator; package de.monticore.lang.monticar.cnnarch.caffe2generator;
import de.monticore.lang.monticar.cnnarch.generator.CNNArchGenerator; import de.monticore.lang.monticar.cnnarch.generator.CNNArchGenerator;
import de.monticore.lang.monticar.cnnarch.generator.DataPathConfigParser; import de.monticore.lang.monticar.cnnarch.generator.DataPathConfigParser;
import de.monticore.lang.monticar.cnnarch.generator.Target; import de.monticore.lang.monticar.cnnarch.generator.Target;
...@@ -14,12 +15,16 @@ import de.monticore.lang.monticar.generator.FileContent; ...@@ -14,12 +15,16 @@ import de.monticore.lang.monticar.generator.FileContent;
import de.monticore.lang.monticar.generator.cmake.CMakeConfig; import de.monticore.lang.monticar.generator.cmake.CMakeConfig;
import de.monticore.lang.monticar.generator.cmake.CMakeFindModule; import de.monticore.lang.monticar.generator.cmake.CMakeFindModule;
import de.monticore.lang.monticar.generator.cpp.GeneratorCPP; import de.monticore.lang.monticar.generator.cpp.GeneratorCPP;
import de.monticore.lang.tagging._symboltable.TaggingResolver;
import java.io.IOException; import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap; import java.util.HashMap;
import java.util.List;
import java.util.Map; import java.util.Map;
public class CNNArch2Caffe2 extends CNNArchGenerator { public class CNNArch2Caffe2 extends CNNArchGenerator {
CMakeConfig cMakeConfig;
public CNNArch2Caffe2() { public CNNArch2Caffe2() {
architectureSupportChecker = new CNNArch2Caffe2ArchitectureSupportChecker(); architectureSupportChecker = new CNNArch2Caffe2ArchitectureSupportChecker();
...@@ -27,29 +32,31 @@ public class CNNArch2Caffe2 extends CNNArchGenerator { ...@@ -27,29 +32,31 @@ public class CNNArch2Caffe2 extends CNNArchGenerator {
} }
//check cocos with CNNArchCocos.checkAll(architecture) before calling this method. //check cocos with CNNArchCocos.checkAll(architecture) before calling this method.
public Map<String, String> generateStrings(ArchitectureSymbol architecture){ public List<FileContent> generateStrings(TaggingResolver taggingResolver, ArchitectureSymbol architecture){
Map<String, String> fileContentMap = new HashMap<>(); List<FileContent> fileContents = new ArrayList<>();
CNNArchTemplateController archTc = new CNNArchTemplateController(architecture); CNNArchTemplateController archTc = new CNNArchTemplateController(architecture);
Map.Entry<String, String> temp; FileContent temp;
temp = archTc.process("CNNPredictor", Target.CPP); temp = archTc.process("CNNPredictor", Target.CPP);
fileContentMap.put(temp.getKey(), temp.getValue()); fileContents.add(temp);
temp = archTc.process("CNNCreator", Target.PYTHON); temp = archTc.process("CNNCreator", Target.PYTHON);
fileContentMap.put(temp.getKey(), temp.getValue()); fileContents.add(temp);
temp = archTc.process("execute", Target.CPP); temp = archTc.process("execute", Target.CPP);
fileContentMap.put(temp.getKey().replace(".h", ""), temp.getValue()); temp.setFileName(temp.getFileName().replace(".h", ""));
fileContents.add(temp);
return fileContentMap; return fileContents;
} }
public Map<String, String> generateCMakeContent(String rootModelName) { public List<FileContent> generateCMakeContent(String rootModelName) {
List<FileContent> fileContents = new ArrayList<>();
// model name should start with a lower case letter. If it is a component, replace dot . by _ // model name should start with a lower case letter. If it is a component, replace dot . by _
rootModelName = rootModelName.replace('.', '_').replace('[', '_').replace(']', '_'); rootModelName = rootModelName.replace('.', '_').replace('[', '_').replace(']', '_');
rootModelName = rootModelName.substring(0, 1).toLowerCase() + rootModelName.substring(1); rootModelName = rootModelName.substring(0, 1).toLowerCase() + rootModelName.substring(1);
CMakeConfig cMakeConfig = new CMakeConfig(rootModelName); cMakeConfig = new CMakeConfig(rootModelName);
cMakeConfig.addModuleDependency(new CMakeFindModule("Armadillo", true)); cMakeConfig.addModuleDependency(new CMakeFindModule("Armadillo", true));
cMakeConfig.addModuleDependency(new CMakeFindModule("Caffe2", true)); cMakeConfig.addModuleDependency(new CMakeFindModule("Caffe2", true));
cMakeConfig.addCMakeCommand("set(LIBS ${LIBS} -lprotobuf -lglog -lgflags)"); cMakeConfig.addCMakeCommand("set(LIBS ${LIBS} -lprotobuf -lglog -lgflags)");
...@@ -64,10 +71,13 @@ public class CNNArch2Caffe2 extends CNNArchGenerator { ...@@ -64,10 +71,13 @@ public class CNNArch2Caffe2 extends CNNArchGenerator {
+ " set(LIBS ${LIBS} caffe2)" + "\n" + " set(LIBS ${LIBS} caffe2)" + "\n"
+ "endif()"); + "endif()");
Map<String,String> fileContentMap = new HashMap<>();
for (FileContent fileContent : cMakeConfig.generateCMakeFiles()){ for (FileContent fileContent : cMakeConfig.generateCMakeFiles()){
fileContentMap.put(fileContent.getFileName(), fileContent.getFileContent()); fileContents.add(fileContent);
} }
return fileContentMap; return fileContents;
}
public CMakeConfig getCmakeConfig() {
return this.cMakeConfig;
} }
} }
...@@ -7,6 +7,7 @@ import de.monticore.lang.monticar.cnnarch.generator.Target; ...@@ -7,6 +7,7 @@ import de.monticore.lang.monticar.cnnarch.generator.Target;
import de.monticore.lang.monticar.cnnarch._symboltable.*; import de.monticore.lang.monticar.cnnarch._symboltable.*;
import de.monticore.lang.monticar.cnnarch.predefined.Sigmoid; import de.monticore.lang.monticar.cnnarch.predefined.Sigmoid;
import de.monticore.lang.monticar.cnnarch.predefined.Softmax; import de.monticore.lang.monticar.cnnarch.predefined.Softmax;
import de.monticore.lang.monticar.generator.FileContent;
import java.io.StringWriter; import java.io.StringWriter;
import java.io.Writer; import java.io.Writer;
...@@ -192,17 +193,15 @@ public class CNNArchTemplateController { ...@@ -192,17 +193,15 @@ public class CNNArchTemplateController {
include(architectureElement, writer); include(architectureElement, writer);
} }
public Map.Entry<String,String> process(String templateNameWithoutEnding, Target targetLanguage){ public FileContent process(String templateNameWithoutEnding, Target targetLanguage) {
StringWriter newWriter = new StringWriter(); StringWriter newWriter = new StringWriter();
this.mainTemplateNameWithoutEnding = templateNameWithoutEnding; this.mainTemplateNameWithoutEnding = templateNameWithoutEnding;
this.targetLanguage = targetLanguage; this.targetLanguage = targetLanguage;
this.writer = newWriter; this.writer = newWriter;
this.include("", templateNameWithoutEnding, newWriter);
include("", templateNameWithoutEnding, newWriter);
String fileEnding = targetLanguage.toString(); String fileEnding = targetLanguage.toString();
String fileName = getFileNameWithoutEnding() + fileEnding; String fileName = this.getFileNameWithoutEnding() + fileEnding;
Map.Entry<String,String> fileContent = new AbstractMap.SimpleEntry<>(fileName, newWriter.toString()); FileContent fileContent = new FileContent(newWriter.toString(), fileName);
this.mainTemplateNameWithoutEnding = null; this.mainTemplateNameWithoutEnding = null;
this.targetLanguage = null; this.targetLanguage = null;
this.writer = null; this.writer = null;
......
...@@ -28,12 +28,12 @@ public class CNNTrain2Caffe2 extends CNNTrainGenerator { ...@@ -28,12 +28,12 @@ public class CNNTrain2Caffe2 extends CNNTrainGenerator {
@Override @Override
public void generate(Path modelsDirPath, String rootModelName) { public void generate(Path modelsDirPath, String rootModelName) {
ConfigurationSymbol configuration = getConfigurationSymbol(modelsDirPath, rootModelName); ConfigurationSymbol configuration = getConfigurationSymbol(modelsDirPath, rootModelName);
Map<String, String> fileContents = generateStrings(configuration); List<FileContent> fileContents= generateStrings(configuration);
GeneratorCPP genCPP = new GeneratorCPP(); GeneratorCPP genCPP = new GeneratorCPP();
genCPP.setGenerationTargetPath(getGenerationTargetPath()); genCPP.setGenerationTargetPath(getGenerationTargetPath());
try { try {
for (String fileName : fileContents.keySet()){ for (FileContent fileContent : fileContents){
genCPP.generateFile(new FileContent(fileContents.get(fileName), fileName)); genCPP.generateFile(fileContent);
} }
} catch (IOException e) { } catch (IOException e) {
Log.error("CNNTrainer file could not be generated" + e.getMessage()); Log.error("CNNTrainer file could not be generated" + e.getMessage());
...@@ -41,13 +41,16 @@ public class CNNTrain2Caffe2 extends CNNTrainGenerator { ...@@ -41,13 +41,16 @@ public class CNNTrain2Caffe2 extends CNNTrainGenerator {
} }
@Override @Override
public Map<String, String> generateStrings(ConfigurationSymbol configuration) { public List<FileContent> generateStrings(ConfigurationSymbol configuration) {
ConfigurationData configData = new ConfigurationData(configuration, getInstanceName()); ConfigurationData configData = new ConfigurationData(configuration, getInstanceName());
List<ConfigurationData> configDataList = new ArrayList<>(); List<ConfigurationData> configDataList = new ArrayList<>();
configDataList.add(configData); configDataList.add(configData);
Map<String, Object> ftlContext = Collections.singletonMap("configurations", configDataList); Map<String, Object> ftlContext = Collections.singletonMap("configurations", configDataList);
String templateContent = TemplateConfiguration.processTemplate(ftlContext, "CNNTrainer.ftl"); String templateContent = TemplateConfiguration.processTemplate(ftlContext, "CNNTrainer.ftl");
return Collections.singletonMap("CNNTrainer_" + getInstanceName() + ".py", templateContent); List<FileContent> fileContents = new ArrayList<>();
FileContent temp = new FileContent(templateContent, "CNNTrainer_" + getInstanceName() + ".py");
fileContents.add(temp);
return fileContents;
} }
} }
...@@ -193,9 +193,7 @@ public class GenerationTest extends AbstractSymtabTest{ ...@@ -193,9 +193,7 @@ public class GenerationTest extends AbstractSymtabTest{
String rootModelName = "alexnet"; String rootModelName = "alexnet";
CNNArch2Caffe2 generator = new CNNArch2Caffe2(); CNNArch2Caffe2 generator = new CNNArch2Caffe2();
generator.setGenerationTargetPath("./target/generated-sources-cnnarch"); generator.setGenerationTargetPath("./target/generated-sources-cnnarch");
if(generator.isCMakeRequired()){ generator.generateCMake(rootModelName);
generator.generateCMake(rootModelName);
}
assertTrue(Log.getFindings().isEmpty()); assertTrue(Log.getFindings().isEmpty());
......
...@@ -10,15 +10,19 @@ set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} ${CMAKE_CURRENT_SOURCE_DIR}/cmake) ...@@ -10,15 +10,19 @@ set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} ${CMAKE_CURRENT_SOURCE_DIR}/cmake)
find_package(Armadillo REQUIRED) find_package(Armadillo REQUIRED)
set(INCLUDE_DIRS ${INCLUDE_DIRS} ${Armadillo_INCLUDE_DIRS}) set(INCLUDE_DIRS ${INCLUDE_DIRS} ${Armadillo_INCLUDE_DIRS})
set(LIBS ${LIBS} ${Armadillo_LIBRARIES}) set(LIBS ${LIBS} ${Armadillo_LIBRARIES})
find_package(Caffe2 REQUIRED) find_package(Caffe2 REQUIRED)
set(INCLUDE_DIRS ${INCLUDE_DIRS} ${Caffe2_INCLUDE_DIRS}) set(INCLUDE_DIRS ${INCLUDE_DIRS} ${Caffe2_INCLUDE_DIRS})
set(LIBS ${LIBS} ${Caffe2_LIBRARIES}) set(LIBS ${LIBS} ${Caffe2_LIBRARIES})
# additional library linkage
# additional commands # additional commands
set(LIBS ${LIBS} -lprotobuf -lglog -lgflags) set(LIBS ${LIBS} -lprotobuf -lglog -lgflags)
find_package(CUDA) find_package(CUDA)
if(CUDA_FOUND) if(CUDA_FOUND)
set(LIBS ${LIBS} caffe2 caffe2_gpu) set(LIBS ${LIBS} caffe2 caffe2_gpu)
set(INCLUDE_DIRS ${INCLUDE_DIRS} ${CUDA_INCLUDE_DIRS}) set(INCLUDE_DIRS ${INCLUDE_DIRS} ${CUDA_INCLUDE_DIRS})
set(LIBS ${LIBS} ${CUDA_LIBRARIES} ${CUDA_curand_LIBRARY}) set(LIBS ${LIBS} ${CUDA_LIBRARIES} ${CUDA_curand_LIBRARY})
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment