Commit b5d928ee authored by Nicola Gatto's avatar Nicola Gatto Committed by Evgeny Kusmenko
Browse files

Make generator classes reusable

parent 8099a9c9
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
<groupId>de.monticore.lang.monticar</groupId> <groupId>de.monticore.lang.monticar</groupId>
<artifactId>cnnarch-mxnet-generator</artifactId> <artifactId>cnnarch-mxnet-generator</artifactId>
<version>0.2.12</version> <version>0.2.13-SNAPSHOT</version>
<!-- == PROJECT DEPENDENCIES ============================================= --> <!-- == PROJECT DEPENDENCIES ============================================= -->
......
...@@ -95,8 +95,11 @@ public class CNNArch2MxNet extends CNNArchGenerator { ...@@ -95,8 +95,11 @@ public class CNNArch2MxNet extends CNNArchGenerator {
//check cocos with CNNArchCocos.checkAll(architecture) before calling this method. //check cocos with CNNArchCocos.checkAll(architecture) before calling this method.
public Map<String, String> generateStrings(ArchitectureSymbol architecture){ public Map<String, String> generateStrings(ArchitectureSymbol architecture){
TemplateConfiguration templateConfiguration = new MxNetTemplateConfiguration();
Map<String, String> fileContentMap = new HashMap<>(); Map<String, String> fileContentMap = new HashMap<>();
CNNArchTemplateController archTc = new CNNArchTemplateController(architecture); CNNArch2MxNetTemplateController archTc
= new CNNArch2MxNetTemplateController(architecture, templateConfiguration);
Map.Entry<String, String> temp; Map.Entry<String, String> temp;
temp = archTc.process("CNNPredictor", Target.CPP); temp = archTc.process("CNNPredictor", Target.CPP);
......
...@@ -19,81 +19,12 @@ ...@@ -19,81 +19,12 @@
* ******************************************************************************* * *******************************************************************************
*/ */
package de.monticore.lang.monticar.cnnarch.mxnetgenerator; package de.monticore.lang.monticar.cnnarch.mxnetgenerator;
import de.se_rwth.commons.logging.Log; import de.monticore.lang.monticar.cnnarch.CNNArchGenerator;
import org.apache.commons.cli.*;
import java.nio.file.Path;
import java.nio.file.Paths;
public class CNNArch2MxNetCli { public class CNNArch2MxNetCli {
public static final Option OPTION_MODELS_PATH = Option.builder("m")
.longOpt("models-dir")
.desc("full path to the directory with the CNNArch model")
.hasArg(true)
.required(true)
.build();
public static final Option OPTION_ROOT_MODEL = Option.builder("r")
.longOpt("root-model")
.desc("name of the architecture")
.hasArg(true)
.required(true)
.build();
public static final Option OPTION_OUTPUT_PATH = Option.builder("o")
.longOpt("output-dir")
.desc("full path to output directory for tests")
.hasArg(true)
.required(false)
.build();
private CNNArch2MxNetCli() {
}
public static void main(String[] args) { public static void main(String[] args) {
Options options = getOptions(); CNNArchGenerator generator = new CNNArch2MxNet();
CommandLineParser parser = new DefaultParser(); GenericCNNArchCli cli = new GenericCNNArchCli(generator);
CommandLine cliArgs = parseArgs(options, parser, args); cli.run(args);
if (cliArgs != null) {
runGenerator(cliArgs);
}
}
private static Options getOptions() {
Options options = new Options();
options.addOption(OPTION_MODELS_PATH);
options.addOption(OPTION_ROOT_MODEL);
options.addOption(OPTION_OUTPUT_PATH);
return options;
}
private static CommandLine parseArgs(Options options, CommandLineParser parser, String[] args) {
CommandLine cliArgs;
try {
cliArgs = parser.parse(options, args);
} catch (ParseException e) {
Log.error("argument parsing exception: " + e.getMessage());
quitGeneration();
return null;
}
return cliArgs;
}
private static void quitGeneration(){
Log.error("Code generation is aborted");
System.exit(1);
}
private static void runGenerator(CommandLine cliArgs) {
Path modelsDirPath = Paths.get(cliArgs.getOptionValue(OPTION_MODELS_PATH.getOpt()));
String rootModelName = cliArgs.getOptionValue(OPTION_ROOT_MODEL.getOpt());
String outputPath = cliArgs.getOptionValue(OPTION_OUTPUT_PATH.getOpt());
CNNArch2MxNet generator = new CNNArch2MxNet();
if (outputPath != null){
generator.setGenerationTargetPath(outputPath);
}
generator.generate(modelsDirPath, rootModelName);
} }
} }
\ No newline at end of file
package de.monticore.lang.monticar.cnnarch.mxnetgenerator;
import de.monticore.lang.monticar.cnnarch._symboltable.*;
import java.io.Writer;
/**
*
*/
public class CNNArch2MxNetTemplateController extends CNNArchTemplateController {
public CNNArch2MxNetTemplateController(ArchitectureSymbol architecture,
TemplateConfiguration templateConfiguration) {
super(architecture, templateConfiguration);
}
public void include(IOSymbol ioElement, Writer writer){
ArchitectureElementData previousElement = getCurrentElement();
setCurrentElement(ioElement);
if (ioElement.isAtomic()){
if (ioElement.isInput()){
include(TEMPLATE_ELEMENTS_DIR_PATH, "Input", writer);
} else {
include(TEMPLATE_ELEMENTS_DIR_PATH, "Output", writer);
}
} else {
include(ioElement.getResolvedThis().get(), writer);
}
setCurrentElement(previousElement);
}
public void include(LayerSymbol layer, Writer writer){
ArchitectureElementData previousElement = getCurrentElement();
setCurrentElement(layer);
if (layer.isAtomic()){
ArchitectureElementSymbol nextElement = layer.getOutputElement().get();
if (!isSoftmaxOutput(nextElement) && !isLogisticRegressionOutput(nextElement)){
String templateName = layer.getDeclaration().getName();
include(TEMPLATE_ELEMENTS_DIR_PATH, templateName, writer);
}
} else {
include(layer.getResolvedThis().get(), writer);
}
setCurrentElement(previousElement);
}
public void include(CompositeElementSymbol compositeElement, Writer writer){
ArchitectureElementData previousElement = getCurrentElement();
setCurrentElement(compositeElement);
for (ArchitectureElementSymbol element : compositeElement.getElements()){
include(element, writer);
}
setCurrentElement(previousElement);
}
public void include(ArchitectureElementSymbol architectureElement, Writer writer){
if (architectureElement instanceof CompositeElementSymbol){
include((CompositeElementSymbol) architectureElement, writer);
} else if (architectureElement instanceof LayerSymbol){
include((LayerSymbol) architectureElement, writer);
} else {
include((IOSymbol) architectureElement, writer);
}
}
public void include(ArchitectureElementSymbol architectureElement){
if (getWriter() == null){
throw new IllegalStateException("missing writer");
}
include(architectureElement, getWriter());
}
}
\ No newline at end of file
...@@ -28,13 +28,15 @@ import java.io.StringWriter; ...@@ -28,13 +28,15 @@ import java.io.StringWriter;
import java.io.Writer; import java.io.Writer;
import java.util.*; import java.util.*;
public class CNNArchTemplateController { public abstract class CNNArchTemplateController {
public static final String FTL_FILE_ENDING = ".ftl"; public static final String FTL_FILE_ENDING = ".ftl";
public static final String TEMPLATE_ELEMENTS_DIR_PATH = "elements/"; public static final String TEMPLATE_ELEMENTS_DIR_PATH = "elements/";
public static final String TEMPLATE_CONTROLLER_KEY = "tc"; public static final String TEMPLATE_CONTROLLER_KEY = "tc";
public static final String ELEMENT_DATA_KEY = "element"; public static final String ELEMENT_DATA_KEY = "element";
private final TemplateConfiguration templateConfiguration;
private LayerNameCreator nameManager; private LayerNameCreator nameManager;
private ArchitectureSymbol architecture; private ArchitectureSymbol architecture;
...@@ -44,9 +46,53 @@ public class CNNArchTemplateController { ...@@ -44,9 +46,53 @@ public class CNNArchTemplateController {
private Target targetLanguage; private Target targetLanguage;
private ArchitectureElementData dataElement; private ArchitectureElementData dataElement;
protected CNNArchTemplateController(ArchitectureSymbol architecture, TemplateConfiguration templateConfiguration) {
public CNNArchTemplateController(ArchitectureSymbol architecture) {
setArchitecture(architecture); setArchitecture(architecture);
this.templateConfiguration = templateConfiguration;
}
protected TemplateConfiguration getTemplateConfiguration() {
return templateConfiguration;
}
protected LayerNameCreator getNameManager() {
return nameManager;
}
protected void setNameManager(LayerNameCreator nameManager) {
this.nameManager = nameManager;
}
protected Writer getWriter() {
return writer;
}
protected void setWriter(Writer writer) {
this.writer = writer;
}
protected String getMainTemplateNameWithoutEnding() {
return mainTemplateNameWithoutEnding;
}
protected void setMainTemplateNameWithoutEnding(String mainTemplateNameWithoutEnding) {
this.mainTemplateNameWithoutEnding = mainTemplateNameWithoutEnding;
}
protected Target getTargetLanguage() {
return targetLanguage;
}
protected void setTargetLanguage(Target targetLanguage) {
this.targetLanguage = targetLanguage;
}
protected ArchitectureElementData getDataElement() {
return dataElement;
}
protected void setDataElement(ArchitectureElementData dataElement) {
this.dataElement = dataElement;
} }
public String getFileNameWithoutEnding() { public String getFileNameWithoutEnding() {
...@@ -122,74 +168,12 @@ public class CNNArchTemplateController { ...@@ -122,74 +168,12 @@ public class CNNArchTemplateController {
return list; return list;
} }
public void include(String relativePath, String templateWithoutFileEnding, Writer writer){ public void include(String relativePath, String templateWithoutFileEnding, Writer writer) {
String templatePath = relativePath + templateWithoutFileEnding + FTL_FILE_ENDING; String templatePath = relativePath + templateWithoutFileEnding + FTL_FILE_ENDING;
Map<String, Object> ftlContext = new HashMap<>(); Map<String, Object> ftlContext = new HashMap<>();
ftlContext.put(TEMPLATE_CONTROLLER_KEY, this); ftlContext.put(TEMPLATE_CONTROLLER_KEY, this);
ftlContext.put(ELEMENT_DATA_KEY, getCurrentElement()); ftlContext.put(ELEMENT_DATA_KEY, getCurrentElement());
TemplateConfiguration.processTemplate(ftlContext, templatePath, writer); templateConfiguration.processTemplate(ftlContext, templatePath, writer);
}
public void include(IOSymbol ioElement, Writer writer){
ArchitectureElementData previousElement = getCurrentElement();
setCurrentElement(ioElement);
if (ioElement.isAtomic()){
if (ioElement.isInput()){
include(TEMPLATE_ELEMENTS_DIR_PATH, "Input", writer);
} else {
include(TEMPLATE_ELEMENTS_DIR_PATH, "Output", writer);
}
} else {
include(ioElement.getResolvedThis().get(), writer);
}
setCurrentElement(previousElement);
}
public void include(LayerSymbol layer, Writer writer){
ArchitectureElementData previousElement = getCurrentElement();
setCurrentElement(layer);
if (layer.isAtomic()){
ArchitectureElementSymbol nextElement = layer.getOutputElement().get();
if (!isSoftmaxOutput(nextElement) && !isLogisticRegressionOutput(nextElement)){
String templateName = layer.getDeclaration().getName();
include(TEMPLATE_ELEMENTS_DIR_PATH, templateName, writer);
}
} else {
include(layer.getResolvedThis().get(), writer);
}
setCurrentElement(previousElement);
}
public void include(CompositeElementSymbol compositeElement, Writer writer){
ArchitectureElementData previousElement = getCurrentElement();
setCurrentElement(compositeElement);
for (ArchitectureElementSymbol element : compositeElement.getElements()){
include(element, writer);
}
setCurrentElement(previousElement);
}
public void include(ArchitectureElementSymbol architectureElement, Writer writer){
if (architectureElement instanceof CompositeElementSymbol){
include((CompositeElementSymbol) architectureElement, writer);
} else if (architectureElement instanceof LayerSymbol){
include((LayerSymbol) architectureElement, writer);
} else {
include((IOSymbol) architectureElement, writer);
}
}
public void include(ArchitectureElementSymbol architectureElement){
if (writer == null){
throw new IllegalStateException("missing writer");
}
include(architectureElement, writer);
} }
public Map.Entry<String,String> process(String templateNameWithoutEnding, Target targetLanguage){ public Map.Entry<String,String> process(String templateNameWithoutEnding, Target targetLanguage){
...@@ -249,9 +233,7 @@ public class CNNArchTemplateController { ...@@ -249,9 +233,7 @@ public class CNNArchTemplateController {
&& architectureElement.getInputElement().isPresent() && architectureElement.getInputElement().isPresent()
&& architectureElement.getInputElement().get() instanceof LayerSymbol){ && architectureElement.getInputElement().get() instanceof LayerSymbol){
LayerSymbol inputLayer = (LayerSymbol) architectureElement.getInputElement().get(); LayerSymbol inputLayer = (LayerSymbol) architectureElement.getInputElement().get();
if (inputPredefinedLayerClass.isInstance(inputLayer.getDeclaration())){ return inputPredefinedLayerClass.isInstance(inputLayer.getDeclaration());
return true;
}
} }
return false; return false;
} }
......
...@@ -121,13 +121,13 @@ public class CNNTrain2MxNet implements CNNTrainGenerator { ...@@ -121,13 +121,13 @@ public class CNNTrain2MxNet implements CNNTrainGenerator {
} }
public Map<String, String> generateStrings(ConfigurationSymbol configuration) { public Map<String, String> generateStrings(ConfigurationSymbol configuration) {
TemplateConfiguration templateConfiguration = new MxNetTemplateConfiguration();
ConfigurationData configData = new ConfigurationData(configuration, getInstanceName()); ConfigurationData configData = new ConfigurationData(configuration, getInstanceName());
List<ConfigurationData> configDataList = new ArrayList<>(); List<ConfigurationData> configDataList = new ArrayList<>();
configDataList.add(configData); configDataList.add(configData);
Map<String, Object> ftlContext = Collections.singletonMap("configurations", configDataList); Map<String, Object> ftlContext = Collections.singletonMap("configurations", configDataList);
String templateContent = TemplateConfiguration.processTemplate(ftlContext, "CNNTrainer.ftl"); String templateContent = templateConfiguration.processTemplate(ftlContext, "CNNTrainer.ftl");
return Collections.singletonMap("CNNTrainer_" + getInstanceName() + ".py", templateContent); return Collections.singletonMap("CNNTrainer_" + getInstanceName() + ".py", templateContent);
} }
} }
\ No newline at end of file
package de.monticore.lang.monticar.cnnarch.mxnetgenerator;
import de.monticore.lang.monticar.cnnarch.CNNArchGenerator;
import de.se_rwth.commons.logging.Log;
import org.apache.commons.cli.*;
import java.nio.file.Path;
import java.nio.file.Paths;
/**
*
*/
public class GenericCNNArchCli {
private final CNNArchGenerator cnnArchGenerator;
public static final Option OPTION_MODELS_PATH = Option.builder("m")
.longOpt("models-dir")
.desc("full path to the directory with the CNNArch model")
.hasArg(true)
.required(true)
.build();
public static final Option OPTION_ROOT_MODEL = Option.builder("r")
.longOpt("root-model")
.desc("name of the architecture")
.hasArg(true)
.required(true)
.build();
public static final Option OPTION_OUTPUT_PATH = Option.builder("o")
.longOpt("output-dir")
.desc("full path to output directory for tests")
.hasArg(true)
.required(false)
.build();
public GenericCNNArchCli(CNNArchGenerator cnnArchGenerator) {
this.cnnArchGenerator = cnnArchGenerator;
}
public void run(String[] args) {
Options options = getOptions();
CommandLineParser parser = new DefaultParser();
CommandLine cliArgs = parseArgs(options, parser, args);
if (cliArgs != null) {
runGenerator(cliArgs);
}
}
private Options getOptions() {
Options options = new Options();
options.addOption(OPTION_MODELS_PATH);
options.addOption(OPTION_ROOT_MODEL);
options.addOption(OPTION_OUTPUT_PATH);
return options;
}
private CommandLine parseArgs(Options options, CommandLineParser parser, String[] args) {
CommandLine cliArgs;
try {
cliArgs = parser.parse(options, args);
} catch (ParseException e) {
Log.error("argument parsing exception: " + e.getMessage());
quitGeneration();
return null;
}
return cliArgs;
}
private void quitGeneration(){
Log.error("Code generation is aborted");
System.exit(1);
}
private void runGenerator(CommandLine cliArgs) {
Path modelsDirPath = Paths.get(cliArgs.getOptionValue(OPTION_MODELS_PATH.getOpt()));
String rootModelName = cliArgs.getOptionValue(OPTION_ROOT_MODEL.getOpt());
String outputPath = cliArgs.getOptionValue(OPTION_OUTPUT_PATH.getOpt());
if (outputPath != null){
cnnArchGenerator.setGenerationTargetPath(outputPath);
}
cnnArchGenerator.generate(modelsDirPath, rootModelName);
}
}
\ No newline at end of file
package de.monticore.lang.monticar.cnnarch.mxnetgenerator; package de.monticore.lang.monticar.cnnarch.mxnetgenerator;
import static de.monticore.lang.monticar.cnnarch.predefined.AllPredefinedLayers.*;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
public class LayerSupportChecker { public class LayerSupportChecker {
private List<String> unsupportedLayerList = new ArrayList(); private List<String> unsupportedLayerList = new ArrayList<>();
public LayerSupportChecker() { public LayerSupportChecker() {
//Set the unsupported layers for the backend //Set the unsupported layers for the backend
......
package de.monticore.lang.monticar.cnnarch.mxnetgenerator;
import freemarker.template.Configuration;
/**
*
*/
public class MxNetTemplateConfiguration extends TemplateConfiguration {
private static Configuration configuration;
public MxNetTemplateConfiguration() {
super();
if (configuration == null) {
configuration = super.createConfiguration();
}
}
@Override
protected String getBaseTemplatePackagePath() {
return "/templates/mxnet/";
}
@Override
public Configuration getConfiguration() {
return configuration;
}
}
\ No newline at end of file
...@@ -33,5 +33,5 @@ public enum Target { ...@@ -33,5 +33,5 @@ public enum Target {
public String toString() { public String toString() {
return ".h"; return ".h";
} }