Commit b5d928ee authored by Nicola Gatto's avatar Nicola Gatto Committed by Evgeny Kusmenko

Make generator classes reusable

parent 8099a9c9
......@@ -8,7 +8,7 @@
<groupId>de.monticore.lang.monticar</groupId>
<artifactId>cnnarch-mxnet-generator</artifactId>
<version>0.2.12</version>
<version>0.2.13-SNAPSHOT</version>
<!-- == PROJECT DEPENDENCIES ============================================= -->
......
......@@ -95,8 +95,11 @@ public class CNNArch2MxNet extends CNNArchGenerator {
//check cocos with CNNArchCocos.checkAll(architecture) before calling this method.
public Map<String, String> generateStrings(ArchitectureSymbol architecture){
TemplateConfiguration templateConfiguration = new MxNetTemplateConfiguration();
Map<String, String> fileContentMap = new HashMap<>();
CNNArchTemplateController archTc = new CNNArchTemplateController(architecture);
CNNArch2MxNetTemplateController archTc
= new CNNArch2MxNetTemplateController(architecture, templateConfiguration);
Map.Entry<String, String> temp;
temp = archTc.process("CNNPredictor", Target.CPP);
......
......@@ -19,81 +19,12 @@
* *******************************************************************************
*/
package de.monticore.lang.monticar.cnnarch.mxnetgenerator;
import de.se_rwth.commons.logging.Log;
import org.apache.commons.cli.*;
import java.nio.file.Path;
import java.nio.file.Paths;
import de.monticore.lang.monticar.cnnarch.CNNArchGenerator;
public class CNNArch2MxNetCli {
public static final Option OPTION_MODELS_PATH = Option.builder("m")
.longOpt("models-dir")
.desc("full path to the directory with the CNNArch model")
.hasArg(true)
.required(true)
.build();
public static final Option OPTION_ROOT_MODEL = Option.builder("r")
.longOpt("root-model")
.desc("name of the architecture")
.hasArg(true)
.required(true)
.build();
public static final Option OPTION_OUTPUT_PATH = Option.builder("o")
.longOpt("output-dir")
.desc("full path to output directory for tests")
.hasArg(true)
.required(false)
.build();
private CNNArch2MxNetCli() {
}
public static void main(String[] args) {
Options options = getOptions();
CommandLineParser parser = new DefaultParser();
CommandLine cliArgs = parseArgs(options, parser, args);
if (cliArgs != null) {
runGenerator(cliArgs);
}
}
private static Options getOptions() {
Options options = new Options();
options.addOption(OPTION_MODELS_PATH);
options.addOption(OPTION_ROOT_MODEL);
options.addOption(OPTION_OUTPUT_PATH);
return options;
}
private static CommandLine parseArgs(Options options, CommandLineParser parser, String[] args) {
CommandLine cliArgs;
try {
cliArgs = parser.parse(options, args);
} catch (ParseException e) {
Log.error("argument parsing exception: " + e.getMessage());
quitGeneration();
return null;
}
return cliArgs;
}
private static void quitGeneration(){
Log.error("Code generation is aborted");
System.exit(1);
}
private static void runGenerator(CommandLine cliArgs) {
Path modelsDirPath = Paths.get(cliArgs.getOptionValue(OPTION_MODELS_PATH.getOpt()));
String rootModelName = cliArgs.getOptionValue(OPTION_ROOT_MODEL.getOpt());
String outputPath = cliArgs.getOptionValue(OPTION_OUTPUT_PATH.getOpt());
CNNArch2MxNet generator = new CNNArch2MxNet();
if (outputPath != null){
generator.setGenerationTargetPath(outputPath);
}
generator.generate(modelsDirPath, rootModelName);
CNNArchGenerator generator = new CNNArch2MxNet();
GenericCNNArchCli cli = new GenericCNNArchCli(generator);
cli.run(args);
}
}
}
\ No newline at end of file
package de.monticore.lang.monticar.cnnarch.mxnetgenerator;
import de.monticore.lang.monticar.cnnarch._symboltable.*;
import java.io.Writer;
/**
*
*/
public class CNNArch2MxNetTemplateController extends CNNArchTemplateController {
public CNNArch2MxNetTemplateController(ArchitectureSymbol architecture,
TemplateConfiguration templateConfiguration) {
super(architecture, templateConfiguration);
}
public void include(IOSymbol ioElement, Writer writer){
ArchitectureElementData previousElement = getCurrentElement();
setCurrentElement(ioElement);
if (ioElement.isAtomic()){
if (ioElement.isInput()){
include(TEMPLATE_ELEMENTS_DIR_PATH, "Input", writer);
} else {
include(TEMPLATE_ELEMENTS_DIR_PATH, "Output", writer);
}
} else {
include(ioElement.getResolvedThis().get(), writer);
}
setCurrentElement(previousElement);
}
public void include(LayerSymbol layer, Writer writer){
ArchitectureElementData previousElement = getCurrentElement();
setCurrentElement(layer);
if (layer.isAtomic()){
ArchitectureElementSymbol nextElement = layer.getOutputElement().get();
if (!isSoftmaxOutput(nextElement) && !isLogisticRegressionOutput(nextElement)){
String templateName = layer.getDeclaration().getName();
include(TEMPLATE_ELEMENTS_DIR_PATH, templateName, writer);
}
} else {
include(layer.getResolvedThis().get(), writer);
}
setCurrentElement(previousElement);
}
public void include(CompositeElementSymbol compositeElement, Writer writer){
ArchitectureElementData previousElement = getCurrentElement();
setCurrentElement(compositeElement);
for (ArchitectureElementSymbol element : compositeElement.getElements()){
include(element, writer);
}
setCurrentElement(previousElement);
}
public void include(ArchitectureElementSymbol architectureElement, Writer writer){
if (architectureElement instanceof CompositeElementSymbol){
include((CompositeElementSymbol) architectureElement, writer);
} else if (architectureElement instanceof LayerSymbol){
include((LayerSymbol) architectureElement, writer);
} else {
include((IOSymbol) architectureElement, writer);
}
}
public void include(ArchitectureElementSymbol architectureElement){
if (getWriter() == null){
throw new IllegalStateException("missing writer");
}
include(architectureElement, getWriter());
}
}
\ No newline at end of file
......@@ -28,13 +28,15 @@ import java.io.StringWriter;
import java.io.Writer;
import java.util.*;
public class CNNArchTemplateController {
public abstract class CNNArchTemplateController {
public static final String FTL_FILE_ENDING = ".ftl";
public static final String TEMPLATE_ELEMENTS_DIR_PATH = "elements/";
public static final String TEMPLATE_CONTROLLER_KEY = "tc";
public static final String ELEMENT_DATA_KEY = "element";
private final TemplateConfiguration templateConfiguration;
private LayerNameCreator nameManager;
private ArchitectureSymbol architecture;
......@@ -44,9 +46,53 @@ public class CNNArchTemplateController {
private Target targetLanguage;
private ArchitectureElementData dataElement;
public CNNArchTemplateController(ArchitectureSymbol architecture) {
protected CNNArchTemplateController(ArchitectureSymbol architecture, TemplateConfiguration templateConfiguration) {
setArchitecture(architecture);
this.templateConfiguration = templateConfiguration;
}
protected TemplateConfiguration getTemplateConfiguration() {
return templateConfiguration;
}
protected LayerNameCreator getNameManager() {
return nameManager;
}
protected void setNameManager(LayerNameCreator nameManager) {
this.nameManager = nameManager;
}
protected Writer getWriter() {
return writer;
}
protected void setWriter(Writer writer) {
this.writer = writer;
}
protected String getMainTemplateNameWithoutEnding() {
return mainTemplateNameWithoutEnding;
}
protected void setMainTemplateNameWithoutEnding(String mainTemplateNameWithoutEnding) {
this.mainTemplateNameWithoutEnding = mainTemplateNameWithoutEnding;
}
protected Target getTargetLanguage() {
return targetLanguage;
}
protected void setTargetLanguage(Target targetLanguage) {
this.targetLanguage = targetLanguage;
}
protected ArchitectureElementData getDataElement() {
return dataElement;
}
protected void setDataElement(ArchitectureElementData dataElement) {
this.dataElement = dataElement;
}
public String getFileNameWithoutEnding() {
......@@ -122,74 +168,12 @@ public class CNNArchTemplateController {
return list;
}
public void include(String relativePath, String templateWithoutFileEnding, Writer writer){
public void include(String relativePath, String templateWithoutFileEnding, Writer writer) {
String templatePath = relativePath + templateWithoutFileEnding + FTL_FILE_ENDING;
Map<String, Object> ftlContext = new HashMap<>();
ftlContext.put(TEMPLATE_CONTROLLER_KEY, this);
ftlContext.put(ELEMENT_DATA_KEY, getCurrentElement());
TemplateConfiguration.processTemplate(ftlContext, templatePath, writer);
}
public void include(IOSymbol ioElement, Writer writer){
ArchitectureElementData previousElement = getCurrentElement();
setCurrentElement(ioElement);
if (ioElement.isAtomic()){
if (ioElement.isInput()){
include(TEMPLATE_ELEMENTS_DIR_PATH, "Input", writer);
} else {
include(TEMPLATE_ELEMENTS_DIR_PATH, "Output", writer);
}
} else {
include(ioElement.getResolvedThis().get(), writer);
}
setCurrentElement(previousElement);
}
public void include(LayerSymbol layer, Writer writer){
ArchitectureElementData previousElement = getCurrentElement();
setCurrentElement(layer);
if (layer.isAtomic()){
ArchitectureElementSymbol nextElement = layer.getOutputElement().get();
if (!isSoftmaxOutput(nextElement) && !isLogisticRegressionOutput(nextElement)){
String templateName = layer.getDeclaration().getName();
include(TEMPLATE_ELEMENTS_DIR_PATH, templateName, writer);
}
} else {
include(layer.getResolvedThis().get(), writer);
}
setCurrentElement(previousElement);
}
public void include(CompositeElementSymbol compositeElement, Writer writer){
ArchitectureElementData previousElement = getCurrentElement();
setCurrentElement(compositeElement);
for (ArchitectureElementSymbol element : compositeElement.getElements()){
include(element, writer);
}
setCurrentElement(previousElement);
}
public void include(ArchitectureElementSymbol architectureElement, Writer writer){
if (architectureElement instanceof CompositeElementSymbol){
include((CompositeElementSymbol) architectureElement, writer);
} else if (architectureElement instanceof LayerSymbol){
include((LayerSymbol) architectureElement, writer);
} else {
include((IOSymbol) architectureElement, writer);
}
}
public void include(ArchitectureElementSymbol architectureElement){
if (writer == null){
throw new IllegalStateException("missing writer");
}
include(architectureElement, writer);
templateConfiguration.processTemplate(ftlContext, templatePath, writer);
}
public Map.Entry<String,String> process(String templateNameWithoutEnding, Target targetLanguage){
......@@ -249,9 +233,7 @@ public class CNNArchTemplateController {
&& architectureElement.getInputElement().isPresent()
&& architectureElement.getInputElement().get() instanceof LayerSymbol){
LayerSymbol inputLayer = (LayerSymbol) architectureElement.getInputElement().get();
if (inputPredefinedLayerClass.isInstance(inputLayer.getDeclaration())){
return true;
}
return inputPredefinedLayerClass.isInstance(inputLayer.getDeclaration());
}
return false;
}
......
......@@ -121,13 +121,13 @@ public class CNNTrain2MxNet implements CNNTrainGenerator {
}
public Map<String, String> generateStrings(ConfigurationSymbol configuration) {
TemplateConfiguration templateConfiguration = new MxNetTemplateConfiguration();
ConfigurationData configData = new ConfigurationData(configuration, getInstanceName());
List<ConfigurationData> configDataList = new ArrayList<>();
configDataList.add(configData);
Map<String, Object> ftlContext = Collections.singletonMap("configurations", configDataList);
String templateContent = TemplateConfiguration.processTemplate(ftlContext, "CNNTrainer.ftl");
String templateContent = templateConfiguration.processTemplate(ftlContext, "CNNTrainer.ftl");
return Collections.singletonMap("CNNTrainer_" + getInstanceName() + ".py", templateContent);
}
}
\ No newline at end of file
package de.monticore.lang.monticar.cnnarch.mxnetgenerator;
import de.monticore.lang.monticar.cnnarch.CNNArchGenerator;
import de.se_rwth.commons.logging.Log;
import org.apache.commons.cli.*;
import java.nio.file.Path;
import java.nio.file.Paths;
/**
*
*/
public class GenericCNNArchCli {
private final CNNArchGenerator cnnArchGenerator;
public static final Option OPTION_MODELS_PATH = Option.builder("m")
.longOpt("models-dir")
.desc("full path to the directory with the CNNArch model")
.hasArg(true)
.required(true)
.build();
public static final Option OPTION_ROOT_MODEL = Option.builder("r")
.longOpt("root-model")
.desc("name of the architecture")
.hasArg(true)
.required(true)
.build();
public static final Option OPTION_OUTPUT_PATH = Option.builder("o")
.longOpt("output-dir")
.desc("full path to output directory for tests")
.hasArg(true)
.required(false)
.build();
public GenericCNNArchCli(CNNArchGenerator cnnArchGenerator) {
this.cnnArchGenerator = cnnArchGenerator;
}
public void run(String[] args) {
Options options = getOptions();
CommandLineParser parser = new DefaultParser();
CommandLine cliArgs = parseArgs(options, parser, args);
if (cliArgs != null) {
runGenerator(cliArgs);
}
}
private Options getOptions() {
Options options = new Options();
options.addOption(OPTION_MODELS_PATH);
options.addOption(OPTION_ROOT_MODEL);
options.addOption(OPTION_OUTPUT_PATH);
return options;
}
private CommandLine parseArgs(Options options, CommandLineParser parser, String[] args) {
CommandLine cliArgs;
try {
cliArgs = parser.parse(options, args);
} catch (ParseException e) {
Log.error("argument parsing exception: " + e.getMessage());
quitGeneration();
return null;
}
return cliArgs;
}
private void quitGeneration(){
Log.error("Code generation is aborted");
System.exit(1);
}
private void runGenerator(CommandLine cliArgs) {
Path modelsDirPath = Paths.get(cliArgs.getOptionValue(OPTION_MODELS_PATH.getOpt()));
String rootModelName = cliArgs.getOptionValue(OPTION_ROOT_MODEL.getOpt());
String outputPath = cliArgs.getOptionValue(OPTION_OUTPUT_PATH.getOpt());
if (outputPath != null){
cnnArchGenerator.setGenerationTargetPath(outputPath);
}
cnnArchGenerator.generate(modelsDirPath, rootModelName);
}
}
\ No newline at end of file
package de.monticore.lang.monticar.cnnarch.mxnetgenerator;
import static de.monticore.lang.monticar.cnnarch.predefined.AllPredefinedLayers.*;
import java.util.ArrayList;
import java.util.List;
public class LayerSupportChecker {
private List<String> unsupportedLayerList = new ArrayList();
private List<String> unsupportedLayerList = new ArrayList<>();
public LayerSupportChecker() {
//Set the unsupported layers for the backend
......
package de.monticore.lang.monticar.cnnarch.mxnetgenerator;
import freemarker.template.Configuration;
/**
*
*/
public class MxNetTemplateConfiguration extends TemplateConfiguration {
private static Configuration configuration;
public MxNetTemplateConfiguration() {
super();
if (configuration == null) {
configuration = super.createConfiguration();
}
}
@Override
protected String getBaseTemplatePackagePath() {
return "/templates/mxnet/";
}
@Override
public Configuration getConfiguration() {
return configuration;
}
}
\ No newline at end of file
......@@ -33,5 +33,5 @@ public enum Target {
public String toString() {
return ".h";
}
};
}
}
......@@ -31,37 +31,30 @@ import java.io.StringWriter;
import java.io.Writer;
import java.util.Map;
public class TemplateConfiguration {
public abstract class TemplateConfiguration {
abstract protected String getBaseTemplatePackagePath();
abstract public Configuration getConfiguration();
private static TemplateConfiguration instance;
private Configuration configuration;
public TemplateConfiguration() {
private TemplateConfiguration() {
configuration = new Configuration(Configuration.VERSION_2_3_23);
configuration.setClassForTemplateLoading(TemplateConfiguration.class, "/templates/mxnet/");
}
protected Configuration createConfiguration() {
Configuration configuration = new Configuration(Configuration.VERSION_2_3_23);
configuration.setClassForTemplateLoading(TemplateConfiguration.class, getBaseTemplatePackagePath());
configuration.setDefaultEncoding("UTF-8");
configuration.setTemplateExceptionHandler(TemplateExceptionHandler.RETHROW_HANDLER);
return configuration;
}
private static void quitGeneration(){
private void quitGeneration(){
Log.error("Code generation is aborted");
System.exit(1);
}
public Configuration getConfiguration() {
return configuration;
}
public static Configuration get(){
if (instance == null){
instance = new TemplateConfiguration();
}
return instance.getConfiguration();
}
public static void processTemplate(Map<String, Object> ftlContext, String templatePath, Writer writer){
public void processTemplate(Map<String, Object> ftlContext, String templatePath, Writer writer){
try{
Template template = TemplateConfiguration.get().getTemplate(templatePath);
Template template = getConfiguration().getTemplate(templatePath);
template.process(ftlContext, writer);
} catch (IOException e) {
Log.error("Freemarker could not find template " + templatePath + " :\n" + e.getMessage());
......@@ -72,10 +65,9 @@ public class TemplateConfiguration {
}
}
public static String processTemplate(Map<String, Object> ftlContext, String templatePath){
public String processTemplate(Map<String, Object> ftlContext, String templatePath){
StringWriter writer = new StringWriter();
processTemplate(ftlContext, templatePath, writer);
return writer.toString();
}
}
......@@ -8,7 +8,7 @@ import java.util.List;
public class TrainParamSupportChecker implements CNNTrainVisitor {
private List<String> unsupportedElemList = new ArrayList();
private List<String> unsupportedElemList = new ArrayList<>();
private void printUnsupportedEntryParam(String nodeName){
Log.warn("Unsupported training parameter " + "'" + nodeName + "'" + " for the backend MXNet. It will be ignored.");
......
......@@ -55,9 +55,8 @@ public class AbstractSymtabTest {
for (String m : modelPath) {
mp.addEntry(Paths.get(m));
}
GlobalScope scope = new GlobalScope(mp, fam);
return scope;
return new GlobalScope(mp, fam);
}
protected static CNNArchCompilationUnitSymbol getCompilationUnitSymbol(String modelPath, String model) {
......
......@@ -101,7 +101,7 @@ public class GenerationTest extends AbstractSymtabTest{
@Test
public void testResNeXtGeneration() throws IOException, TemplateException {
Log.getFindings().clear();;
Log.getFindings().clear();
String[] args = {"-m", "src/test/resources/architectures", "-r", "ResNeXt50"};
CNNArch2MxNetCli.main(args);
assertTrue(Log.getFindings().isEmpty());
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment