Adaptation for CNNArchGenerator interface

parent fb352802
Pipeline #56336 failed with stages
in 14 seconds
#
#
# ******************************************************************************
# MontiCAR Modeling Family, www.se-rwth.de
# Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
# All rights reserved.
#
# This project is free software; you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public
# License as published by the Free Software Foundation; either
# version 3.0 of the License, or (at your option) any later version.
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public
# License along with this project. If not, see <http://www.gnu.org/licenses/>.
# *******************************************************************************
#
stages:
- windows
- linux
......
......@@ -8,14 +8,14 @@
<groupId>de.monticore.lang.monticar</groupId>
<artifactId>cnnarch-caffe2-generator</artifactId>
<version>0.2.1-SNAPSHOT</version>
<version>0.2.2-SNAPSHOT</version>
<!-- == PROJECT DEPENDENCIES ============================================= -->
<properties>
<!-- .. SE-Libraries .................................................. -->
<CNNArch.version>0.2.1-SNAPSHOT</CNNArch.version>
<CNNArch.version>0.2.2-SNAPSHOT</CNNArch.version>
<!-- .. Libraries .................................................. -->
<guava.version>18.0</guava.version>
......@@ -24,7 +24,6 @@
<jscience.version>4.3.1</jscience.version>
<!-- .. Plugins ....................................................... -->
<monticore.plugin>4.5.3-SNAPSHOT</monticore.plugin>
<assembly.plugin>2.5.4</assembly.plugin>
<compiler.plugin>3.3</compiler.plugin>
<source.plugin>2.4</source.plugin>
......@@ -102,23 +101,6 @@
<plugin>
<artifactId>maven-deploy-plugin</artifactId>
<version>2.8.1</version>
<configuration>
<altDeploymentRepository>internal.repo::default::file://${project.build.directory}/external-dependencies</altDeploymentRepository>
</configuration>
</plugin>
<!-- MontiCore Generation -->
<plugin>
<groupId>de.monticore.mojo</groupId>
<artifactId>monticore-maven-plugin</artifactId>
<version>${monticore.plugin}</version>
<executions>
<execution>
<goals>
<goal>generate</goal>
</goals>
</execution>
</executions>
</plugin>
<!-- Other Configuration -->
......@@ -145,7 +127,7 @@
<configuration>
<archive>
<manifest>
<mainClass>de.monticore.lang.monticar.cnnarchcaffe2.generator.CNNArchGeneratorCliCaffe2</mainClass>
<mainClass>de.monticore.lang.monticar.cnnarch.generator.CNNArch2Caffe2Cli</mainClass>
</manifest>
</archive>
<descriptorRefs>
......@@ -206,16 +188,15 @@
</plugins>
</build>
<distributionManagement>
<repository>
<id>se-nexus</id>
<url>https://nexus.se.rwth-aachen.de/content/repositories/embeddedmontiarc-releases/</url>
</repository>
<snapshotRepository>
<id>se-nexus</id>
<url>https://nexus.se.rwth-aachen.de/content/repositories/embeddedmontiarc-snapshots/</url>
</snapshotRepository>
</distributionManagement>
<distributionManagement>
<repository>
<id>se-nexus</id>
<url>https://nexus.se.rwth-aachen.de/content/repositories/embeddedmontiarc-releases/</url>
</repository>
<snapshotRepository>
<id>se-nexus</id>
<url>https://nexus.se.rwth-aachen.de/content/repositories/embeddedmontiarc-snapshots/</url>
</snapshotRepository>
</distributionManagement>
</project>
<?xml version="1.0" encoding="UTF-8"?>
<!--
******************************************************************************
MontiCAR Modeling Family, www.se-rwth.de
Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
All rights reserved.
This project is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 3.0 of the License, or (at your option) any later version.
This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public
License along with this project. If not, see <http://www.gnu.org/licenses/>.
*******************************************************************************
-->
<settings xmlns="http://maven.apache.org/SETTINGS/1.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
......
......@@ -18,7 +18,7 @@
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package de.monticore.lang.monticar.cnnarchcaffe2.generator;
package de.monticore.lang.monticar.cnnarch.generator;
import de.monticore.lang.monticar.cnnarch._symboltable.ArchTypeSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureElementSymbol;
......@@ -30,13 +30,13 @@ import javax.annotation.Nullable;
import java.util.Arrays;
import java.util.List;
public class ArchitectureElementDataCaffe2 {
public class ArchitectureElementData {
private String name;
private ArchitectureElementSymbol element;
private CNNArchTemplateControllerCaffe2 templateController;
private CNNArchTemplateController templateController;
public ArchitectureElementDataCaffe2(String name, ArchitectureElementSymbol element, CNNArchTemplateControllerCaffe2 templateController) {
public ArchitectureElementData(String name, ArchitectureElementSymbol element, CNNArchTemplateController templateController) {
this.name = name;
this.element = element;
this.templateController = templateController;
......@@ -58,11 +58,11 @@ public class ArchitectureElementDataCaffe2 {
this.element = element;
}
public CNNArchTemplateControllerCaffe2 getTemplateController() {
public CNNArchTemplateController getTemplateController() {
return templateController;
}
public void setTemplateController(CNNArchTemplateControllerCaffe2 templateController) {
public void setTemplateController(CNNArchTemplateController templateController) {
this.templateController = templateController;
}
......
......@@ -18,13 +18,15 @@
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package de.monticore.lang.monticar.cnnarchcaffe2.generator;
package de.monticore.lang.monticar.cnnarch.generator;
import de.monticore.io.paths.ModelPath;
import de.monticore.lang.monticar.cnnarch.CNNArchGenerator;
import de.monticore.lang.monticar.cnnarch._cocos.CNNArchCocos;
import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.CNNArchCompilationUnitSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.CNNArchLanguage;
import de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol;
import de.monticore.symboltable.GlobalScope;
import de.monticore.symboltable.Scope;
import de.se_rwth.commons.logging.Log;
......@@ -33,15 +35,13 @@ import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.nio.file.Path;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
import java.util.*;
public class CNNArchGeneratorCaffe2 {
public class CNNArch2Caffe2 implements CNNArchGenerator{
private String generationTargetPath;
public CNNArchGeneratorCaffe2() {
public CNNArch2Caffe2() {
setGenerationTargetPath("./target/generated-sources-cnnarch/");
}
......@@ -79,22 +79,40 @@ public class CNNArchGeneratorCaffe2 {
}
}
@Override
public Map<String, String> generateTrainer(List<ConfigurationSymbol> configurations, List<String> instanceNames, String mainComponentName) {
int numberOfNetworks = configurations.size();
if (configurations.size() != instanceNames.size()){
throw new IllegalStateException(
"The number of configurations and the number of instances for generation of the CNNTrainer is not equal. " +
"This should have been checked previously.");
}
List<ConfigurationData> configDataList = new ArrayList<>();
for(int i = 0; i < numberOfNetworks; i++){
configDataList.add(new ConfigurationData(configurations.get(i), instanceNames.get(i)));
}
Map<String, Object> ftlContext = Collections.singletonMap("configurations", configDataList);
return Collections.singletonMap(
"CNNTrainer_" + mainComponentName + ".py",
TemplateConfiguration.processTemplate(ftlContext, "CNNTrainer.ftl"));
}
//check cocos with CNNArchCocos.checkAll(architecture) before calling this method.
public Map<String, String> generateStrings(ArchitectureSymbol architecture){
Map<String, String> fileContentMap = new HashMap<>();
CNNArchTemplateControllerCaffe2 archTc = new CNNArchTemplateControllerCaffe2(architecture);
CNNArchTemplateController archTc = new CNNArchTemplateController(architecture);
Map.Entry<String, String> temp;
temp = archTc.process("CNNPredictor", TargetCaffe2.CPP);
temp = archTc.process("CNNPredictor", Target.CPP);
fileContentMap.put(temp.getKey(), temp.getValue());
temp = archTc.process("CNNCreator", TargetCaffe2.PYTHON);
temp = archTc.process("CNNCreator", Target.PYTHON);
fileContentMap.put(temp.getKey(), temp.getValue());
temp = archTc.process("execute", TargetCaffe2.CPP);
temp = archTc.process("execute", Target.CPP);
fileContentMap.put(temp.getKey().replace(".h", ""), temp.getValue());
temp = archTc.process("CNNBufferFile", TargetCaffe2.CPP);
temp = archTc.process("CNNBufferFile", Target.CPP);
fileContentMap.put("CNNBufferFile.h", temp.getValue());
checkValidGeneration(architecture);
......@@ -104,15 +122,13 @@ public class CNNArchGeneratorCaffe2 {
private void checkValidGeneration(ArchitectureSymbol architecture){
if (architecture.getInputs().size() > 1){
Log.warn("This cnn architecture has multiple inputs, " +
"which is currently not supported by the generator. " +
"The generated code will not work correctly."
Log.error("This cnn architecture has multiple inputs, " +
"which is currently not supported by the generator. "
, architecture.getSourcePosition());
}
if (architecture.getOutputs().size() > 1){
Log.warn("This cnn architecture has multiple outputs, " +
"which is currently not supported by the generator. " +
"The generated code will not work correctly."
Log.error("This cnn architecture has multiple outputs, " +
"which is currently not supported by the generator. "
, architecture.getSourcePosition());
}
if (architecture.getOutputs().get(0).getDefinition().getType().getWidth() != 1 ||
......@@ -125,7 +141,7 @@ public class CNNArchGeneratorCaffe2 {
//check cocos with CNNArchCocos.checkAll(architecture) before calling this method.
public void generateFiles(ArchitectureSymbol architecture) throws IOException{
CNNArchTemplateControllerCaffe2 archTc = new CNNArchTemplateControllerCaffe2(architecture);
CNNArchTemplateController archTc = new CNNArchTemplateController(architecture);
Map<String, String> fileContentMap = generateStrings(architecture);
for (String fileName : fileContentMap.keySet()){
......@@ -143,5 +159,4 @@ public class CNNArchGeneratorCaffe2 {
writer.close();
}
}
}
......@@ -18,14 +18,14 @@
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package de.monticore.lang.monticar.cnnarchcaffe2.generator;
package de.monticore.lang.monticar.cnnarch.generator;
import org.apache.commons.cli.*;
import java.nio.file.Path;
import java.nio.file.Paths;
public class CNNArchGeneratorCliCaffe2 {
public class CNNArch2Caffe2Cli {
public static final Option OPTION_MODELS_PATH = Option.builder("m")
.longOpt("models-dir")
......@@ -48,7 +48,7 @@ public class CNNArchGeneratorCliCaffe2 {
.required(false)
.build();
private CNNArchGeneratorCliCaffe2() {
private CNNArch2Caffe2Cli() {
}
public static void main(String[] args) {
......@@ -84,7 +84,7 @@ public class CNNArchGeneratorCliCaffe2 {
Path modelsDirPath = Paths.get(cliArgs.getOptionValue(OPTION_MODELS_PATH.getOpt()));
String rootModelName = cliArgs.getOptionValue(OPTION_ROOT_MODEL.getOpt());
String outputPath = cliArgs.getOptionValue(OPTION_OUTPUT_PATH.getOpt());
CNNArchGeneratorCaffe2 generator = new CNNArchGeneratorCaffe2();
CNNArch2Caffe2 generator = new CNNArch2Caffe2();
if (outputPath != null){
generator.setGenerationTargetPath(outputPath);
}
......
......@@ -18,38 +18,34 @@
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package de.monticore.lang.monticar.cnnarchcaffe2.generator;
package de.monticore.lang.monticar.cnnarch.generator;
import de.monticore.lang.monticar.cnnarch._symboltable.*;
import de.monticore.lang.monticar.cnnarch.predefined.Sigmoid;
import de.monticore.lang.monticar.cnnarch.predefined.Softmax;
import de.se_rwth.commons.logging.Log;
import freemarker.template.Configuration;
import freemarker.template.Template;
import freemarker.template.TemplateException;
import java.io.IOException;
import java.io.StringWriter;
import java.io.Writer;
import java.util.*;
public class CNNArchTemplateControllerCaffe2 {
public class CNNArchTemplateController {
public static final String FTL_FILE_ENDING = ".ftl";
public static final String TEMPLATE_ELEMENTS_DIR_PATH = "elements/";
public static final String TEMPLATE_CONTROLLER_KEY = "tc";
public static final String ELEMENT_DATA_KEY = "element";
private LayerNameCreatorCaffe2 nameManager;
private Configuration freemarkerConfig = TemplateConfigurationCaffe2.get();
private LayerNameCreator nameManager;
private ArchitectureSymbol architecture;
//temporary attributes. They are set after calling process()
private Writer writer;
private String mainTemplateNameWithoutEnding;
private TargetCaffe2 targetCaffe2Language;
private ArchitectureElementDataCaffe2 dataElement;
private Target targetLanguage;
private ArchitectureElementData dataElement;
public CNNArchTemplateControllerCaffe2(ArchitectureSymbol architecture) {
public CNNArchTemplateController(ArchitectureSymbol architecture) {
setArchitecture(architecture);
}
......@@ -57,23 +53,15 @@ public class CNNArchTemplateControllerCaffe2 {
return mainTemplateNameWithoutEnding + "_" + getFullArchitectureName();
}
public TargetCaffe2 getTargetCaffe2Language(){
return targetCaffe2Language;
}
public void setTargetCaffe2Language(TargetCaffe2 targetCaffe2Language) {
this.targetCaffe2Language = targetCaffe2Language;
}
public ArchitectureElementDataCaffe2 getCurrentElement() {
public ArchitectureElementData getCurrentElement() {
return dataElement;
}
public void setCurrentElement(ArchitectureElementSymbol layer) {
this.dataElement = new ArchitectureElementDataCaffe2(getName(layer), layer, this);
this.dataElement = new ArchitectureElementData(getName(layer), layer, this);
}
public void setCurrentElement(ArchitectureElementDataCaffe2 dataElement) {
public void setCurrentElement(ArchitectureElementData dataElement) {
this.dataElement = dataElement;
}
......@@ -83,7 +71,7 @@ public class CNNArchTemplateControllerCaffe2 {
public void setArchitecture(ArchitectureSymbol architecture) {
this.architecture = architecture;
this.nameManager = new LayerNameCreatorCaffe2(architecture);
this.nameManager = new LayerNameCreator(architecture);
}
public String getName(ArchitectureElementSymbol layer){
......@@ -137,29 +125,14 @@ public class CNNArchTemplateControllerCaffe2 {
public void include(String relativePath, String templateWithoutFileEnding, Writer writer){
String templatePath = relativePath + templateWithoutFileEnding + FTL_FILE_ENDING;
try {
Template template = freemarkerConfig.getTemplate(templatePath);
Map<String, Object> ftlContext = new HashMap<>();
ftlContext.put(TEMPLATE_CONTROLLER_KEY, this);
ftlContext.put(ELEMENT_DATA_KEY, getCurrentElement());
this.writer = writer;
template.process(ftlContext, writer);
this.writer = null;
}
catch (IOException e) {
Log.error("Freemarker could not find template " + templatePath + " :\n" + e.getMessage());
System.exit(1);
}
catch (TemplateException e){
Log.error("An exception occured in template " + templatePath + " :\n" + e.getMessage());
System.exit(1);
}
Map<String, Object> ftlContext = new HashMap<>();
ftlContext.put(TEMPLATE_CONTROLLER_KEY, this);
ftlContext.put(ELEMENT_DATA_KEY, getCurrentElement());
TemplateConfiguration.processTemplate(ftlContext, templatePath, writer);
}
public void include(IOSymbol ioElement, Writer writer){
ArchitectureElementDataCaffe2 previousElement = getCurrentElement();
ArchitectureElementData previousElement = getCurrentElement();
setCurrentElement(ioElement);
if (ioElement.isAtomic()){
......@@ -178,7 +151,7 @@ public class CNNArchTemplateControllerCaffe2 {
}
public void include(LayerSymbol layer, Writer writer){
ArchitectureElementDataCaffe2 previousElement = getCurrentElement();
ArchitectureElementData previousElement = getCurrentElement();
setCurrentElement(layer);
if (layer.isAtomic()){
......@@ -196,7 +169,7 @@ public class CNNArchTemplateControllerCaffe2 {
}
public void include(CompositeElementSymbol compositeElement, Writer writer){
ArchitectureElementDataCaffe2 previousElement = getCurrentElement();
ArchitectureElementData previousElement = getCurrentElement();
setCurrentElement(compositeElement);
for (ArchitectureElementSymbol element : compositeElement.getElements()){
......@@ -225,22 +198,20 @@ public class CNNArchTemplateControllerCaffe2 {
include(architectureElement, writer);
}
public Map.Entry<String,String> process(String templateNameWithoutEnding, TargetCaffe2 targetCaffe2Language){
public Map.Entry<String,String> process(String templateNameWithoutEnding, Target targetLanguage){
StringWriter writer = new StringWriter();
this.mainTemplateNameWithoutEnding = templateNameWithoutEnding;
this.targetCaffe2Language = targetCaffe2Language;
include("", templateNameWithoutEnding, writer);
this.targetLanguage = targetLanguage;
this.writer = writer;
String fileEnding = targetCaffe2Language.toString();
if (targetCaffe2Language == TargetCaffe2.CPP){
fileEnding = ".h";
}
include("", templateNameWithoutEnding, writer);
String fileEnding = targetLanguage.toString();
String fileName = getFileNameWithoutEnding() + fileEnding;
Map.Entry<String,String> fileContent = new AbstractMap.SimpleEntry<>(fileName, writer.toString());
this.mainTemplateNameWithoutEnding = null;
this.targetCaffe2Language = null;
this.targetLanguage = null;
this.writer = null;
return fileContent;
}
......
package de.monticore.lang.monticar.cnnarch.generator;
import de.monticore.lang.monticar.cnntrain._symboltable.*;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
public class ConfigurationData {
ConfigurationSymbol configuration;
String instanceName;
public ConfigurationData(ConfigurationSymbol configuration, String instanceName) {
this.configuration = configuration;
this.instanceName = instanceName;
}
public ConfigurationSymbol getConfiguration() {
return configuration;
}
public String getInstanceName() {
return instanceName;
}
public String getNumEpoch() {
return String.valueOf(getConfiguration().getNumEpoch().getValue());
}
public String getBatchSize() {
return String.valueOf(getConfiguration().getBatchSize().getValue());
}
public LoadCheckpointSymbol getLoadCheckpoint() {
return getConfiguration().getLoadCheckpoint();
}
public NormalizeSymbol getNormalize() {
return getConfiguration().getNormalize();
}
public TrainContextSymbol getContext() {
return getConfiguration().getTrainContext();
}
public String getOptimizerName() {
return getConfiguration().getOptimizer().getName();
}
public Map<String, String> getOptimizerParams() {
// get classes for single enum values
List<Class> lrPolicyClasses = new ArrayList<>();
for (LRPolicy enum_value: LRPolicy.values()) {
lrPolicyClasses.add(enum_value.getClass());
}
Map<String, String> mapToStrings = new HashMap<>();
Map<String, OptimizerParamSymbol> optimizerParams = getConfiguration().getOptimizer().getOptimizerParamMap();
for (Map.Entry<String, OptimizerParamSymbol> entry : optimizerParams.entrySet()) {
String paramName = entry.getKey();
String valueAsString = entry.getValue().toString();
Class realClass = entry.getValue().getValue().getValue().getClass();
if (realClass == Boolean.class) {
valueAsString = (Boolean) entry.getValue().getValue().getValue() ? "True" : "False";
}
else if (lrPolicyClasses.contains(realClass)) {
valueAsString = "'" + valueAsString + "'";
}
mapToStrings.put(paramName, valueAsString);
}
return mapToStrings;
}
}
......@@ -18,7 +18,7 @@
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package de.monticore.lang.monticar.cnnarchcaffe2.generator;
package de.monticore.lang.monticar.cnnarch.generator;
import de.monticore.lang.monticar.cnnarch._symboltable.*;
import de.monticore.lang.monticar.cnnarch.predefined.Convolution;
......@@ -27,12 +27,12 @@ import de.monticore.lang.monticar.cnnarch.predefined.Pooling;
import java.util.*;
public class LayerNameCreatorCaffe2 {
public class LayerNameCreator {
private Map<ArchitectureElementSymbol, String> elementToName = new HashMap<>();
private Map<String, ArchitectureElementSymbol> nameToElement = new HashMap<>();
public LayerNameCreatorCaffe2(ArchitectureSymbol architecture) {
public LayerNameCreator(ArchitectureSymbol architecture) {
name(architecture.getBody(), 1, new ArrayList<>());
}
......
......@@ -18,10 +18,10 @@
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package de.monticore.lang.monticar.cnnarchcaffe2.generator;
package de.monticore.lang.monticar.cnnarch.generator;
//can be removed
public enum TargetCaffe2 {
public enum Target {