From 531d89e87e753844e0a92c9ffa3175cf695e3aee Mon Sep 17 00:00:00 2001 From: Nicola Gatto <nicola.gatto@rwth-aachen.de> Date: Mon, 25 Feb 2019 16:43:27 +0100 Subject: [PATCH] Refactor and Inherit from CNNArch2MxNet --- pom.xml | 9 +- .../ArchitectureElementData.java | 195 ------------- .../cnnarch/gluongenerator/CNNArch2Gluon.java | 102 +------ .../gluongenerator/CNNArch2GluonCli.java | 77 +---- .../CNNArch2GluonTemplateController.java | 119 ++++++++ .../CNNArchTemplateController.java | 272 ------------------ .../gluongenerator/CNNTrain2Gluon.java | 124 +------- .../gluongenerator/ConfigurationData.java | 99 ------- .../GluonTemplateConfiguration.java | 28 ++ .../gluongenerator/LayerNameCreator.java | 146 ---------- .../gluongenerator/LayerSupportChecker.java | 19 -- .../cnnarch/gluongenerator/Target.java | 37 --- .../gluongenerator/TemplateConfiguration.java | 81 ------ .../TrainParamSupportChecker.java | 94 ------ .../gluongenerator/AbstractSymtabTest.java | 3 +- .../gluongenerator/GenerationTest.java | 4 +- 16 files changed, 182 insertions(+), 1227 deletions(-) delete mode 100644 src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/ArchitectureElementData.java create mode 100644 src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/CNNArch2GluonTemplateController.java delete mode 100644 src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/CNNArchTemplateController.java delete mode 100644 src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/ConfigurationData.java create mode 100644 src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/GluonTemplateConfiguration.java delete mode 100644 src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/LayerNameCreator.java delete mode 100644 src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/LayerSupportChecker.java delete mode 100644 src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/Target.java delete mode 100644 src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/TemplateConfiguration.java delete mode 100644 src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/TrainParamSupportChecker.java diff --git a/pom.xml b/pom.xml index b924852f..42a1dc13 100644 --- a/pom.xml +++ b/pom.xml @@ -8,7 +8,7 @@ <groupId>de.monticore.lang.monticar</groupId> <artifactId>cnnarch-gluon-generator</artifactId> - <version>0.1.0-SNAPSHOT</version> + <version>0.1.2-SNAPSHOT</version> <!-- == PROJECT DEPENDENCIES ============================================= --> @@ -17,6 +17,7 @@ <!-- .. SE-Libraries .................................................. --> <CNNArch.version>0.2.9</CNNArch.version> <CNNTrain.version>0.2.6</CNNTrain.version> + <CNNArch2MXNet.version>0.2.13-SNAPSHOT</CNNArch2MXNet.version> <embedded-montiarc-math-opt-generator>0.1.4</embedded-montiarc-math-opt-generator> <!-- .. Libraries .................................................. --> @@ -59,6 +60,12 @@ <!-- MontiCore Dependencies --> + <dependency> + <groupId>de.monticore.lang.monticar</groupId> + <artifactId>cnnarch-mxnet-generator</artifactId> + <version>${CNNArch2MXNet.version}</version> + </dependency> + <dependency> <groupId>de.monticore.lang.monticar</groupId> <artifactId>cnn-arch</artifactId> diff --git a/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/ArchitectureElementData.java b/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/ArchitectureElementData.java deleted file mode 100644 index 5baf58d2..00000000 --- a/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/ArchitectureElementData.java +++ /dev/null @@ -1,195 +0,0 @@ -/** - * - * ****************************************************************************** - * MontiCAR Modeling Family, www.se-rwth.de - * Copyright (c) 2017, Software Engineering Group at RWTH Aachen, - * All rights reserved. - * - * This project is free software; you can redistribute it and/or - * modify it under the terms of the GNU Lesser General Public - * License as published by the Free Software Foundation; either - * version 3.0 of the License, or (at your option) any later version. - * This library is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU - * Lesser General Public License for more details. - * - * You should have received a copy of the GNU Lesser General Public - * License along with this project. If not, see <http://www.gnu.org/licenses/>. - * ******************************************************************************* - */ -package de.monticore.lang.monticar.cnnarch.gluongenerator; - -import de.monticore.lang.monticar.cnnarch._symboltable.ArchTypeSymbol; -import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureElementSymbol; -import de.monticore.lang.monticar.cnnarch._symboltable.LayerSymbol; -import de.monticore.lang.monticar.cnnarch.predefined.AllPredefinedLayers; -import de.se_rwth.commons.logging.Log; - -import javax.annotation.Nullable; -import java.util.Arrays; -import java.util.List; - -public class ArchitectureElementData { - - private String name; - private ArchitectureElementSymbol element; - private CNNArchTemplateController templateController; - - public ArchitectureElementData(String name, ArchitectureElementSymbol element, CNNArchTemplateController templateController) { - this.name = name; - this.element = element; - this.templateController = templateController; - } - - public String getName() { - return name; - } - - public void setName(String name) { - this.name = name; - } - - public ArchitectureElementSymbol getElement() { - return element; - } - - public void setElement(ArchitectureElementSymbol element) { - this.element = element; - } - - public CNNArchTemplateController getTemplateController() { - return templateController; - } - - public void setTemplateController(CNNArchTemplateController templateController) { - this.templateController = templateController; - } - - public List<String> getInputs(){ - return getTemplateController().getLayerInputs(getElement()); - } - - public boolean isLogisticRegressionOutput(){ - return getTemplateController().isLogisticRegressionOutput(getElement()); - } - - - public boolean isLinearRegressionOutput(){ - boolean result = getTemplateController().isLinearRegressionOutput(getElement()); - if (result){ - Log.warn("The Output '" + getElement().getName() + "' is a linear regression output (squared loss) during training" + - " because the previous architecture element is not a softmax (cross-entropy loss) or sigmoid (logistic regression loss) activation. " + - "Other loss functions are currently not supported. " - , getElement().getSourcePosition()); - } - return result; - } - - public boolean isSoftmaxOutput(){ - return getTemplateController().isSoftmaxOutput(getElement()); - } - - - - - public List<Integer> getKernel(){ - return ((LayerSymbol) getElement()) - .getIntTupleValue(AllPredefinedLayers.KERNEL_NAME).get(); - } - - public int getChannels(){ - return ((LayerSymbol) getElement()) - .getIntValue(AllPredefinedLayers.CHANNELS_NAME).get(); - } - - public List<Integer> getStride(){ - return ((LayerSymbol) getElement()) - .getIntTupleValue(AllPredefinedLayers.STRIDE_NAME).get(); - } - - public int getUnits(){ - return ((LayerSymbol) getElement()) - .getIntValue(AllPredefinedLayers.UNITS_NAME).get(); - } - - public boolean getNoBias(){ - return ((LayerSymbol) getElement()) - .getBooleanValue(AllPredefinedLayers.NOBIAS_NAME).get(); - } - - public double getP(){ - return ((LayerSymbol) getElement()) - .getDoubleValue(AllPredefinedLayers.P_NAME).get(); - } - - public int getIndex(){ - return ((LayerSymbol) getElement()) - .getIntValue(AllPredefinedLayers.INDEX_NAME).get(); - } - - public int getNumOutputs(){ - return ((LayerSymbol) getElement()) - .getIntValue(AllPredefinedLayers.NUM_SPLITS_NAME).get(); - } - - public boolean getFixGamma(){ - return ((LayerSymbol) getElement()) - .getBooleanValue(AllPredefinedLayers.FIX_GAMMA_NAME).get(); - } - - public int getNsize(){ - return ((LayerSymbol) getElement()) - .getIntValue(AllPredefinedLayers.NSIZE_NAME).get(); - } - - public double getKnorm(){ - return ((LayerSymbol) getElement()) - .getDoubleValue(AllPredefinedLayers.KNORM_NAME).get(); - } - - public double getAlpha(){ - return ((LayerSymbol) getElement()) - .getDoubleValue(AllPredefinedLayers.ALPHA_NAME).get(); - } - - public double getBeta(){ - return ((LayerSymbol) getElement()) - .getDoubleValue(AllPredefinedLayers.BETA_NAME).get(); - } - - @Nullable - public String getPoolType(){ - return ((LayerSymbol) getElement()) - .getStringValue(AllPredefinedLayers.POOL_TYPE_NAME).get(); - } - - @Nullable - public List<Integer> getPadding(){ - return getPadding((LayerSymbol) getElement()); - } - - @Nullable - public List<Integer> getPadding(LayerSymbol layer){ - List<Integer> kernel = layer.getIntTupleValue(AllPredefinedLayers.KERNEL_NAME).get(); - List<Integer> stride = layer.getIntTupleValue(AllPredefinedLayers.STRIDE_NAME).get(); - ArchTypeSymbol inputType = layer.getInputTypes().get(0); - ArchTypeSymbol outputType = layer.getOutputTypes().get(0); - - int heightWithPad = kernel.get(0) + stride.get(0)*(outputType.getHeight() - 1); - int widthWithPad = kernel.get(1) + stride.get(1)*(outputType.getWidth() - 1); - int heightPad = Math.max(0, heightWithPad - inputType.getHeight()); - int widthPad = Math.max(0, widthWithPad - inputType.getWidth()); - - int topPad = (int)Math.ceil(heightPad / 2.0); - int bottomPad = (int)Math.floor(heightPad / 2.0); - int leftPad = (int)Math.ceil(widthPad / 2.0); - int rightPad = (int)Math.floor(widthPad / 2.0); - - if (topPad == 0 && bottomPad == 0 && leftPad == 0 && rightPad == 0){ - return null; - } - - return Arrays.asList(0,0,0,0,topPad,bottomPad,leftPad,rightPad); - } -} diff --git a/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/CNNArch2Gluon.java b/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/CNNArch2Gluon.java index f8aad2ed..652e09db 100644 --- a/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/CNNArch2Gluon.java +++ b/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/CNNArch2Gluon.java @@ -20,83 +20,25 @@ */ package de.monticore.lang.monticar.cnnarch.gluongenerator; -import de.monticore.lang.monticar.cnnarch.CNNArchGenerator; -import de.monticore.lang.monticar.cnnarch._cocos.CNNArchCocos; +import de.monticore.lang.monticar.cnnarch.mxnetgenerator.CNNArch2MxNet; +import de.monticore.lang.monticar.cnnarch.mxnetgenerator.Target; + import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureSymbol; -import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureElementSymbol; -import de.monticore.lang.monticar.cnnarch._symboltable.CompositeElementSymbol; -import de.monticore.lang.monticar.cnnarch._symboltable.CNNArchCompilationUnitSymbol; -import de.monticore.lang.monticar.generator.FileContent; -import de.monticore.lang.monticar.generator.cmake.CMakeConfig; -import de.monticore.lang.monticar.generator.cmake.CMakeFindModule; -import de.monticore.lang.monticar.generator.cpp.GeneratorCPP; -import de.monticore.symboltable.Scope; -import de.se_rwth.commons.logging.Log; +import de.monticore.lang.monticar.cnnarch.mxnetgenerator.TemplateConfiguration; -import java.io.IOException; import java.util.HashMap; import java.util.Map; -import java.util.Optional; -import java.util.List; - -public class CNNArch2Gluon extends CNNArchGenerator { - - private boolean isSupportedLayer(ArchitectureElementSymbol element, LayerSupportChecker layerChecker){ - List<ArchitectureElementSymbol> constructLayerElemList; - - if (element.getResolvedThis().get() instanceof CompositeElementSymbol) { - constructLayerElemList = ((CompositeElementSymbol)element.getResolvedThis().get()).getElements(); - for (ArchitectureElementSymbol constructedLayerElement : constructLayerElemList) { - if (!isSupportedLayer(constructedLayerElement, layerChecker)) { - return false; - } - } - } - if (!layerChecker.isSupported(element.toString())) { - Log.error("Unsupported layer " + "'" + element.getName() + "'" + " for the backend MXNET."); - return false; - } else { - return true; - } - } - - private boolean supportCheck(ArchitectureSymbol architecture){ - LayerSupportChecker layerChecker = new LayerSupportChecker(); - for (ArchitectureElementSymbol element : ((CompositeElementSymbol)architecture.getBody()).getElements()){ - if(!isSupportedLayer(element, layerChecker)) { - return false; - } - } - return true; - } - - public CNNArch2Gluon() { - setGenerationTargetPath("./target/generated-sources-cnnarch/"); - } - - public void generate(Scope scope, String rootModelName){ - Optional<CNNArchCompilationUnitSymbol> compilationUnit = scope.resolve(rootModelName, CNNArchCompilationUnitSymbol.KIND); - if (!compilationUnit.isPresent()){ - Log.error("could not resolve architecture " + rootModelName); - quitGeneration(); - } - CNNArchCocos.checkAll(compilationUnit.get()); - if (!supportCheck(compilationUnit.get().getArchitecture())){ - quitGeneration(); - } - - try{ - generateFiles(compilationUnit.get().getArchitecture()); - } catch (IOException e){ - Log.error(e.toString()); - } - } +public class CNNArch2Gluon extends CNNArch2MxNet { //check cocos with CNNArchCocos.checkAll(architecture) before calling this method. + @Override public Map<String, String> generateStrings(ArchitectureSymbol architecture){ + TemplateConfiguration templateConfiguration = new GluonTemplateConfiguration(); + Map<String, String> fileContentMap = new HashMap<>(); - CNNArchTemplateController archTc = new CNNArchTemplateController(architecture); + CNNArch2GluonTemplateController archTc = new CNNArch2GluonTemplateController( + architecture, templateConfiguration); Map.Entry<String, String> temp; temp = archTc.process("CNNPredictor", Target.CPP); @@ -118,28 +60,4 @@ public class CNNArch2Gluon extends CNNArchGenerator { return fileContentMap; } - - public void generateFromFilecontentsMap(Map<String, String> fileContentMap) throws IOException { - GeneratorCPP genCPP = new GeneratorCPP(); - genCPP.setGenerationTargetPath(getGenerationTargetPath()); - for (String fileName : fileContentMap.keySet()){ - genCPP.generateFile(new FileContent(fileContentMap.get(fileName), fileName)); - } - } - - public Map<String, String> generateCMakeContent(String rootModelName) { - // model name should start with a lower case letter. If it is a component, replace dot . by _ - rootModelName = rootModelName.replace('.', '_').replace('[', '_').replace(']', '_'); - rootModelName = rootModelName.substring(0, 1).toLowerCase() + rootModelName.substring(1); - - CMakeConfig cMakeConfig = new CMakeConfig(rootModelName); - cMakeConfig.addModuleDependency(new CMakeFindModule("Armadillo", true)); - cMakeConfig.addCMakeCommand("set(LIBS ${LIBS} mxnet)"); - - Map<String,String> fileContentMap = new HashMap<>(); - for (FileContent fileContent : cMakeConfig.generateCMakeFiles()){ - fileContentMap.put(fileContent.getFileName(), fileContent.getFileContent()); - } - return fileContentMap; - } } diff --git a/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/CNNArch2GluonCli.java b/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/CNNArch2GluonCli.java index 582053ff..402c46b7 100644 --- a/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/CNNArch2GluonCli.java +++ b/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/CNNArch2GluonCli.java @@ -19,81 +19,14 @@ * ******************************************************************************* */ package de.monticore.lang.monticar.cnnarch.gluongenerator; -import de.se_rwth.commons.logging.Log; -import org.apache.commons.cli.*; - -import java.nio.file.Path; -import java.nio.file.Paths; +import de.monticore.lang.monticar.cnnarch.CNNArchGenerator; +import de.monticore.lang.monticar.cnnarch.mxnetgenerator.GenericCNNArchCli; public class CNNArch2GluonCli { - - public static final Option OPTION_MODELS_PATH = Option.builder("m") - .longOpt("models-dir") - .desc("full path to the directory with the CNNArch model") - .hasArg(true) - .required(true) - .build(); - - public static final Option OPTION_ROOT_MODEL = Option.builder("r") - .longOpt("root-model") - .desc("name of the architecture") - .hasArg(true) - .required(true) - .build(); - - public static final Option OPTION_OUTPUT_PATH = Option.builder("o") - .longOpt("output-dir") - .desc("full path to output directory for tests") - .hasArg(true) - .required(false) - .build(); - - private CNNArch2GluonCli() { - } - public static void main(String[] args) { - Options options = getOptions(); - CommandLineParser parser = new DefaultParser(); - CommandLine cliArgs = parseArgs(options, parser, args); - if (cliArgs != null) { - runGenerator(cliArgs); - } - } - - private static Options getOptions() { - Options options = new Options(); - options.addOption(OPTION_MODELS_PATH); - options.addOption(OPTION_ROOT_MODEL); - options.addOption(OPTION_OUTPUT_PATH); - return options; - } - - private static CommandLine parseArgs(Options options, CommandLineParser parser, String[] args) { - CommandLine cliArgs; - try { - cliArgs = parser.parse(options, args); - } catch (ParseException e) { - Log.error("argument parsing exception: " + e.getMessage()); - quitGeneration(); - return null; - } - return cliArgs; - } - - private static void quitGeneration(){ - Log.error("Code generation is aborted"); - System.exit(1); - } - - private static void runGenerator(CommandLine cliArgs) { - Path modelsDirPath = Paths.get(cliArgs.getOptionValue(OPTION_MODELS_PATH.getOpt())); - String rootModelName = cliArgs.getOptionValue(OPTION_ROOT_MODEL.getOpt()); - String outputPath = cliArgs.getOptionValue(OPTION_OUTPUT_PATH.getOpt()); - CNNArch2Gluon generator = new CNNArch2Gluon(); - if (outputPath != null){ - generator.setGenerationTargetPath(outputPath); - } - generator.generate(modelsDirPath, rootModelName); + CNNArchGenerator generator = new CNNArch2Gluon(); + GenericCNNArchCli cli = new GenericCNNArchCli(generator); + cli.run(args); } } diff --git a/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/CNNArch2GluonTemplateController.java b/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/CNNArch2GluonTemplateController.java new file mode 100644 index 00000000..0957b3c9 --- /dev/null +++ b/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/CNNArch2GluonTemplateController.java @@ -0,0 +1,119 @@ +/** + * + * ****************************************************************************** + * MontiCAR Modeling Family, www.se-rwth.de + * Copyright (c) 2017, Software Engineering Group at RWTH Aachen, + * All rights reserved. + * + * This project is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 3.0 of the License, or (at your option) any later version. + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this project. If not, see <http://www.gnu.org/licenses/>. + * ******************************************************************************* + */ +package de.monticore.lang.monticar.cnnarch.gluongenerator; + +import de.monticore.lang.monticar.cnnarch.mxnetgenerator.ArchitectureElementData; +import de.monticore.lang.monticar.cnnarch.mxnetgenerator.CNNArchTemplateController; + +import de.monticore.lang.monticar.cnnarch._symboltable.*; +import de.monticore.lang.monticar.cnnarch.mxnetgenerator.TemplateConfiguration; + +import java.io.Writer; +import java.util.*; + +public class CNNArch2GluonTemplateController extends CNNArchTemplateController { + public static final String NET_DEFINITION_MODE_KEY = "definition_mode"; + + public CNNArch2GluonTemplateController(ArchitectureSymbol architecture, + TemplateConfiguration templateConfiguration) { + super(architecture, templateConfiguration); + } + + public void include(String relativePath, String templateWithoutFileEnding, Writer writer, NetDefinitionMode netDefinitionMode){ + String templatePath = relativePath + templateWithoutFileEnding + FTL_FILE_ENDING; + Map<String, Object> ftlContext = new HashMap<>(); + ftlContext.put(TEMPLATE_CONTROLLER_KEY, this); + ftlContext.put(ELEMENT_DATA_KEY, getCurrentElement()); + ftlContext.put(NET_DEFINITION_MODE_KEY, netDefinitionMode); + getTemplateConfiguration().processTemplate(ftlContext, templatePath, writer); + } + + public void include(IOSymbol ioElement, Writer writer, NetDefinitionMode netDefinitionMode){ + ArchitectureElementData previousElement = getCurrentElement(); + setCurrentElement(ioElement); + + if (ioElement.isAtomic()){ + if (ioElement.isInput()){ + include(TEMPLATE_ELEMENTS_DIR_PATH, "Input", writer, netDefinitionMode); + } + else { + include(TEMPLATE_ELEMENTS_DIR_PATH, "Output", writer, netDefinitionMode); + } + } + else { + include(ioElement.getResolvedThis().get(), writer, netDefinitionMode); + } + + setCurrentElement(previousElement); + } + + public void include(LayerSymbol layer, Writer writer, NetDefinitionMode netDefinitionMode){ + ArchitectureElementData previousElement = getCurrentElement(); + setCurrentElement(layer); + + if (layer.isAtomic()){ + ArchitectureElementSymbol nextElement = layer.getOutputElement().get(); + if (!isSoftmaxOutput(nextElement) && !isLogisticRegressionOutput(nextElement)){ + String templateName = layer.getDeclaration().getName(); + include(TEMPLATE_ELEMENTS_DIR_PATH, templateName, writer, netDefinitionMode); + } + } + else { + include(layer.getResolvedThis().get(), writer, netDefinitionMode); + } + + setCurrentElement(previousElement); + } + + public void include(CompositeElementSymbol compositeElement, Writer writer, NetDefinitionMode netDefinitionMode){ + ArchitectureElementData previousElement = getCurrentElement(); + setCurrentElement(compositeElement); + + for (ArchitectureElementSymbol element : compositeElement.getElements()){ + include(element, writer, netDefinitionMode); + } + + setCurrentElement(previousElement); + } + + public void include(ArchitectureElementSymbol architectureElement, Writer writer, NetDefinitionMode netDefinitionMode){ + if (architectureElement instanceof CompositeElementSymbol){ + include((CompositeElementSymbol) architectureElement, writer, netDefinitionMode); + } + else if (architectureElement instanceof LayerSymbol){ + include((LayerSymbol) architectureElement, writer, netDefinitionMode); + } + else { + include((IOSymbol) architectureElement, writer, netDefinitionMode); + } + } + + public void include(ArchitectureElementSymbol architectureElementSymbol, String netDefinitionMode) { + include(architectureElementSymbol, NetDefinitionMode.fromString(netDefinitionMode)); + } + + public void include(ArchitectureElementSymbol architectureElement, NetDefinitionMode netDefinitionMode){ + if (getWriter() == null){ + throw new IllegalStateException("missing writer"); + } + include(architectureElement, getWriter(), netDefinitionMode); + } +} diff --git a/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/CNNArchTemplateController.java b/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/CNNArchTemplateController.java deleted file mode 100644 index c5d7a558..00000000 --- a/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/CNNArchTemplateController.java +++ /dev/null @@ -1,272 +0,0 @@ -/** - * - * ****************************************************************************** - * MontiCAR Modeling Family, www.se-rwth.de - * Copyright (c) 2017, Software Engineering Group at RWTH Aachen, - * All rights reserved. - * - * This project is free software; you can redistribute it and/or - * modify it under the terms of the GNU Lesser General Public - * License as published by the Free Software Foundation; either - * version 3.0 of the License, or (at your option) any later version. - * This library is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU - * Lesser General Public License for more details. - * - * You should have received a copy of the GNU Lesser General Public - * License along with this project. If not, see <http://www.gnu.org/licenses/>. - * ******************************************************************************* - */ -package de.monticore.lang.monticar.cnnarch.gluongenerator; - -import de.monticore.lang.monticar.cnnarch._symboltable.*; -import de.monticore.lang.monticar.cnnarch.predefined.Sigmoid; -import de.monticore.lang.monticar.cnnarch.predefined.Softmax; - -import java.io.StringWriter; -import java.io.Writer; -import java.util.*; - -public class CNNArchTemplateController { - - public static final String FTL_FILE_ENDING = ".ftl"; - public static final String TEMPLATE_ELEMENTS_DIR_PATH = "elements/"; - public static final String TEMPLATE_CONTROLLER_KEY = "tc"; - public static final String ELEMENT_DATA_KEY = "element"; - public static final String NET_DEFINITION_MODE_KEY = "definition_mode"; - - private LayerNameCreator nameManager; - private ArchitectureSymbol architecture; - - //temporary attributes. They are set after calling process() - private Writer writer; - private String mainTemplateNameWithoutEnding; - private Target targetLanguage; - private ArchitectureElementData dataElement; - - - public CNNArchTemplateController(ArchitectureSymbol architecture) { - setArchitecture(architecture); - } - - public String getFileNameWithoutEnding() { - return mainTemplateNameWithoutEnding + "_" + getFullArchitectureName(); - } - - public ArchitectureElementData getCurrentElement() { - return dataElement; - } - - public void setCurrentElement(ArchitectureElementSymbol layer) { - this.dataElement = new ArchitectureElementData(getName(layer), layer, this); - } - - public void setCurrentElement(ArchitectureElementData dataElement) { - this.dataElement = dataElement; - } - - public ArchitectureSymbol getArchitecture() { - return architecture; - } - - public void setArchitecture(ArchitectureSymbol architecture) { - this.architecture = architecture; - this.nameManager = new LayerNameCreator(architecture); - } - - public String getName(ArchitectureElementSymbol layer){ - return nameManager.getName(layer); - } - - public String getArchitectureName(){ - return getArchitecture().getEnclosingScope().getSpanningSymbol().get().getName().replaceAll("\\.","_"); - } - - public String getFullArchitectureName(){ - return getArchitecture().getEnclosingScope().getSpanningSymbol().get().getFullName().replaceAll("\\.","_"); - } - - public List<String> getLayerInputs(ArchitectureElementSymbol layer){ - List<String> inputNames = new ArrayList<>(); - - if (isSoftmaxOutput(layer) || isLogisticRegressionOutput(layer)){ - inputNames = getLayerInputs(layer.getInputElement().get()); - } else { - for (ArchitectureElementSymbol input : layer.getPrevious()) { - if (input.getOutputTypes().size() == 1) { - inputNames.add(getName(input)); - } else { - for (int i = 0; i < input.getOutputTypes().size(); i++) { - inputNames.add(getName(input) + "[" + i + "]"); - } - } - } - } - return inputNames; - - } - - public List<String> getArchitectureInputs(){ - List<String> list = new ArrayList<>(); - for (IOSymbol ioElement : getArchitecture().getInputs()){ - list.add(nameManager.getName(ioElement)); - } - return list; - } - - public List<String> getArchitectureOutputs(){ - List<String> list = new ArrayList<>(); - for (IOSymbol ioElement : getArchitecture().getOutputs()){ - list.add(nameManager.getName(ioElement)); - } - return list; - } - - public void include(String relativePath, String templateWithoutFileEnding, Writer writer, NetDefinitionMode netDefinitionMode){ - String templatePath = relativePath + templateWithoutFileEnding + FTL_FILE_ENDING; - Map<String, Object> ftlContext = new HashMap<>(); - ftlContext.put(TEMPLATE_CONTROLLER_KEY, this); - ftlContext.put(ELEMENT_DATA_KEY, getCurrentElement()); - ftlContext.put(NET_DEFINITION_MODE_KEY, netDefinitionMode); - TemplateConfiguration.processTemplate(ftlContext, templatePath, writer); - } - - public void include(String relativePath, String templateWithoutFileEnding, Writer writer) { - String templatePath = relativePath + templateWithoutFileEnding + FTL_FILE_ENDING; - Map<String, Object> ftlContext = new HashMap<>(); - ftlContext.put(TEMPLATE_CONTROLLER_KEY, this); - ftlContext.put(ELEMENT_DATA_KEY, getCurrentElement()); - TemplateConfiguration.processTemplate(ftlContext, templatePath, writer); - } - - public void include(IOSymbol ioElement, Writer writer, NetDefinitionMode netDefinitionMode){ - ArchitectureElementData previousElement = getCurrentElement(); - setCurrentElement(ioElement); - - if (ioElement.isAtomic()){ - if (ioElement.isInput()){ - include(TEMPLATE_ELEMENTS_DIR_PATH, "Input", writer, netDefinitionMode); - } else { - include(TEMPLATE_ELEMENTS_DIR_PATH, "Output", writer, netDefinitionMode); - } - } else { - include(ioElement.getResolvedThis().get(), writer, netDefinitionMode); - } - - setCurrentElement(previousElement); - } - - public void include(LayerSymbol layer, Writer writer, NetDefinitionMode netDefinitionMode){ - ArchitectureElementData previousElement = getCurrentElement(); - setCurrentElement(layer); - - if (layer.isAtomic()){ - ArchitectureElementSymbol nextElement = layer.getOutputElement().get(); - if (!isSoftmaxOutput(nextElement) && !isLogisticRegressionOutput(nextElement)){ - String templateName = layer.getDeclaration().getName(); - include(TEMPLATE_ELEMENTS_DIR_PATH, templateName, writer, netDefinitionMode); - } - } else { - include(layer.getResolvedThis().get(), writer, netDefinitionMode); - } - - setCurrentElement(previousElement); - } - - public void include(CompositeElementSymbol compositeElement, Writer writer, NetDefinitionMode netDefinitionMode){ - ArchitectureElementData previousElement = getCurrentElement(); - setCurrentElement(compositeElement); - - for (ArchitectureElementSymbol element : compositeElement.getElements()){ - include(element, writer, netDefinitionMode); - } - - setCurrentElement(previousElement); - } - - public void include(ArchitectureElementSymbol architectureElement, Writer writer, NetDefinitionMode netDefinitionMode){ - if (architectureElement instanceof CompositeElementSymbol){ - include((CompositeElementSymbol) architectureElement, writer, netDefinitionMode); - } else if (architectureElement instanceof LayerSymbol){ - include((LayerSymbol) architectureElement, writer, netDefinitionMode); - } else { - include((IOSymbol) architectureElement, writer, netDefinitionMode); - } - } - - public void include(ArchitectureElementSymbol architectureElementSymbol, String netDefinitionMode) { - include(architectureElementSymbol, NetDefinitionMode.fromString(netDefinitionMode)); - } - - public void include(ArchitectureElementSymbol architectureElement, NetDefinitionMode netDefinitionMode){ - if (writer == null){ - throw new IllegalStateException("missing writer"); - } - include(architectureElement, writer, netDefinitionMode); - } - - public Map.Entry<String,String> process(String templateNameWithoutEnding, Target targetLanguage){ - StringWriter newWriter = new StringWriter(); - this.mainTemplateNameWithoutEnding = templateNameWithoutEnding; - this.targetLanguage = targetLanguage; - this.writer = newWriter; - - include("", templateNameWithoutEnding, newWriter); - String fileEnding = targetLanguage.toString(); - String fileName = getFileNameWithoutEnding() + fileEnding; - Map.Entry<String,String> fileContent = new AbstractMap.SimpleEntry<>(fileName, newWriter.toString()); - - this.mainTemplateNameWithoutEnding = null; - this.targetLanguage = null; - this.writer = null; - return fileContent; - } - - public String join(Iterable iterable, String separator){ - return join(iterable, separator, "", ""); - } - - public String join(Iterable iterable, String separator, String elementPrefix, String elementPostfix){ - StringBuilder stringBuilder = new StringBuilder(); - boolean isFirst = true; - for (Object element : iterable){ - if (!isFirst){ - stringBuilder.append(separator); - } - stringBuilder.append(elementPrefix); - stringBuilder.append(element.toString()); - stringBuilder.append(elementPostfix); - isFirst = false; - } - return stringBuilder.toString(); - } - - - public boolean isLogisticRegressionOutput(ArchitectureElementSymbol architectureElement){ - return isTOutput(Sigmoid.class, architectureElement); - } - - public boolean isLinearRegressionOutput(ArchitectureElementSymbol architectureElement){ - return architectureElement.isOutput() - && !isLogisticRegressionOutput(architectureElement) - && !isSoftmaxOutput(architectureElement); - } - - - public boolean isSoftmaxOutput(ArchitectureElementSymbol architectureElement){ - return isTOutput(Softmax.class, architectureElement); - } - - private boolean isTOutput(Class inputPredefinedLayerClass, ArchitectureElementSymbol architectureElement){ - if (architectureElement.isOutput() - && architectureElement.getInputElement().isPresent() - && architectureElement.getInputElement().get() instanceof LayerSymbol){ - LayerSymbol inputLayer = (LayerSymbol) architectureElement.getInputElement().get(); - if (inputPredefinedLayerClass.isInstance(inputLayer.getDeclaration())){ - return true; - } - } - return false; - } -} diff --git a/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/CNNTrain2Gluon.java b/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/CNNTrain2Gluon.java index 845f6b47..8be4e1f5 100644 --- a/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/CNNTrain2Gluon.java +++ b/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/CNNTrain2Gluon.java @@ -1,133 +1,27 @@ package de.monticore.lang.monticar.cnnarch.gluongenerator; -import de.monticore.io.paths.ModelPath; -import de.monticore.lang.monticar.cnntrain.CNNTrainGenerator; -import de.monticore.lang.monticar.cnntrain._ast.ASTCNNTrainNode; -import de.monticore.lang.monticar.cnntrain._ast.ASTOptimizerEntry; -import de.monticore.lang.monticar.cnntrain._cocos.CNNTrainCocos; -import de.monticore.lang.monticar.cnntrain._symboltable.CNNTrainCompilationUnitSymbol; -import de.monticore.lang.monticar.cnntrain._symboltable.CNNTrainLanguage; +import de.monticore.lang.monticar.cnnarch.mxnetgenerator.ConfigurationData; + +import de.monticore.lang.monticar.cnnarch.mxnetgenerator.CNNTrain2MxNet; +import de.monticore.lang.monticar.cnnarch.mxnetgenerator.TemplateConfiguration; import de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol; -import de.monticore.lang.monticar.generator.FileContent; -import de.monticore.lang.monticar.generator.cpp.GeneratorCPP; -import de.monticore.symboltable.GlobalScope; -import de.se_rwth.commons.logging.Log; -import java.io.IOException; -import java.nio.file.Path; import java.util.*; -public class CNNTrain2Gluon implements CNNTrainGenerator { - private String generationTargetPath; - private String instanceName; - - private void supportCheck(ConfigurationSymbol configuration){ - checkEntryParams(configuration); - checkOptimizerParams(configuration); - } - - private void checkEntryParams(ConfigurationSymbol configuration){ - TrainParamSupportChecker funcChecker = new TrainParamSupportChecker(); - Iterator it = configuration.getEntryMap().keySet().iterator(); - while (it.hasNext()) { - String key = it.next().toString(); - ASTCNNTrainNode astTrainEntryNode = (ASTCNNTrainNode) configuration.getEntryMap().get(key).getAstNode().get(); - astTrainEntryNode.accept(funcChecker); - } - it = configuration.getEntryMap().keySet().iterator(); - while (it.hasNext()) { - String key = it.next().toString(); - if (funcChecker.getUnsupportedElemList().contains(key)) { - it.remove(); - } - } - } - - private void checkOptimizerParams(ConfigurationSymbol configuration){ - TrainParamSupportChecker funcChecker = new TrainParamSupportChecker(); - if (configuration.getOptimizer() != null) { - ASTOptimizerEntry astOptimizer = (ASTOptimizerEntry) configuration.getOptimizer().getAstNode().get(); - astOptimizer.accept(funcChecker); - if (funcChecker.getUnsupportedElemList().contains(funcChecker.unsupportedOptFlag)) { - configuration.setOptimizer(null); - }else { - Iterator it = configuration.getOptimizer().getOptimizerParamMap().keySet().iterator(); - while (it.hasNext()) { - String key = it.next().toString(); - if (funcChecker.getUnsupportedElemList().contains(key)) { - it.remove(); - } - } - } - } - } - - private static void quitGeneration(){ - Log.error("Code generation is aborted"); - System.exit(1); - } - +public class CNNTrain2Gluon extends CNNTrain2MxNet { public CNNTrain2Gluon() { - setGenerationTargetPath("./target/generated-sources-cnnarch/"); - } - - public String getInstanceName() { - String parsedInstanceName = this.instanceName.replace('.', '_').replace('[', '_').replace(']', '_'); - parsedInstanceName = parsedInstanceName.substring(0, 1).toLowerCase() + parsedInstanceName.substring(1); - return parsedInstanceName; - } - - public void setInstanceName(String instanceName) { - this.instanceName = instanceName; - } - - public String getGenerationTargetPath() { - if (generationTargetPath.charAt(generationTargetPath.length() - 1) != '/') { - this.generationTargetPath = generationTargetPath + "/"; - } - return generationTargetPath; - } - - public void setGenerationTargetPath(String generationTargetPath) { - this.generationTargetPath = generationTargetPath; - } - - public ConfigurationSymbol getConfigurationSymbol(Path modelsDirPath, String rootModelName) { - final ModelPath mp = new ModelPath(modelsDirPath); - GlobalScope scope = new GlobalScope(mp, new CNNTrainLanguage()); - Optional<CNNTrainCompilationUnitSymbol> compilationUnit = scope.resolve(rootModelName, CNNTrainCompilationUnitSymbol.KIND); - if (!compilationUnit.isPresent()) { - Log.error("could not resolve training configuration " + rootModelName); - quitGeneration(); - } - setInstanceName(compilationUnit.get().getFullName()); - CNNTrainCocos.checkAll(compilationUnit.get()); - supportCheck(compilationUnit.get().getConfiguration()); - return compilationUnit.get().getConfiguration(); - } - - public void generate(Path modelsDirPath, String rootModelName) { - ConfigurationSymbol configuration = getConfigurationSymbol(modelsDirPath, rootModelName); - Map<String, String> fileContents = generateStrings(configuration); - GeneratorCPP genCPP = new GeneratorCPP(); - genCPP.setGenerationTargetPath(getGenerationTargetPath()); - try { - for (String fileName : fileContents.keySet()){ - genCPP.generateFile(new FileContent(fileContents.get(fileName), fileName)); - } - } catch (IOException e) { - Log.error("CNNTrainer file could not be generated" + e.getMessage()); - } + super(); } + @Override public Map<String, String> generateStrings(ConfigurationSymbol configuration) { + TemplateConfiguration templateConfiguration = new GluonTemplateConfiguration(); ConfigurationData configData = new ConfigurationData(configuration, getInstanceName()); List<ConfigurationData> configDataList = new ArrayList<>(); configDataList.add(configData); Map<String, Object> ftlContext = Collections.singletonMap("configurations", configDataList); - String templateContent = TemplateConfiguration.processTemplate(ftlContext, "CNNTrainer.ftl"); + String templateContent = templateConfiguration.processTemplate(ftlContext, "CNNTrainer.ftl"); return Collections.singletonMap("CNNTrainer_" + getInstanceName() + ".py", templateContent); } - } \ No newline at end of file diff --git a/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/ConfigurationData.java b/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/ConfigurationData.java deleted file mode 100644 index 10cf8dd7..00000000 --- a/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/ConfigurationData.java +++ /dev/null @@ -1,99 +0,0 @@ -package de.monticore.lang.monticar.cnnarch.gluongenerator; - -import de.monticore.lang.monticar.cnntrain._symboltable.*; - -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; - -public class ConfigurationData { - - ConfigurationSymbol configuration; - String instanceName; - - public ConfigurationData(ConfigurationSymbol configuration, String instanceName) { - this.configuration = configuration; - this.instanceName = instanceName; - } - - public ConfigurationSymbol getConfiguration() { - return configuration; - } - - public String getInstanceName() { - return instanceName; - } - - public String getNumEpoch() { - if (!getConfiguration().getEntryMap().containsKey("num_epoch")) { - return null; - } - return String.valueOf(getConfiguration().getEntry("num_epoch").getValue()); - } - - public String getBatchSize() { - if (!getConfiguration().getEntryMap().containsKey("batch_size")) { - return null; - } - return String.valueOf(getConfiguration().getEntry("batch_size") .getValue()); - } - - public Boolean getLoadCheckpoint() { - if (!getConfiguration().getEntryMap().containsKey("load_checkpoint")) { - return null; - } - return (Boolean) getConfiguration().getEntry("load_checkpoint").getValue().getValue(); - } - - public Boolean getNormalize() { - if (!getConfiguration().getEntryMap().containsKey("normalize")) { - return null; - } - return (Boolean) getConfiguration().getEntry("normalize").getValue().getValue(); - } - - public String getContext() { - if (!getConfiguration().getEntryMap().containsKey("context")) { - return null; - } - return getConfiguration().getEntry("context").getValue().toString(); - } - - public String getEvalMetric() { - if (!getConfiguration().getEntryMap().containsKey("eval_metric")) { - return null; - } - return getConfiguration().getEntry("eval_metric").getValue().toString(); - } - - public String getOptimizerName() { - if (getConfiguration().getOptimizer() == null) { - return null; - } - return getConfiguration().getOptimizer().getName(); - } - - public Map<String, String> getOptimizerParams() { - // get classes for single enum values - List<Class> lrPolicyClasses = new ArrayList<>(); - for (LRPolicy enum_value: LRPolicy.values()) { - lrPolicyClasses.add(enum_value.getClass()); - } - - Map<String, String> mapToStrings = new HashMap<>(); - Map<String, OptimizerParamSymbol> optimizerParams = getConfiguration().getOptimizer().getOptimizerParamMap(); - for (Map.Entry<String, OptimizerParamSymbol> entry : optimizerParams.entrySet()) { - String paramName = entry.getKey(); - String valueAsString = entry.getValue().toString(); - Class realClass = entry.getValue().getValue().getValue().getClass(); - if (realClass == Boolean.class) { - valueAsString = (Boolean) entry.getValue().getValue().getValue() ? "True" : "False"; - } else if (lrPolicyClasses.contains(realClass)) { - valueAsString = "'" + valueAsString + "'"; - } - mapToStrings.put(paramName, valueAsString); - } - return mapToStrings; - } -} diff --git a/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/GluonTemplateConfiguration.java b/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/GluonTemplateConfiguration.java new file mode 100644 index 00000000..a0ecf88b --- /dev/null +++ b/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/GluonTemplateConfiguration.java @@ -0,0 +1,28 @@ +package de.monticore.lang.monticar.cnnarch.gluongenerator; + +import de.monticore.lang.monticar.cnnarch.mxnetgenerator.TemplateConfiguration; +import freemarker.template.Configuration; + +/** + * + */ +public class GluonTemplateConfiguration extends TemplateConfiguration { + private static Configuration configuration; + + public GluonTemplateConfiguration() { + super(); + if (configuration == null) { + configuration = super.createConfiguration(); + } + } + + @Override + protected String getBaseTemplatePackagePath() { + return "/templates/gluon/"; + } + + @Override + public Configuration getConfiguration() { + return configuration; + } +} diff --git a/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/LayerNameCreator.java b/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/LayerNameCreator.java deleted file mode 100644 index 0b91f102..00000000 --- a/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/LayerNameCreator.java +++ /dev/null @@ -1,146 +0,0 @@ -/** - * - * ****************************************************************************** - * MontiCAR Modeling Family, www.se-rwth.de - * Copyright (c) 2017, Software Engineering Group at RWTH Aachen, - * All rights reserved. - * - * This project is free software; you can redistribute it and/or - * modify it under the terms of the GNU Lesser General Public - * License as published by the Free Software Foundation; either - * version 3.0 of the License, or (at your option) any later version. - * This library is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU - * Lesser General Public License for more details. - * - * You should have received a copy of the GNU Lesser General Public - * License along with this project. If not, see <http://www.gnu.org/licenses/>. - * ******************************************************************************* - */ -package de.monticore.lang.monticar.cnnarch.gluongenerator; - -import de.monticore.lang.monticar.cnnarch._symboltable.*; -import de.monticore.lang.monticar.cnnarch.predefined.Convolution; -import de.monticore.lang.monticar.cnnarch.predefined.FullyConnected; -import de.monticore.lang.monticar.cnnarch.predefined.Pooling; - -import java.util.*; - -public class LayerNameCreator { - - private Map<ArchitectureElementSymbol, String> elementToName = new HashMap<>(); - private Map<String, ArchitectureElementSymbol> nameToElement = new HashMap<>(); - - public LayerNameCreator(ArchitectureSymbol architecture) { - name(architecture.getBody(), 1, new ArrayList<>()); - } - - public ArchitectureElementSymbol getArchitectureElement(String name){ - return nameToElement.get(name); - } - - public String getName(ArchitectureElementSymbol architectureElement){ - return elementToName.get(architectureElement); - } - - protected int name(ArchitectureElementSymbol architectureElement, int stage, List<Integer> streamIndices){ - if (architectureElement instanceof CompositeElementSymbol){ - return nameComposite((CompositeElementSymbol) architectureElement, stage, streamIndices); - } else{ - if (architectureElement.isAtomic()){ - if (architectureElement.getMaxSerialLength().get() > 0){ - return add(architectureElement, stage, streamIndices); - } else { - return stage; - } - } else { - ArchitectureElementSymbol resolvedElement = architectureElement.getResolvedThis().get(); - return name(resolvedElement, stage, streamIndices); - } - } - } - - protected int nameComposite(CompositeElementSymbol compositeElement, int stage, List<Integer> streamIndices){ - if (compositeElement.isParallel()){ - int startStage = stage + 1; - streamIndices.add(1); - int lastIndex = streamIndices.size() - 1; - - List<Integer> endStages = new ArrayList<>(); - for (ArchitectureElementSymbol subElement : compositeElement.getElements()){ - endStages.add(name(subElement, startStage, streamIndices)); - streamIndices.set(lastIndex, streamIndices.get(lastIndex) + 1); - } - - streamIndices.remove(lastIndex); - return Collections.max(endStages) + 1; - } else { - int endStage = stage; - for (ArchitectureElementSymbol subElement : compositeElement.getElements()){ - endStage = name(subElement, endStage, streamIndices); - } - return endStage; - } - } - - protected int add(ArchitectureElementSymbol architectureElement, int stage, List<Integer> streamIndices){ - int endStage = stage; - if (!elementToName.containsKey(architectureElement)) { - String name = createName(architectureElement, endStage, streamIndices); - - while (nameToElement.containsKey(name)) { - endStage++; - name = createName(architectureElement, endStage, streamIndices); - } - - elementToName.put(architectureElement, name); - nameToElement.put(name, architectureElement); - } - return endStage; - } - - protected String createName(ArchitectureElementSymbol architectureElement, int stage, List<Integer> streamIndices){ - if (architectureElement instanceof IOSymbol){ - String name = createBaseName(architectureElement); - IOSymbol ioElement = (IOSymbol) architectureElement; - if (ioElement.getArrayAccess().isPresent()){ - int arrayAccess = ioElement.getArrayAccess().get().getIntValue().get(); - name = name + "_" + arrayAccess + "_"; - } - return name; - } else { - return createBaseName(architectureElement) + stage + createStreamPostfix(streamIndices) + "_"; - } - } - - - protected String createBaseName(ArchitectureElementSymbol architectureElement){ - if (architectureElement instanceof LayerSymbol) { - LayerDeclarationSymbol layerDeclaration = ((LayerSymbol) architectureElement).getDeclaration(); - if (layerDeclaration instanceof Convolution) { - return "conv"; - } else if (layerDeclaration instanceof FullyConnected) { - return "fc"; - } else if (layerDeclaration instanceof Pooling) { - return "pool"; - } else { - return layerDeclaration.getName().toLowerCase(); - } - } else if (architectureElement instanceof CompositeElementSymbol){ - return "group"; - } else { - return architectureElement.getName(); - } - } - - protected String createStreamPostfix(List<Integer> streamIndices){ - StringBuilder stringBuilder = new StringBuilder(); - for (int streamIndex : streamIndices){ - stringBuilder.append("_"); - stringBuilder.append(streamIndex); - } - return stringBuilder.toString(); - } -} - diff --git a/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/LayerSupportChecker.java b/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/LayerSupportChecker.java deleted file mode 100644 index 3e639962..00000000 --- a/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/LayerSupportChecker.java +++ /dev/null @@ -1,19 +0,0 @@ -package de.monticore.lang.monticar.cnnarch.gluongenerator; - -import java.util.ArrayList; -import java.util.List; - - -public class LayerSupportChecker { - - private List<String> unsupportedLayerList = new ArrayList(); - - public LayerSupportChecker() { - //Set the unsupported layers for the backend - //this.unsupportedLayerList.add(PREDEFINED_LAYER_NAME); - } - - public boolean isSupported(String element) { - return !this.unsupportedLayerList.contains(element); - } -} diff --git a/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/Target.java b/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/Target.java deleted file mode 100644 index 07c5e73a..00000000 --- a/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/Target.java +++ /dev/null @@ -1,37 +0,0 @@ -/** - * - * ****************************************************************************** - * MontiCAR Modeling Family, www.se-rwth.de - * Copyright (c) 2017, Software Engineering Group at RWTH Aachen, - * All rights reserved. - * - * This project is free software; you can redistribute it and/or - * modify it under the terms of the GNU Lesser General Public - * License as published by the Free Software Foundation; either - * version 3.0 of the License, or (at your option) any later version. - * This library is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU - * Lesser General Public License for more details. - * - * You should have received a copy of the GNU Lesser General Public - * License along with this project. If not, see <http://www.gnu.org/licenses/>. - * ******************************************************************************* - */ -package de.monticore.lang.monticar.cnnarch.gluongenerator; - -//can be removed -public enum Target { - PYTHON{ - @Override - public String toString() { - return ".py"; - } - }, - CPP{ - @Override - public String toString() { - return ".h"; - } - }; -} diff --git a/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/TemplateConfiguration.java b/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/TemplateConfiguration.java deleted file mode 100644 index d79e9076..00000000 --- a/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/TemplateConfiguration.java +++ /dev/null @@ -1,81 +0,0 @@ -/** - * - * ****************************************************************************** - * MontiCAR Modeling Family, www.se-rwth.de - * Copyright (c) 2017, Software Engineering Group at RWTH Aachen, - * All rights reserved. - * - * This project is free software; you can redistribute it and/or - * modify it under the terms of the GNU Lesser General Public - * License as published by the Free Software Foundation; either - * version 3.0 of the License, or (at your option) any later version. - * This library is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU - * Lesser General Public License for more details. - * - * You should have received a copy of the GNU Lesser General Public - * License along with this project. If not, see <http://www.gnu.org/licenses/>. - * ******************************************************************************* - */ -package de.monticore.lang.monticar.cnnarch.gluongenerator; - -import de.se_rwth.commons.logging.Log; -import freemarker.template.Configuration; -import freemarker.template.Template; -import freemarker.template.TemplateException; -import freemarker.template.TemplateExceptionHandler; - -import java.io.IOException; -import java.io.StringWriter; -import java.io.Writer; -import java.util.Map; - -public class TemplateConfiguration { - - private static TemplateConfiguration instance; - private Configuration configuration; - - private TemplateConfiguration() { - configuration = new Configuration(Configuration.VERSION_2_3_23); - configuration.setClassForTemplateLoading(TemplateConfiguration.class, "/templates/gluon/"); - configuration.setDefaultEncoding("UTF-8"); - configuration.setTemplateExceptionHandler(TemplateExceptionHandler.RETHROW_HANDLER); - } - - private static void quitGeneration(){ - Log.error("Code generation is aborted"); - System.exit(1); - } - - public Configuration getConfiguration() { - return configuration; - } - - public static Configuration get(){ - if (instance == null){ - instance = new TemplateConfiguration(); - } - return instance.getConfiguration(); - } - - public static void processTemplate(Map<String, Object> ftlContext, String templatePath, Writer writer){ - try{ - Template template = TemplateConfiguration.get().getTemplate(templatePath); - template.process(ftlContext, writer); - } catch (IOException e) { - Log.error("Freemarker could not find template " + templatePath + " :\n" + e.getMessage()); - quitGeneration(); - } catch (TemplateException e){ - Log.error("An exception occured in template " + templatePath + " :\n" + e.getMessage()); - quitGeneration(); - } - } - - public static String processTemplate(Map<String, Object> ftlContext, String templatePath){ - StringWriter writer = new StringWriter(); - processTemplate(ftlContext, templatePath, writer); - return writer.toString(); - } - -} diff --git a/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/TrainParamSupportChecker.java b/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/TrainParamSupportChecker.java deleted file mode 100644 index d145e6ff..00000000 --- a/src/main/java/de/monticore/lang/monticar/cnnarch/gluongenerator/TrainParamSupportChecker.java +++ /dev/null @@ -1,94 +0,0 @@ -package de.monticore.lang.monticar.cnnarch.gluongenerator; - -import de.monticore.lang.monticar.cnntrain._ast.*; -import de.monticore.lang.monticar.cnntrain._visitor.CNNTrainVisitor; -import de.se_rwth.commons.logging.Log; -import java.util.ArrayList; -import java.util.List; - -public class TrainParamSupportChecker implements CNNTrainVisitor { - - private List<String> unsupportedElemList = new ArrayList(); - - private void printUnsupportedEntryParam(String nodeName){ - Log.warn("Unsupported training parameter " + "'" + nodeName + "'" + " for the backend MXNet. It will be ignored."); - } - - private void printUnsupportedOptimizer(String nodeName){ - Log.warn("Unsupported optimizer parameter " + "'" + nodeName + "'" + " for the backend MXNet. It will be ignored."); - } - - private void printUnsupportedOptimizerParam(String nodeName){ - Log.warn("Unsupported training optimizer parameter " + "'" + nodeName + "'" + " for the backend MXNet. It will be ignored."); - } - - public TrainParamSupportChecker() { - } - - public static final String unsupportedOptFlag = "unsupported_optimizer"; - - public List getUnsupportedElemList(){ - return this.unsupportedElemList; - } - - //Empty visit method denotes that the corresponding training parameter is supported. - //To set a training parameter as unsupported, add the corresponding node to the unsupportedElemList - public void visit(ASTNumEpochEntry node){} - - public void visit(ASTBatchSizeEntry node){} - - public void visit(ASTLoadCheckpointEntry node){} - - public void visit(ASTNormalizeEntry node){} - - public void visit(ASTTrainContextEntry node){} - - public void visit(ASTEvalMetricEntry node){} - - public void visit(ASTSGDOptimizer node){} - - public void visit(ASTAdamOptimizer node){} - - public void visit(ASTRmsPropOptimizer node){} - - public void visit(ASTAdaGradOptimizer node){} - - public void visit(ASTNesterovOptimizer node){} - - public void visit(ASTAdaDeltaOptimizer node){} - - public void visit(ASTLearningRateEntry node){} - - public void visit(ASTMinimumLearningRateEntry node){} - - public void visit(ASTWeightDecayEntry node){} - - public void visit(ASTLRDecayEntry node){} - - public void visit(ASTLRPolicyEntry node){} - - public void visit(ASTRescaleGradEntry node){} - - public void visit(ASTClipGradEntry node){} - - public void visit(ASTStepSizeEntry node){} - - public void visit(ASTMomentumEntry node){} - - public void visit(ASTBeta1Entry node){} - - public void visit(ASTBeta2Entry node){} - - public void visit(ASTEpsilonEntry node){} - - public void visit(ASTGamma1Entry node){} - - public void visit(ASTGamma2Entry node){} - - public void visit(ASTCenteredEntry node){} - - public void visit(ASTClipWeightsEntry node){} - - public void visit(ASTRhoEntry node){} - -} diff --git a/src/test/java/de/monticore/lang/monticar/cnnarch/gluongenerator/AbstractSymtabTest.java b/src/test/java/de/monticore/lang/monticar/cnnarch/gluongenerator/AbstractSymtabTest.java index 61a0c1c3..200894a5 100644 --- a/src/test/java/de/monticore/lang/monticar/cnnarch/gluongenerator/AbstractSymtabTest.java +++ b/src/test/java/de/monticore/lang/monticar/cnnarch/gluongenerator/AbstractSymtabTest.java @@ -55,9 +55,8 @@ public class AbstractSymtabTest { for (String m : modelPath) { mp.addEntry(Paths.get(m)); } - GlobalScope scope = new GlobalScope(mp, fam); - return scope; + return new GlobalScope(mp, fam); } protected static CNNArchCompilationUnitSymbol getCompilationUnitSymbol(String modelPath, String model) { diff --git a/src/test/java/de/monticore/lang/monticar/cnnarch/gluongenerator/GenerationTest.java b/src/test/java/de/monticore/lang/monticar/cnnarch/gluongenerator/GenerationTest.java index d397cb35..30e242da 100644 --- a/src/test/java/de/monticore/lang/monticar/cnnarch/gluongenerator/GenerationTest.java +++ b/src/test/java/de/monticore/lang/monticar/cnnarch/gluongenerator/GenerationTest.java @@ -32,7 +32,7 @@ import java.util.*; import static junit.framework.TestCase.assertTrue; -public class GenerationTest extends AbstractSymtabTest{ +public class GenerationTest extends AbstractSymtabTest { @Before public void setUp() { @@ -104,7 +104,7 @@ public class GenerationTest extends AbstractSymtabTest{ @Test public void testResNeXtGeneration() throws IOException, TemplateException { - Log.getFindings().clear();; + Log.getFindings().clear(); String[] args = {"-m", "src/test/resources/architectures", "-r", "ResNeXt50"}; CNNArch2GluonCli.main(args); assertTrue(Log.getFindings().isEmpty()); -- GitLab