Commit 89e98056 authored by Jan Philipp Haller's avatar Jan Philipp Haller
Browse files

Changes for new interface,

updated Generationtest,
updated dependency versions.
parent 104949c3
Pipeline #336041 canceled with stages
......@@ -9,17 +9,15 @@
<groupId>de.monticore.lang.monticar</groupId>
<artifactId>cnnarch-caffe2-generator</artifactId>
<version>0.2.14-SNAPSHOT</version>
<version>0.4.0-SNAPSHOT</version>
<!-- == PROJECT DEPENDENCIES ============================================= -->
<properties>
<!-- .. SE-Libraries .................................................. -->
<CNNArch.version>0.3.4-SNAPSHOT</CNNArch.version>
<CNNTrain.version>0.3.9-SNAPSHOT</CNNTrain.version>
<CNNArch2X.version>0.0.5-SNAPSHOT</CNNArch2X.version>
<embedded-montiarc-math-opt-generator>0.1.4</embedded-montiarc-math-opt-generator>
<CNNArch2X.version>0.4.0-SNAPSHOT</CNNArch2X.version>
<!-- .. Libraries .................................................. -->
<guava.version>18.0</guava.version>
......@@ -67,41 +65,6 @@
<version>${CNNArch2X.version}</version>
</dependency>
<dependency>
<groupId>de.monticore.lang.monticar</groupId>
<artifactId>cnn-arch</artifactId>
<version>${CNNArch.version}</version>
</dependency>
<dependency>
<groupId>de.monticore.lang.monticar</groupId>
<artifactId>cnn-arch</artifactId>
<version>${CNNArch.version}</version>
<classifier>${grammars.classifier}</classifier>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>de.monticore.lang.monticar</groupId>
<artifactId>cnn-train</artifactId>
<version>${CNNTrain.version}</version>
</dependency>
<dependency>
<groupId>de.monticore.lang.monticar</groupId>
<artifactId>cnn-train</artifactId>
<version>${CNNTrain.version}</version>
<classifier>${grammars.classifier}</classifier>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>de.monticore.lang.monticar</groupId>
<artifactId>embedded-montiarc-math-opt-generator</artifactId>
<version>${embedded-montiarc-math-opt-generator}</version>
</dependency>
<!-- .. Test Libraries ............................................... -->
<dependency>
<groupId>junit</groupId>
......
......@@ -2,6 +2,7 @@
package de.monticore.lang.monticar.cnnarch.caffe2generator;
import de.monticore.lang.monticar.cnnarch.generator.CNNArchGenerator;
import de.monticore.lang.monticar.cnnarch.generator.DataPathConfigParser;
import de.monticore.lang.monticar.cnnarch.generator.Target;
......@@ -14,12 +15,16 @@ import de.monticore.lang.monticar.generator.FileContent;
import de.monticore.lang.monticar.generator.cmake.CMakeConfig;
import de.monticore.lang.monticar.generator.cmake.CMakeFindModule;
import de.monticore.lang.monticar.generator.cpp.GeneratorCPP;
import de.monticore.lang.tagging._symboltable.TaggingResolver;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
public class CNNArch2Caffe2 extends CNNArchGenerator {
CMakeConfig cMakeConfig;
public CNNArch2Caffe2() {
architectureSupportChecker = new CNNArch2Caffe2ArchitectureSupportChecker();
......@@ -27,29 +32,31 @@ public class CNNArch2Caffe2 extends CNNArchGenerator {
}
//check cocos with CNNArchCocos.checkAll(architecture) before calling this method.
public Map<String, String> generateStrings(ArchitectureSymbol architecture){
Map<String, String> fileContentMap = new HashMap<>();
public List<FileContent> generateStrings(TaggingResolver taggingResolver, ArchitectureSymbol architecture){
List<FileContent> fileContents = new ArrayList<>();
CNNArchTemplateController archTc = new CNNArchTemplateController(architecture);
Map.Entry<String, String> temp;
FileContent temp;
temp = archTc.process("CNNPredictor", Target.CPP);
fileContentMap.put(temp.getKey(), temp.getValue());
fileContents.add(temp);
temp = archTc.process("CNNCreator", Target.PYTHON);
fileContentMap.put(temp.getKey(), temp.getValue());
fileContents.add(temp);
temp = archTc.process("execute", Target.CPP);
fileContentMap.put(temp.getKey().replace(".h", ""), temp.getValue());
temp.setFileName(temp.getFileName().replace(".h", ""));
fileContents.add(temp);
return fileContentMap;
return fileContents;
}
public Map<String, String> generateCMakeContent(String rootModelName) {
public List<FileContent> generateCMakeContent(String rootModelName) {
List<FileContent> fileContents = new ArrayList<>();
// model name should start with a lower case letter. If it is a component, replace dot . by _
rootModelName = rootModelName.replace('.', '_').replace('[', '_').replace(']', '_');
rootModelName = rootModelName.substring(0, 1).toLowerCase() + rootModelName.substring(1);
CMakeConfig cMakeConfig = new CMakeConfig(rootModelName);
cMakeConfig = new CMakeConfig(rootModelName);
cMakeConfig.addModuleDependency(new CMakeFindModule("Armadillo", true));
cMakeConfig.addModuleDependency(new CMakeFindModule("Caffe2", true));
cMakeConfig.addCMakeCommand("set(LIBS ${LIBS} -lprotobuf -lglog -lgflags)");
......@@ -64,10 +71,13 @@ public class CNNArch2Caffe2 extends CNNArchGenerator {
+ " set(LIBS ${LIBS} caffe2)" + "\n"
+ "endif()");
Map<String,String> fileContentMap = new HashMap<>();
for (FileContent fileContent : cMakeConfig.generateCMakeFiles()){
fileContentMap.put(fileContent.getFileName(), fileContent.getFileContent());
fileContents.add(fileContent);
}
return fileContentMap;
return fileContents;
}
public CMakeConfig getCmakeConfig() {
return this.cMakeConfig;
}
}
......@@ -7,6 +7,7 @@ import de.monticore.lang.monticar.cnnarch.generator.Target;
import de.monticore.lang.monticar.cnnarch._symboltable.*;
import de.monticore.lang.monticar.cnnarch.predefined.Sigmoid;
import de.monticore.lang.monticar.cnnarch.predefined.Softmax;
import de.monticore.lang.monticar.generator.FileContent;
import java.io.StringWriter;
import java.io.Writer;
......@@ -192,17 +193,15 @@ public class CNNArchTemplateController {
include(architectureElement, writer);
}
public Map.Entry<String,String> process(String templateNameWithoutEnding, Target targetLanguage){
public FileContent process(String templateNameWithoutEnding, Target targetLanguage) {
StringWriter newWriter = new StringWriter();
this.mainTemplateNameWithoutEnding = templateNameWithoutEnding;
this.targetLanguage = targetLanguage;
this.writer = newWriter;
include("", templateNameWithoutEnding, newWriter);
this.include("", templateNameWithoutEnding, newWriter);
String fileEnding = targetLanguage.toString();
String fileName = getFileNameWithoutEnding() + fileEnding;
Map.Entry<String,String> fileContent = new AbstractMap.SimpleEntry<>(fileName, newWriter.toString());
String fileName = this.getFileNameWithoutEnding() + fileEnding;
FileContent fileContent = new FileContent(newWriter.toString(), fileName);
this.mainTemplateNameWithoutEnding = null;
this.targetLanguage = null;
this.writer = null;
......
......@@ -28,12 +28,12 @@ public class CNNTrain2Caffe2 extends CNNTrainGenerator {
@Override
public void generate(Path modelsDirPath, String rootModelName) {
ConfigurationSymbol configuration = getConfigurationSymbol(modelsDirPath, rootModelName);
Map<String, String> fileContents = generateStrings(configuration);
List<FileContent> fileContents= generateStrings(configuration);
GeneratorCPP genCPP = new GeneratorCPP();
genCPP.setGenerationTargetPath(getGenerationTargetPath());
try {
for (String fileName : fileContents.keySet()){
genCPP.generateFile(new FileContent(fileContents.get(fileName), fileName));
for (FileContent fileContent : fileContents){
genCPP.generateFile(fileContent);
}
} catch (IOException e) {
Log.error("CNNTrainer file could not be generated" + e.getMessage());
......@@ -41,13 +41,16 @@ public class CNNTrain2Caffe2 extends CNNTrainGenerator {
}
@Override
public Map<String, String> generateStrings(ConfigurationSymbol configuration) {
public List<FileContent> generateStrings(ConfigurationSymbol configuration) {
ConfigurationData configData = new ConfigurationData(configuration, getInstanceName());
List<ConfigurationData> configDataList = new ArrayList<>();
configDataList.add(configData);
Map<String, Object> ftlContext = Collections.singletonMap("configurations", configDataList);
String templateContent = TemplateConfiguration.processTemplate(ftlContext, "CNNTrainer.ftl");
return Collections.singletonMap("CNNTrainer_" + getInstanceName() + ".py", templateContent);
List<FileContent> fileContents = new ArrayList<>();
FileContent temp = new FileContent(templateContent, "CNNTrainer_" + getInstanceName() + ".py");
fileContents.add(temp);
return fileContents;
}
}
......@@ -193,9 +193,7 @@ public class GenerationTest extends AbstractSymtabTest{
String rootModelName = "alexnet";
CNNArch2Caffe2 generator = new CNNArch2Caffe2();
generator.setGenerationTargetPath("./target/generated-sources-cnnarch");
if(generator.isCMakeRequired()){
generator.generateCMake(rootModelName);
}
generator.generateCMake(rootModelName);
assertTrue(Log.getFindings().isEmpty());
......
......@@ -10,15 +10,19 @@ set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} ${CMAKE_CURRENT_SOURCE_DIR}/cmake)
find_package(Armadillo REQUIRED)
set(INCLUDE_DIRS ${INCLUDE_DIRS} ${Armadillo_INCLUDE_DIRS})
set(LIBS ${LIBS} ${Armadillo_LIBRARIES})
find_package(Caffe2 REQUIRED)
set(INCLUDE_DIRS ${INCLUDE_DIRS} ${Caffe2_INCLUDE_DIRS})
set(LIBS ${LIBS} ${Caffe2_LIBRARIES})
# additional library linkage
# additional commands
set(LIBS ${LIBS} -lprotobuf -lglog -lgflags)
find_package(CUDA)
set(LIBS ${LIBS} -lprotobuf -lglog -lgflags)
find_package(CUDA)
if(CUDA_FOUND)
if(CUDA_FOUND)
set(LIBS ${LIBS} caffe2 caffe2_gpu)
set(INCLUDE_DIRS ${INCLUDE_DIRS} ${CUDA_INCLUDE_DIRS})
set(LIBS ${LIBS} ${CUDA_LIBRARIES} ${CUDA_curand_LIBRARY})
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment