diff --git a/.circleci/config.yml b/.circleci/config.yml new file mode 100644 index 0000000000000000000000000000000000000000..9ec02adc5749985975f56d2151e96de4a98c4d4d --- /dev/null +++ b/.circleci/config.yml @@ -0,0 +1,69 @@ +# +# +# ****************************************************************************** +# MontiCAR Modeling Family, www.se-rwth.de +# Copyright (c) 2017, Software Engineering Group at RWTH Aachen, +# All rights reserved. +# +# This project is free software; you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public +# License as published by the Free Software Foundation; either +# version 3.0 of the License, or (at your option) any later version. +# This library is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public +# License along with this project. If not, see . +# ******************************************************************************* +# + +# Java Maven CircleCI 2.0 configuration file +# +# Check https://circleci.com/docs/2.0/language-java/ for more details +# +version: 2 +general: + branches: + ignore: + - gh-pages + +jobs: + build: + docker: + # specify the version you desire here + - image: circleci/openjdk:8-jdk + + # Specify service dependencies here if necessary + # CircleCI maintains a library of pre-built images + # documented at https://circleci.com/docs/2.0/circleci-images/ + # - image: circleci/postgres:9.4 + + working_directory: ~/repo + + environment: + # Customize the JVM maximum heap limit + MAVEN_OPTS: -Xmx3200m + + steps: + - checkout + + # run tests! + - run: mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B clean install --settings "settings.xml" +workflows: + version: 2 + commit-workflow: + jobs: + - build + scheduled-workflow: + triggers: + - schedule: + cron: "30 1 * * *" + filters: + branches: + only: master + + jobs: + - build + diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..11780027328480bc4bac278fc1025f1fac376079 --- /dev/null +++ b/.gitignore @@ -0,0 +1,10 @@ +target +nppBackup +.project +.settings +.classpath +.idea +.git + +*.iml + diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml new file mode 100644 index 0000000000000000000000000000000000000000..a45cd7c87301d8847cb1237464de25df85737f9b --- /dev/null +++ b/.gitlab-ci.yml @@ -0,0 +1,50 @@ +# +# +# ****************************************************************************** +# MontiCAR Modeling Family, www.se-rwth.de +# Copyright (c) 2017, Software Engineering Group at RWTH Aachen, +# All rights reserved. +# +# This project is free software; you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public +# License as published by the Free Software Foundation; either +# version 3.0 of the License, or (at your option) any later version. +# This library is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public +# License along with this project. If not, see . +# ******************************************************************************* +# + +stages: +- windows +- linux + +masterJobLinux: + stage: linux + image: maven:3-jdk-8 + script: + - mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B clean deploy --settings settings.xml + - cat target/site/jacoco/index.html + - mvn package sonar:sonar -s settings.xml + only: + - master + +masterJobWindows: + stage: windows + script: + - mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B clean install --settings settings.xml + tags: + - Windows10 + +BranchJobLinux: + stage: linux + image: maven:3-jdk-8 + script: + - mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B clean install --settings settings.xml + - cat target/site/jacoco/index.html + except: + - master diff --git a/.travis.yml b/.travis.yml new file mode 100644 index 0000000000000000000000000000000000000000..5b54b0cf104b040c3fa25918a29d2da17626f6e0 --- /dev/null +++ b/.travis.yml @@ -0,0 +1,5 @@ +script: +- git checkout ${TRAVIS_BRANCH} +- mvn clean install cobertura:cobertura org.eluder.coveralls:coveralls-maven-plugin:report --settings "settings.xml" +after_success: +- if [ "${TRAVIS_BRANCH}" == "master" ]; then mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B deploy --debug --settings "./settings.xml"; fi diff --git a/README.md b/README.md index fdedc33a4caa509cbcd70f7f4ab1f390ee3d6d57..645554bb479a1d17718eb59978e7cfd737a842ec 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,2 @@ -# CNNArch2X - +![pipeline](https://git.rwth-aachen.de/monticore/EmbeddedMontiArc/generators/cnnarch2x/badges/master/build.svg) +![coverage](https://git.rwth-aachen.de/monticore/EmbeddedMontiArc/generators/cnnarch2x/badges/master/coverage.svg) diff --git a/pom.xml b/pom.xml new file mode 100644 index 0000000000000000000000000000000000000000..75e7a7fbbe01dc2f371c911e9917ff27ae4d6b2e --- /dev/null +++ b/pom.xml @@ -0,0 +1,254 @@ + + 4.0.0 + + + + + de.monticore.lang.monticar + cnnarch-generator + 0.0.1-SNAPSHOT + + + + + + + 0.3.1-SNAPSHOT + 0.3.2-SNAPSHOT + 0.1.4 + + + 18.0 + 4.12 + 1.1.2 + 4.3.1 + + + 2.5.4 + 3.3 + 2.4 + 2.4.3 + 0.8.1 + + + grammars + cli + + + 1.8 + + UTF-8 + UTF-8 + + + + + + org.antlr + antlr4-runtime + 4.7.1 + + + + com.google.guava + guava + ${guava.version} + + + + + + de.monticore.lang.monticar + cnn-arch + ${CNNArch.version} + + + + de.monticore.lang.monticar + cnn-arch + ${CNNArch.version} + ${grammars.classifier} + provided + + + + de.monticore.lang.monticar + cnn-train + ${CNNTrain.version} + + + + de.monticore.lang.monticar + cnn-train + ${CNNTrain.version} + ${grammars.classifier} + provided + + + + de.monticore.lang.monticar + embedded-montiarc-math-opt-generator + ${embedded-montiarc-math-opt-generator} + + + + + + junit + junit + ${junit.version} + test + + + + com.github.stefanbirkner + system-rules + 1.3.0 + + + + ch.qos.logback + logback-classic + ${logback.version} + + + + org.jscience + jscience + ${jscience.version} + + + + + + + + + + + maven-deploy-plugin + 2.8.1 + + + + org.jacoco + jacoco-maven-plugin + ${jacoco.plugin} + + + pre-unit-test + + prepare-agent + + + + post-unit-test + test + + report + + + + + + + + + maven-compiler-plugin + ${compiler.plugin} + + true + ${java.version} + ${java.version} + + + + + maven-assembly-plugin + 3.1.0 + + + jar-with-dependencies + package + + single + + + + + de.monticore.lang.monticar.cnnarch.mxnetgenerator.CNNArch2MxNetCli + + + + jar-with-dependencies + + + + + + + + + org.apache.maven.plugins + maven-source-plugin + ${source.plugin} + + + create source jar + package + + jar-no-fork + + + false + + **/*.java + **/*.ftl + + + + + + + org.apache.maven.plugins + maven-surefire-plugin + 2.19.1 + + false + + + + org.eluder.coveralls + coveralls-maven-plugin + 4.3.0 + + + + + org.codehaus.mojo + cobertura-maven-plugin + 2.7 + + xml + 256m + + true + + + + + + + + se-nexus + https://nexus.se.rwth-aachen.de/content/repositories/embeddedmontiarc-releases/ + + + se-nexus + https://nexus.se.rwth-aachen.de/content/repositories/embeddedmontiarc-snapshots/ + + + + diff --git a/settings.xml b/settings.xml new file mode 100644 index 0000000000000000000000000000000000000000..61887dffdaabd02ba790fb8620b8fcb11fc81c87 --- /dev/null +++ b/settings.xml @@ -0,0 +1,103 @@ + + + + + + + org.mortbay.jetty + de.topobyte + + + + + + + + se-nexus + cibuild + ${env.cibuild} + + + + github + travisbuilduser + ${env.travisbuilduserpassword} + + + + + + se-nexus + external:* + https://nexus.se.rwth-aachen.de/content/groups/public + + + + + + sonar + + true + + + + + https://metric.se.rwth-aachen.de + + + jenkins + + + ${env.sonar} + + + + + + se-nexus + + + central + http://central + + + + + + + central + http://central + + + + + + + + + se-nexus + + diff --git a/src/license/se/license.txt b/src/license/se/license.txt new file mode 100644 index 0000000000000000000000000000000000000000..27ba0853ebb009d0220b03526fe684ebf1474eb3 --- /dev/null +++ b/src/license/se/license.txt @@ -0,0 +1,18 @@ + + ****************************************************************************** + MontiCAR Modeling Family, www.se-rwth.de + Copyright (c) 2017, Software Engineering Group at RWTH Aachen, + All rights reserved. + + This project is free software; you can redistribute it and/or + modify it under the terms of the GNU Lesser General Public + License as published by the Free Software Foundation; either + version 3.0 of the License, or (at your option) any later version. + This library is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + Lesser General Public License for more details. + + You should have received a copy of the GNU Lesser General Public + License along with this project. If not, see . +******************************************************************************* diff --git a/src/main/java/de/monticore/lang/monticar/cnnarch/generator/ArchitectureElementData.java b/src/main/java/de/monticore/lang/monticar/cnnarch/generator/ArchitectureElementData.java new file mode 100644 index 0000000000000000000000000000000000000000..a1e2fc2a33e3ff73e9d1b5e769a48905d7d9cec5 --- /dev/null +++ b/src/main/java/de/monticore/lang/monticar/cnnarch/generator/ArchitectureElementData.java @@ -0,0 +1,203 @@ +/** + * + * ****************************************************************************** + * MontiCAR Modeling Family, www.se-rwth.de + * Copyright (c) 2017, Software Engineering Group at RWTH Aachen, + * All rights reserved. + * + * This project is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 3.0 of the License, or (at your option) any later version. + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this project. If not, see . + * ******************************************************************************* + */ +package de.monticore.lang.monticar.cnnarch.generator; + +import de.monticore.lang.monticar.cnnarch._symboltable.ArchTypeSymbol; +import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureElementSymbol; +import de.monticore.lang.monticar.cnnarch._symboltable.ConstantSymbol; +import de.monticore.lang.monticar.cnnarch._symboltable.LayerSymbol; +import de.monticore.lang.monticar.cnnarch.predefined.AllPredefinedLayers; +import de.se_rwth.commons.logging.Log; + +import javax.annotation.Nullable; +import java.util.Arrays; +import java.util.List; + +public class ArchitectureElementData { + + private String name; + private ArchitectureElementSymbol element; + private CNNArchTemplateController templateController; + + public ArchitectureElementData(String name, ArchitectureElementSymbol element, CNNArchTemplateController templateController) { + this.name = name; + this.element = element; + this.templateController = templateController; + } + + public String getName() { + return name; + } + + public void setName(String name) { + this.name = name; + } + + public ArchitectureElementSymbol getElement() { + return element; + } + + public void setElement(ArchitectureElementSymbol element) { + this.element = element; + } + + public CNNArchTemplateController getTemplateController() { + return templateController; + } + + public void setTemplateController(CNNArchTemplateController templateController) { + this.templateController = templateController; + } + + public List getInputs(){ + return getTemplateController().getLayerInputs(getElement()); + } + + public boolean isLogisticRegressionOutput(){ + return getTemplateController().isLogisticRegressionOutput(getElement()); + } + + + public boolean isLinearRegressionOutput(){ + boolean result = getTemplateController().isLinearRegressionOutput(getElement()); + if (result){ + Log.warn("The Output '" + getElement().getName() + "' is a linear regression output (squared loss) during training" + + " because the previous architecture element is not a softmax (cross-entropy loss) or sigmoid (logistic regression loss) activation. " + + "Other loss functions are currently not supported. " + , getElement().getSourcePosition()); + } + return result; + } + + public boolean isSoftmaxOutput(){ + return getTemplateController().isSoftmaxOutput(getElement()); + } + + public int getConstValue() { + ConstantSymbol constant = (ConstantSymbol) getElement(); + return constant.getExpression().getIntValue().get(); + } + + public List getKernel(){ + return ((LayerSymbol) getElement()) + .getIntTupleValue(AllPredefinedLayers.KERNEL_NAME).get(); + } + + public int getChannels(){ + return ((LayerSymbol) getElement()) + .getIntValue(AllPredefinedLayers.CHANNELS_NAME).get(); + } + + public List getStride(){ + return ((LayerSymbol) getElement()) + .getIntTupleValue(AllPredefinedLayers.STRIDE_NAME).get(); + } + + public int getUnits(){ + return ((LayerSymbol) getElement()) + .getIntValue(AllPredefinedLayers.UNITS_NAME).get(); + } + + public boolean getNoBias(){ + return ((LayerSymbol) getElement()) + .getBooleanValue(AllPredefinedLayers.NOBIAS_NAME).get(); + } + + public double getP(){ + return ((LayerSymbol) getElement()) + .getDoubleValue(AllPredefinedLayers.P_NAME).get(); + } + + public int getIndex(){ + return ((LayerSymbol) getElement()) + .getIntValue(AllPredefinedLayers.INDEX_NAME).get(); + } + + public int getNumOutputs(){ + return ((LayerSymbol) getElement()) + .getIntValue(AllPredefinedLayers.NUM_SPLITS_NAME).get(); + } + + public boolean getFixGamma(){ + return ((LayerSymbol) getElement()) + .getBooleanValue(AllPredefinedLayers.FIX_GAMMA_NAME).get(); + } + + public int getNsize(){ + return ((LayerSymbol) getElement()) + .getIntValue(AllPredefinedLayers.NSIZE_NAME).get(); + } + + public double getKnorm(){ + return ((LayerSymbol) getElement()) + .getDoubleValue(AllPredefinedLayers.KNORM_NAME).get(); + } + + public double getAlpha(){ + return ((LayerSymbol) getElement()) + .getDoubleValue(AllPredefinedLayers.ALPHA_NAME).get(); + } + + public double getBeta(){ + return ((LayerSymbol) getElement()) + .getDoubleValue(AllPredefinedLayers.BETA_NAME).get(); + } + + public int getSize(){ + return ((LayerSymbol) getElement()) + .getIntValue(AllPredefinedLayers.SIZE_NAME).get(); + } + + @Nullable + public String getPoolType(){ + return ((LayerSymbol) getElement()) + .getStringValue(AllPredefinedLayers.POOL_TYPE_NAME).get(); + } + + @Nullable + public List getPadding(){ + return getPadding((LayerSymbol) getElement()); + } + + @Nullable + public List getPadding(LayerSymbol layer){ + List kernel = layer.getIntTupleValue(AllPredefinedLayers.KERNEL_NAME).get(); + List stride = layer.getIntTupleValue(AllPredefinedLayers.STRIDE_NAME).get(); + ArchTypeSymbol inputType = layer.getInputTypes().get(0); + ArchTypeSymbol outputType = layer.getOutputTypes().get(0); + + int heightWithPad = kernel.get(0) + stride.get(0)*(outputType.getHeight() - 1); + int widthWithPad = kernel.get(1) + stride.get(1)*(outputType.getWidth() - 1); + int heightPad = Math.max(0, heightWithPad - inputType.getHeight()); + int widthPad = Math.max(0, widthWithPad - inputType.getWidth()); + + int topPad = (int)Math.ceil(heightPad / 2.0); + int bottomPad = (int)Math.floor(heightPad / 2.0); + int leftPad = (int)Math.ceil(widthPad / 2.0); + int rightPad = (int)Math.floor(widthPad / 2.0); + + if (topPad == 0 && bottomPad == 0 && leftPad == 0 && rightPad == 0){ + return null; + } + + return Arrays.asList(0,0,0,0,topPad,bottomPad,leftPad,rightPad); + } +} diff --git a/src/main/java/de/monticore/lang/monticar/cnnarch/generator/ArchitectureSupportChecker.java b/src/main/java/de/monticore/lang/monticar/cnnarch/generator/ArchitectureSupportChecker.java new file mode 100644 index 0000000000000000000000000000000000000000..17c666b617163be3f5e5b4c28e8e467395e548a8 --- /dev/null +++ b/src/main/java/de/monticore/lang/monticar/cnnarch/generator/ArchitectureSupportChecker.java @@ -0,0 +1,104 @@ +package de.monticore.lang.monticar.cnnarch.generator; + +import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureSymbol; +import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureElementSymbol; +import de.monticore.lang.monticar.cnnarch._symboltable.CompositeElementSymbol; +import de.monticore.lang.monticar.cnnarch._symboltable.ConstantSymbol; +import de.monticore.lang.monticar.cnnarch._symboltable.SerialCompositeElementSymbol; +import de.se_rwth.commons.logging.Log; + +import java.util.List; + +public abstract class ArchitectureSupportChecker { + + // Overload functions returning always true to enable the features + protected boolean checkMultipleStreams(ArchitectureSymbol architecture) { + if (architecture.getStreams().size() != 1) { + Log.error("This cnn architecture has multiple instructions, " + + "which is currently not supported by the code generator. " + , architecture.getSourcePosition()); + + return false; + } + + return true; + } + + protected boolean checkMultipleInputs(ArchitectureSymbol architecture) { + if (architecture.getInputs().size() > 1) { + Log.error("This cnn architecture has multiple inputs, " + + "which is currently not supported by the code generator. " + , architecture.getSourcePosition()); + + return false; + } + + return true; + } + + protected boolean checkMultipleOutputs(ArchitectureSymbol architecture) { + if (architecture.getOutputs().size() > 1) { + Log.error("This cnn architecture has multiple outputs, " + + "which is currently not supported by the code generator. " + , architecture.getSourcePosition()); + + return false; + } + + return true; + } + + protected boolean checkMultiDimensionalOutput(ArchitectureSymbol architecture) { + if (architecture.getOutputs().get(0).getDefinition().getType().getWidth() != 1 || + architecture.getOutputs().get(0).getDefinition().getType().getHeight() != 1) { + Log.error("This cnn architecture has a multi-dimensional output, " + + "which is currently not supported by the code generator." + , architecture.getSourcePosition()); + + return false; + } + + return true; + } + + protected boolean hasConstant(ArchitectureElementSymbol element) { + ArchitectureElementSymbol resolvedElement = element.getResolvedThis().get(); + + if (resolvedElement instanceof CompositeElementSymbol) { + List constructedElements = ((CompositeElementSymbol) resolvedElement).getElements(); + + for (ArchitectureElementSymbol constructedElement : constructedElements) { + if (hasConstant(constructedElement)) { + return true; + } + } + } + else if (resolvedElement instanceof ConstantSymbol) { + return true; + } + + return false; + } + + protected boolean checkConstants(ArchitectureSymbol architecture) { + for (SerialCompositeElementSymbol stream : architecture.getStreams()) { + for (ArchitectureElementSymbol element : stream.getElements()) { + if (hasConstant(element)) { + Log.error("This cnn architecture has a constant, which is currently not supported by the code generator." + , architecture.getSourcePosition()); + return false; + } + } + } + + return true; + } + + public boolean check(ArchitectureSymbol architecture) { + return checkMultipleStreams(architecture) + && checkMultipleInputs(architecture) + && checkMultipleOutputs(architecture) + && checkMultiDimensionalOutput(architecture) + && checkConstants(architecture); + } +} diff --git a/src/main/java/de/monticore/lang/monticar/cnnarch/generator/CNNArchGenerator.java b/src/main/java/de/monticore/lang/monticar/cnnarch/generator/CNNArchGenerator.java new file mode 100644 index 0000000000000000000000000000000000000000..04e97d6cbb9eb76e201079abb04888127c67ee29 --- /dev/null +++ b/src/main/java/de/monticore/lang/monticar/cnnarch/generator/CNNArchGenerator.java @@ -0,0 +1,143 @@ +/** + * + * ****************************************************************************** + * MontiCAR Modeling Family, www.se-rwth.de + * Copyright (c) 2017, Software Engineering Group at RWTH Aachen, + * All rights reserved. + * + * This project is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 3.0 of the License, or (at your option) any later version. + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this project. If not, see . + * ******************************************************************************* + */ +package de.monticore.lang.monticar.cnnarch.generator; + +import de.monticore.io.paths.ModelPath; +import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureSymbol; +import de.monticore.lang.monticar.cnnarch._symboltable.CNNArchLanguage; +import de.monticore.lang.monticar.generator.cmake.CMakeConfig; +import de.monticore.lang.monticar.generator.cmake.CMakeFindModule; +import de.monticore.lang.monticar.generator.FileContent; +import de.monticore.lang.monticar.generator.cpp.GeneratorCPP; +import de.monticore.symboltable.GlobalScope; +import de.monticore.symboltable.Scope; +import de.se_rwth.commons.logging.Log; + +import java.io.IOException; +import java.nio.file.Path; +import java.util.HashMap; +import java.util.Map; + +public abstract class CNNArchGenerator { + + protected ArchitectureSupportChecker architectureSupportChecker; + protected LayerSupportChecker layerSupportChecker; + + private String generationTargetPath; + private String modelsDirPath; + + protected CNNArchGenerator() { + setGenerationTargetPath("./target/generated-sources-cnnarch/"); + } + + public static void quitGeneration(){ + Log.error("Code generation is aborted"); + System.exit(1); + } + + public boolean isCMakeRequired() { + return true; + } + + public String getGenerationTargetPath(){ + if (generationTargetPath.charAt(generationTargetPath.length() - 1) != '/') { + this.generationTargetPath = generationTargetPath + "/"; + } + return generationTargetPath; + } + + public void setGenerationTargetPath(String generationTargetPath){ + this.generationTargetPath = generationTargetPath; + } + + protected String getModelsDirPath() { + return this.modelsDirPath; + } + + public void generate(Path modelsDirPath, String rootModelName){ + this.modelsDirPath = modelsDirPath.toString(); + final ModelPath mp = new ModelPath(modelsDirPath); + GlobalScope scope = new GlobalScope(mp, new CNNArchLanguage()); + generate(scope, rootModelName); + } + + // TODO: Rewrite so that CNNArchSymbolCompiler is used in EMADL2CPP instead of this method + public boolean check(ArchitectureSymbol architecture) { + return architectureSupportChecker.check(architecture) && layerSupportChecker.check(architecture); + } + + public void generate(Scope scope, String rootModelName){ + CNNArchSymbolCompiler symbolCompiler = new CNNArchSymbolCompiler(architectureSupportChecker, layerSupportChecker); + ArchitectureSymbol architectureSymbol = symbolCompiler.compileArchitectureSymbol(scope, rootModelName); + + try{ + String confPath = getModelsDirPath() + "/data_paths.txt"; + DataPathConfigParser newParserConfig = new DataPathConfigParser(confPath); + String dataPath = newParserConfig.getDataPath(rootModelName); + architectureSymbol.setDataPath(dataPath); + architectureSymbol.setComponentName(rootModelName); + generateFiles(architectureSymbol); + } catch (IOException e){ + Log.error(e.toString()); + } + } + + //check cocos with CNNArchCocos.checkAll(architecture) before calling this method. + public abstract Map generateStrings(ArchitectureSymbol architecture); + + //check cocos with CNNArchCocos.checkAll(architecture) before calling this method. + public void generateFiles(ArchitectureSymbol architecture) throws IOException{ + Map fileContentMap = generateStrings(architecture); + generateFromFilecontentsMap(fileContentMap); + } + + public void generateFromFilecontentsMap(Map fileContentMap) throws IOException { + GeneratorCPP genCPP = new GeneratorCPP(); + genCPP.setGenerationTargetPath(getGenerationTargetPath()); + for (String fileName : fileContentMap.keySet()){ + genCPP.generateFile(new FileContent(fileContentMap.get(fileName), fileName)); + } + } + + public void generateCMake(String rootModelName){ + Map fileContentMap = generateCMakeContent(rootModelName); + try { + generateFromFilecontentsMap(fileContentMap); + } catch (IOException e) { + Log.error("CMake file could not be generated" + e.getMessage()); + } + } + + public Map generateCMakeContent(String rootModelName) { + // model name should start with a lower case letter. If it is a component, replace dot . by _ + rootModelName = rootModelName.replace('.', '_').replace('[', '_').replace(']', '_'); + rootModelName = rootModelName.substring(0, 1).toLowerCase() + rootModelName.substring(1); + + CMakeConfig cMakeConfig = new CMakeConfig(rootModelName); + cMakeConfig.addModuleDependency(new CMakeFindModule("Armadillo", true)); + cMakeConfig.addCMakeCommand("set(LIBS ${LIBS} mxnet)"); + + Map fileContentMap = new HashMap<>(); + for (FileContent fileContent : cMakeConfig.generateCMakeFiles()){ + fileContentMap.put(fileContent.getFileName(), fileContent.getFileContent()); + } + return fileContentMap; + }} diff --git a/src/main/java/de/monticore/lang/monticar/cnnarch/generator/CNNArchSymbolCompiler.java b/src/main/java/de/monticore/lang/monticar/cnnarch/generator/CNNArchSymbolCompiler.java new file mode 100644 index 0000000000000000000000000000000000000000..46fddd137a5aec44704769b9f2a8a50c20be9943 --- /dev/null +++ b/src/main/java/de/monticore/lang/monticar/cnnarch/generator/CNNArchSymbolCompiler.java @@ -0,0 +1,52 @@ +package de.monticore.lang.monticar.cnnarch.generator; + +import de.monticore.io.paths.ModelPath; +import de.monticore.lang.monticar.cnnarch._cocos.CNNArchCocos; +import de.monticore.lang.monticar.cnnarch._symboltable.*; +import de.monticore.symboltable.GlobalScope; +import de.monticore.symboltable.Scope; +import de.se_rwth.commons.logging.Log; + +import java.nio.file.Path; +import java.util.List; +import java.util.Optional; + +public class CNNArchSymbolCompiler { + private final ArchitectureSupportChecker architectureSupportChecker; + private final LayerSupportChecker layerSupportChecker; + + public CNNArchSymbolCompiler(final ArchitectureSupportChecker architectureSupportChecker, + final LayerSupportChecker layerSupportChecker) { + this.architectureSupportChecker = architectureSupportChecker; + this.layerSupportChecker = layerSupportChecker; + } + + public ArchitectureSymbol compileArchitectureSymbolFromModelsDir( + final Path modelsDirPath, final String rootModel) { + ModelPath mp = new ModelPath(modelsDirPath); + GlobalScope scope = new GlobalScope(mp, new CNNArchLanguage()); + return compileArchitectureSymbol(scope, rootModel); + } + + public ArchitectureSymbol compileArchitectureSymbol(Scope scope, String rootModelName) { + Optional compilationUnit = scope.resolve(rootModelName, CNNArchCompilationUnitSymbol.KIND); + if (!compilationUnit.isPresent()){ + failWithMessage("Could not resolve architecture " + rootModelName); + } + + CNNArchCocos.checkAll(compilationUnit.get()); + + ArchitectureSymbol architecture = compilationUnit.get().getArchitecture(); + + if (!architectureSupportChecker.check(architecture) || !layerSupportChecker.check(architecture)) { + failWithMessage("Architecture not supported by generator"); + } + + return architecture; + } + + private void failWithMessage(final String message) { + Log.error(message); + System.exit(1); + } +} diff --git a/src/main/java/de/monticore/lang/monticar/cnnarch/generator/CNNArchTemplateController.java b/src/main/java/de/monticore/lang/monticar/cnnarch/generator/CNNArchTemplateController.java new file mode 100644 index 0000000000000000000000000000000000000000..ffbfc5be5d61134ecd2d88ef66008820fde13e12 --- /dev/null +++ b/src/main/java/de/monticore/lang/monticar/cnnarch/generator/CNNArchTemplateController.java @@ -0,0 +1,247 @@ +/** + * + * ****************************************************************************** + * MontiCAR Modeling Family, www.se-rwth.de + * Copyright (c) 2017, Software Engineering Group at RWTH Aachen, + * All rights reserved. + * + * This project is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 3.0 of the License, or (at your option) any later version. + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this project. If not, see . + * ******************************************************************************* + */ +package de.monticore.lang.monticar.cnnarch.generator; + +import de.monticore.lang.monticar.cnnarch._symboltable.*; +import de.monticore.lang.monticar.cnnarch.predefined.Sigmoid; +import de.monticore.lang.monticar.cnnarch.predefined.Softmax; + +import java.io.StringWriter; +import java.io.Writer; +import java.util.*; + +public abstract class CNNArchTemplateController { + + public static final String FTL_FILE_ENDING = ".ftl"; + public static final String TEMPLATE_ELEMENTS_DIR_PATH = "elements/"; + public static final String TEMPLATE_CONTROLLER_KEY = "tc"; + public static final String ELEMENT_DATA_KEY = "element"; + + private final TemplateConfiguration templateConfiguration; + + private LayerNameCreator nameManager; + private ArchitectureSymbol architecture; + + //temporary attributes. They are set after calling process() + private Writer writer; + private String mainTemplateNameWithoutEnding; + private Target targetLanguage; + private ArchitectureElementData dataElement; + + protected CNNArchTemplateController(ArchitectureSymbol architecture, TemplateConfiguration templateConfiguration) { + setArchitecture(architecture); + this.templateConfiguration = templateConfiguration; + } + + protected TemplateConfiguration getTemplateConfiguration() { + return templateConfiguration; + } + + protected LayerNameCreator getNameManager() { + return nameManager; + } + + protected void setNameManager(LayerNameCreator nameManager) { + this.nameManager = nameManager; + } + + protected Writer getWriter() { + return writer; + } + + protected void setWriter(Writer writer) { + this.writer = writer; + } + + protected String getMainTemplateNameWithoutEnding() { + return mainTemplateNameWithoutEnding; + } + + protected void setMainTemplateNameWithoutEnding(String mainTemplateNameWithoutEnding) { + this.mainTemplateNameWithoutEnding = mainTemplateNameWithoutEnding; + } + + protected Target getTargetLanguage() { + return targetLanguage; + } + + protected void setTargetLanguage(Target targetLanguage) { + this.targetLanguage = targetLanguage; + } + + protected ArchitectureElementData getDataElement() { + return dataElement; + } + + protected void setDataElement(ArchitectureElementData dataElement) { + this.dataElement = dataElement; + } + + public String getFileNameWithoutEnding() { + return mainTemplateNameWithoutEnding + "_" + getFullArchitectureName(); + } + + public ArchitectureElementData getCurrentElement() { + return dataElement; + } + + public void setCurrentElement(ArchitectureElementSymbol layer) { + this.dataElement = new ArchitectureElementData(getName(layer), layer, this); + } + + public void setCurrentElement(ArchitectureElementData dataElement) { + this.dataElement = dataElement; + } + + public ArchitectureSymbol getArchitecture() { + return architecture; + } + + public void setArchitecture(ArchitectureSymbol architecture) { + this.architecture = architecture; + this.nameManager = new LayerNameCreator(architecture); + } + + public String getName(ArchitectureElementSymbol layer){ + return nameManager.getName(layer); + } + + public String getArchitectureName(){ + return getArchitecture().getEnclosingScope().getSpanningSymbol().get().getName().replaceAll("\\.","_"); + } + + public String getFullArchitectureName(){ + return getArchitecture().getEnclosingScope().getSpanningSymbol().get().getFullName().replaceAll("\\.","_"); + } + + public String getDataPath(){ + return getArchitecture().getDataPath(); + } + + public List getLayerInputs(ArchitectureElementSymbol layer){ + List inputNames = new ArrayList<>(); + + if (isSoftmaxOutput(layer) || isLogisticRegressionOutput(layer)){ + inputNames = getLayerInputs(layer.getInputElement().get()); + } else { + for (ArchitectureElementSymbol input : layer.getPrevious()) { + if (input.getOutputTypes().size() == 1) { + inputNames.add(getName(input)); + } else { + for (int i = 0; i < input.getOutputTypes().size(); i++) { + inputNames.add(getName(input) + "[" + i + "]"); + } + } + } + } + return inputNames; + + } + + public List getArchitectureInputs(){ + List list = new ArrayList<>(); + for (IOSymbol ioElement : getArchitecture().getInputs()){ + list.add(nameManager.getName(ioElement)); + } + return list; + } + + public List getArchitectureOutputs(){ + List list = new ArrayList<>(); + for (IOSymbol ioElement : getArchitecture().getOutputs()){ + list.add(nameManager.getName(ioElement)); + } + return list; + } + + public String getComponentName(){ + return getArchitecture().getComponentName(); + } + + public void include(String relativePath, String templateWithoutFileEnding, Writer writer){ + String templatePath = relativePath + templateWithoutFileEnding + FTL_FILE_ENDING; + Map ftlContext = new HashMap<>(); + ftlContext.put(TEMPLATE_CONTROLLER_KEY, this); + ftlContext.put(ELEMENT_DATA_KEY, getCurrentElement()); + templateConfiguration.processTemplate(ftlContext, templatePath, writer); + } + + public Map.Entry process(String templateNameWithoutEnding, Target targetLanguage){ + StringWriter newWriter = new StringWriter(); + this.mainTemplateNameWithoutEnding = templateNameWithoutEnding; + this.targetLanguage = targetLanguage; + this.writer = newWriter; + + include("", templateNameWithoutEnding, newWriter); + String fileEnding = targetLanguage.toString(); + String fileName = getFileNameWithoutEnding() + fileEnding; + Map.Entry fileContent = new AbstractMap.SimpleEntry<>(fileName, newWriter.toString()); + + this.mainTemplateNameWithoutEnding = null; + this.targetLanguage = null; + this.writer = null; + return fileContent; + } + + public String join(Iterable iterable, String separator){ + return join(iterable, separator, "", ""); + } + + public String join(Iterable iterable, String separator, String elementPrefix, String elementPostfix){ + StringBuilder stringBuilder = new StringBuilder(); + boolean isFirst = true; + for (Object element : iterable){ + if (!isFirst){ + stringBuilder.append(separator); + } + stringBuilder.append(elementPrefix); + stringBuilder.append(element.toString()); + stringBuilder.append(elementPostfix); + isFirst = false; + } + return stringBuilder.toString(); + } + + + public boolean isLogisticRegressionOutput(ArchitectureElementSymbol architectureElement){ + return isTOutput(Sigmoid.class, architectureElement); + } + + public boolean isLinearRegressionOutput(ArchitectureElementSymbol architectureElement){ + return architectureElement.isOutput() + && !isLogisticRegressionOutput(architectureElement) + && !isSoftmaxOutput(architectureElement); + } + + public boolean isSoftmaxOutput(ArchitectureElementSymbol architectureElement){ + return isTOutput(Softmax.class, architectureElement); + } + + private boolean isTOutput(Class inputPredefinedLayerClass, ArchitectureElementSymbol architectureElement){ + if (architectureElement.isOutput() + && architectureElement.getInputElement().isPresent() + && architectureElement.getInputElement().get() instanceof LayerSymbol){ + LayerSymbol inputLayer = (LayerSymbol) architectureElement.getInputElement().get(); + return inputPredefinedLayerClass.isInstance(inputLayer.getDeclaration()); + } + return false; + } +} diff --git a/src/main/java/de/monticore/lang/monticar/cnnarch/generator/CNNTrainGenerator.java b/src/main/java/de/monticore/lang/monticar/cnnarch/generator/CNNTrainGenerator.java new file mode 100644 index 0000000000000000000000000000000000000000..63a6321ac862cddbf66381df1c203c980f81e312 --- /dev/null +++ b/src/main/java/de/monticore/lang/monticar/cnnarch/generator/CNNTrainGenerator.java @@ -0,0 +1,115 @@ +package de.monticore.lang.monticar.cnnarch.generator; + +import de.monticore.io.paths.ModelPath; +import de.monticore.lang.monticar.cnnarch.generator.ConfigurationData; +import de.monticore.lang.monticar.cnnarch.generator.TemplateConfiguration; +import de.monticore.lang.monticar.cnntrain._ast.ASTCNNTrainNode; +import de.monticore.lang.monticar.cnntrain._ast.ASTOptimizerEntry; +import de.monticore.lang.monticar.cnntrain._cocos.CNNTrainCocos; +import de.monticore.lang.monticar.cnntrain._symboltable.CNNTrainCompilationUnitSymbol; +import de.monticore.lang.monticar.cnntrain._symboltable.CNNTrainLanguage; +import de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol; +import de.monticore.lang.monticar.generator.FileContent; +import de.monticore.lang.monticar.generator.cpp.GeneratorCPP; +import de.monticore.symboltable.GlobalScope; +import de.se_rwth.commons.logging.Log; + +import java.io.IOException; +import java.nio.file.Path; +import java.util.*; + +public abstract class CNNTrainGenerator { + + protected TrainParamSupportChecker trainParamSupportChecker; + + private String generationTargetPath; + private String instanceName; + + protected CNNTrainGenerator() { + setGenerationTargetPath("./target/generated-sources-cnnarch/"); + } + + private void supportCheck(ConfigurationSymbol configuration){ + checkEntryParams(configuration); + checkOptimizerParams(configuration); + } + + private void checkEntryParams(ConfigurationSymbol configuration){ + Iterator it = configuration.getEntryMap().keySet().iterator(); + while (it.hasNext()) { + String key = it.next().toString(); + ASTCNNTrainNode astTrainEntryNode = (ASTCNNTrainNode) configuration.getEntryMap().get(key).getAstNode().get(); + astTrainEntryNode.accept(trainParamSupportChecker); + } + it = configuration.getEntryMap().keySet().iterator(); + while (it.hasNext()) { + String key = it.next().toString(); + if (trainParamSupportChecker.getUnsupportedElemList().contains(key)) { + it.remove(); + } + } + } + + private void checkOptimizerParams(ConfigurationSymbol configuration){ + if (configuration.getOptimizer() != null) { + ASTOptimizerEntry astOptimizer = (ASTOptimizerEntry) configuration.getOptimizer().getAstNode().get(); + astOptimizer.accept(trainParamSupportChecker); + if (trainParamSupportChecker.getUnsupportedElemList().contains(trainParamSupportChecker.unsupportedOptFlag)) { + configuration.setOptimizer(null); + }else { + Iterator it = configuration.getOptimizer().getOptimizerParamMap().keySet().iterator(); + while (it.hasNext()) { + String key = it.next().toString(); + if (trainParamSupportChecker.getUnsupportedElemList().contains(key)) { + it.remove(); + } + } + } + } + } + + private static void quitGeneration(){ + Log.error("Code generation is aborted"); + System.exit(1); + } + + public String getInstanceName() { + String parsedInstanceName = this.instanceName.replace('.', '_').replace('[', '_').replace(']', '_'); + parsedInstanceName = parsedInstanceName.substring(0, 1).toLowerCase() + parsedInstanceName.substring(1); + return parsedInstanceName; + } + + public void setInstanceName(String instanceName) { + this.instanceName = instanceName; + } + + public String getGenerationTargetPath() { + if (generationTargetPath.charAt(generationTargetPath.length() - 1) != '/') { + this.generationTargetPath = generationTargetPath + "/"; + } + return generationTargetPath; + } + + public void setGenerationTargetPath(String generationTargetPath) { + this.generationTargetPath = generationTargetPath; + } + + public ConfigurationSymbol getConfigurationSymbol(Path modelsDirPath, String rootModelName) { + final ModelPath mp = new ModelPath(modelsDirPath); + GlobalScope scope = new GlobalScope(mp, new CNNTrainLanguage()); + Optional compilationUnit = scope.resolve(rootModelName, CNNTrainCompilationUnitSymbol.KIND); + if (!compilationUnit.isPresent()) { + Log.error("could not resolve training configuration " + rootModelName); + quitGeneration(); + } + setInstanceName(compilationUnit.get().getFullName()); + CNNTrainCocos.checkAll(compilationUnit.get()); + supportCheck(compilationUnit.get().getConfiguration()); + return compilationUnit.get().getConfiguration(); + } + + public abstract void generate(Path modelsDirPath, String rootModelNames); + + //check cocos with CNNTrainCocos.checkAll(configuration) before calling this method. + public abstract Map generateStrings(ConfigurationSymbol configuration); +} \ No newline at end of file diff --git a/src/main/java/de/monticore/lang/monticar/cnnarch/generator/ConfigurationData.java b/src/main/java/de/monticore/lang/monticar/cnnarch/generator/ConfigurationData.java new file mode 100644 index 0000000000000000000000000000000000000000..23f29db1b48424e1f85745d360efdd6615fe8235 --- /dev/null +++ b/src/main/java/de/monticore/lang/monticar/cnnarch/generator/ConfigurationData.java @@ -0,0 +1,99 @@ +package de.monticore.lang.monticar.cnnarch.generator; + +import de.monticore.lang.monticar.cnntrain._symboltable.*; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +public class ConfigurationData { + + ConfigurationSymbol configuration; + String instanceName; + + public ConfigurationData(ConfigurationSymbol configuration, String instanceName) { + this.configuration = configuration; + this.instanceName = instanceName; + } + + public ConfigurationSymbol getConfiguration() { + return configuration; + } + + public String getInstanceName() { + return instanceName; + } + + public String getNumEpoch() { + if (!getConfiguration().getEntryMap().containsKey("num_epoch")) { + return null; + } + return String.valueOf(getConfiguration().getEntry("num_epoch").getValue()); + } + + public String getBatchSize() { + if (!getConfiguration().getEntryMap().containsKey("batch_size")) { + return null; + } + return String.valueOf(getConfiguration().getEntry("batch_size") .getValue()); + } + + public Boolean getLoadCheckpoint() { + if (!getConfiguration().getEntryMap().containsKey("load_checkpoint")) { + return null; + } + return (Boolean) getConfiguration().getEntry("load_checkpoint").getValue().getValue(); + } + + public Boolean getNormalize() { + if (!getConfiguration().getEntryMap().containsKey("normalize")) { + return null; + } + return (Boolean) getConfiguration().getEntry("normalize").getValue().getValue(); + } + + public String getContext() { + if (!getConfiguration().getEntryMap().containsKey("context")) { + return null; + } + return getConfiguration().getEntry("context").getValue().toString(); + } + + public String getEvalMetric() { + if (!getConfiguration().getEntryMap().containsKey("eval_metric")) { + return null; + } + return getConfiguration().getEntry("eval_metric").getValue().toString(); + } + + public String getOptimizerName() { + if (getConfiguration().getOptimizer() == null) { + return null; + } + return getConfiguration().getOptimizer().getName(); + } + + public Map getOptimizerParams() { + // get classes for single enum values + List lrPolicyClasses = new ArrayList<>(); + for (LRPolicy enum_value: LRPolicy.values()) { + lrPolicyClasses.add(enum_value.getClass()); + } + + Map mapToStrings = new HashMap<>(); + Map optimizerParams = getConfiguration().getOptimizer().getOptimizerParamMap(); + for (Map.Entry entry : optimizerParams.entrySet()) { + String paramName = entry.getKey(); + String valueAsString = entry.getValue().toString(); + Class realClass = entry.getValue().getValue().getValue().getClass(); + if (realClass == Boolean.class) { + valueAsString = (Boolean) entry.getValue().getValue().getValue() ? "True" : "False"; + } else if (lrPolicyClasses.contains(realClass)) { + valueAsString = "'" + valueAsString + "'"; + } + mapToStrings.put(paramName, valueAsString); + } + return mapToStrings; + } +} diff --git a/src/main/java/de/monticore/lang/monticar/cnnarch/generator/DataPathConfigParser.java b/src/main/java/de/monticore/lang/monticar/cnnarch/generator/DataPathConfigParser.java new file mode 100644 index 0000000000000000000000000000000000000000..3a034b0ffd268b2f7c8c0c5cd4b68bd2ea6de466 --- /dev/null +++ b/src/main/java/de/monticore/lang/monticar/cnnarch/generator/DataPathConfigParser.java @@ -0,0 +1,66 @@ +/** + * + * ****************************************************************************** + * MontiCAR Modeling Family, www.se-rwth.de + * Copyright (c) 2017, Software Engineering Group at RWTH Aachen, + * All rights reserved. + * + * This project is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 3.0 of the License, or (at your option) any later version. + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this project. If not, see . + * ******************************************************************************* + */ +package de.monticore.lang.monticar.cnnarch.generator; + +import de.se_rwth.commons.logging.Log; + +import java.io.*; +import java.net.URL; +import java.util.Objects; +import java.util.Properties; + +public class DataPathConfigParser{ + + private String configTargetPath; + private String configFileName; + private Properties properties; + + public DataPathConfigParser(String configPath) { + setConfigPath(configPath); + properties = new Properties(); + try + { + properties.load(new FileInputStream(configTargetPath)); + } catch(IOException e) + { + Log.error("Config file " + configPath + " could not be found"); + } + } + + public String getConfigPath() { + if (configTargetPath.charAt(configTargetPath.length() - 1) != '/') { + this.configTargetPath = configTargetPath + "/"; + } + return configTargetPath; + } + + public void setConfigPath(String configTargetPath){ + this.configTargetPath = configTargetPath; + } + + public String getDataPath(String modelName) { + String path = properties.getProperty(modelName); + if(path == null) { + Log.error("Data path config file did not specify a path for component '" + modelName + "'"); + } + return path; + } +} diff --git a/src/main/java/de/monticore/lang/monticar/cnnarch/generator/GenericCNNArchCli.java b/src/main/java/de/monticore/lang/monticar/cnnarch/generator/GenericCNNArchCli.java new file mode 100644 index 0000000000000000000000000000000000000000..a02e2a795781ce36bc91a42c3d3fbe2a01e26838 --- /dev/null +++ b/src/main/java/de/monticore/lang/monticar/cnnarch/generator/GenericCNNArchCli.java @@ -0,0 +1,83 @@ +package de.monticore.lang.monticar.cnnarch.generator; + +import de.se_rwth.commons.logging.Log; +import org.apache.commons.cli.*; + +import java.nio.file.Path; +import java.nio.file.Paths; + +/** + * + */ +public class GenericCNNArchCli { + private final CNNArchGenerator cnnArchGenerator; + + public static final Option OPTION_MODELS_PATH = Option.builder("m") + .longOpt("models-dir") + .desc("full path to the directory with the CNNArch model") + .hasArg(true) + .required(true) + .build(); + + public static final Option OPTION_ROOT_MODEL = Option.builder("r") + .longOpt("root-model") + .desc("name of the architecture") + .hasArg(true) + .required(true) + .build(); + + public static final Option OPTION_OUTPUT_PATH = Option.builder("o") + .longOpt("output-dir") + .desc("full path to output directory for tests") + .hasArg(true) + .required(false) + .build(); + + public GenericCNNArchCli(CNNArchGenerator cnnArchGenerator) { + this.cnnArchGenerator = cnnArchGenerator; + } + + public void run(String[] args) { + Options options = getOptions(); + CommandLineParser parser = new DefaultParser(); + CommandLine cliArgs = parseArgs(options, parser, args); + if (cliArgs != null) { + runGenerator(cliArgs); + } + } + + private Options getOptions() { + Options options = new Options(); + options.addOption(OPTION_MODELS_PATH); + options.addOption(OPTION_ROOT_MODEL); + options.addOption(OPTION_OUTPUT_PATH); + return options; + } + + private CommandLine parseArgs(Options options, CommandLineParser parser, String[] args) { + CommandLine cliArgs; + try { + cliArgs = parser.parse(options, args); + } catch (ParseException e) { + Log.error("argument parsing exception: " + e.getMessage()); + quitGeneration(); + return null; + } + return cliArgs; + } + + private void quitGeneration(){ + Log.error("Code generation is aborted"); + System.exit(1); + } + + private void runGenerator(CommandLine cliArgs) { + Path modelsDirPath = Paths.get(cliArgs.getOptionValue(OPTION_MODELS_PATH.getOpt())); + String rootModelName = cliArgs.getOptionValue(OPTION_ROOT_MODEL.getOpt()); + String outputPath = cliArgs.getOptionValue(OPTION_OUTPUT_PATH.getOpt()); + if (outputPath != null){ + cnnArchGenerator.setGenerationTargetPath(outputPath); + } + cnnArchGenerator.generate(modelsDirPath, rootModelName); + } +} \ No newline at end of file diff --git a/src/main/java/de/monticore/lang/monticar/cnnarch/generator/LayerNameCreator.java b/src/main/java/de/monticore/lang/monticar/cnnarch/generator/LayerNameCreator.java new file mode 100644 index 0000000000000000000000000000000000000000..63eeb851b45784bf2dcf5ac414d320b8c29d9f44 --- /dev/null +++ b/src/main/java/de/monticore/lang/monticar/cnnarch/generator/LayerNameCreator.java @@ -0,0 +1,151 @@ +/** + * + * ****************************************************************************** + * MontiCAR Modeling Family, www.se-rwth.de + * Copyright (c) 2017, Software Engineering Group at RWTH Aachen, + * All rights reserved. + * + * This project is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 3.0 of the License, or (at your option) any later version. + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this project. If not, see . + * ******************************************************************************* + */ +package de.monticore.lang.monticar.cnnarch.generator; + +import de.monticore.lang.monticar.cnnarch._symboltable.*; +import de.monticore.lang.monticar.cnnarch.predefined.Convolution; +import de.monticore.lang.monticar.cnnarch.predefined.FullyConnected; +import de.monticore.lang.monticar.cnnarch.predefined.Pooling; + +import java.util.*; + +public class LayerNameCreator { + + private Map elementToName = new HashMap<>(); + private Map nameToElement = new HashMap<>(); + + public LayerNameCreator(ArchitectureSymbol architecture) { + int stage = 1; + for (SerialCompositeElementSymbol stream : architecture.getStreams()) { + stage = name(stream, stage, new ArrayList<>()); + } + } + + public ArchitectureElementSymbol getArchitectureElement(String name){ + return nameToElement.get(name); + } + + public String getName(ArchitectureElementSymbol architectureElement){ + return elementToName.get(architectureElement); + } + + protected int name(ArchitectureElementSymbol architectureElement, int stage, List streamIndices){ + if (architectureElement instanceof SerialCompositeElementSymbol) { + return nameSerialComposite((SerialCompositeElementSymbol) architectureElement, stage, streamIndices); + } else if (architectureElement instanceof ParallelCompositeElementSymbol){ + return nameParallelComposite((ParallelCompositeElementSymbol) architectureElement, stage, streamIndices); + } else{ + if (architectureElement.isAtomic()){ + if (architectureElement.getMaxSerialLength().get() > 0){ + return add(architectureElement, stage, streamIndices); + } else { + return stage; + } + } else { + ArchitectureElementSymbol resolvedElement = architectureElement.getResolvedThis().get(); + return name(resolvedElement, stage, streamIndices); + } + } + } + + protected int nameSerialComposite(SerialCompositeElementSymbol compositeElement, int stage, List streamIndices){ + int endStage = stage; + for (ArchitectureElementSymbol subElement : compositeElement.getElements()){ + endStage = name(subElement, endStage, streamIndices); + } + return endStage; + } + + protected int nameParallelComposite(ParallelCompositeElementSymbol compositeElement, int stage, List streamIndices){ + int startStage = stage + 1; + streamIndices.add(1); + int lastIndex = streamIndices.size() - 1; + + List endStages = new ArrayList<>(); + for (ArchitectureElementSymbol subElement : compositeElement.getElements()){ + endStages.add(name(subElement, startStage, streamIndices)); + streamIndices.set(lastIndex, streamIndices.get(lastIndex) + 1); + } + + streamIndices.remove(lastIndex); + return Collections.max(endStages) + 1; + } + + protected int add(ArchitectureElementSymbol architectureElement, int stage, List streamIndices){ + int endStage = stage; + if (!elementToName.containsKey(architectureElement)) { + String name = createName(architectureElement, endStage, streamIndices); + + while (nameToElement.containsKey(name)) { + endStage++; + name = createName(architectureElement, endStage, streamIndices); + } + + elementToName.put(architectureElement, name); + nameToElement.put(name, architectureElement); + } + return endStage; + } + + protected String createName(ArchitectureElementSymbol architectureElement, int stage, List streamIndices){ + if (architectureElement instanceof IOSymbol){ + String name = createBaseName(architectureElement); + IOSymbol ioElement = (IOSymbol) architectureElement; + if (ioElement.getArrayAccess().isPresent()){ + int arrayAccess = ioElement.getArrayAccess().get().getIntValue().get(); + name = name + "_" + arrayAccess + "_"; + } + return name; + } else { + return createBaseName(architectureElement) + stage + createStreamPostfix(streamIndices) + "_"; + } + } + + + protected String createBaseName(ArchitectureElementSymbol architectureElement){ + if (architectureElement instanceof LayerSymbol) { + LayerDeclarationSymbol layerDeclaration = ((LayerSymbol) architectureElement).getDeclaration(); + if (layerDeclaration instanceof Convolution) { + return "conv"; + } else if (layerDeclaration instanceof FullyConnected) { + return "fc"; + } else if (layerDeclaration instanceof Pooling) { + return "pool"; + } else { + return layerDeclaration.getName().toLowerCase(); + } + } else if (architectureElement instanceof CompositeElementSymbol){ + return "group"; + } else { + return architectureElement.getName(); + } + } + + protected String createStreamPostfix(List streamIndices){ + StringBuilder stringBuilder = new StringBuilder(); + for (int streamIndex : streamIndices){ + stringBuilder.append("_"); + stringBuilder.append(streamIndex); + } + return stringBuilder.toString(); + } +} + diff --git a/src/main/java/de/monticore/lang/monticar/cnnarch/generator/LayerSupportChecker.java b/src/main/java/de/monticore/lang/monticar/cnnarch/generator/LayerSupportChecker.java new file mode 100644 index 0000000000000000000000000000000000000000..ace5288e3db2d0bf08cb2ec74e55ec28772c3b93 --- /dev/null +++ b/src/main/java/de/monticore/lang/monticar/cnnarch/generator/LayerSupportChecker.java @@ -0,0 +1,70 @@ +package de.monticore.lang.monticar.cnnarch.generator; + +import de.monticore.lang.monticar.cnnarch.predefined.AllPredefinedLayers; +import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureElementSymbol; +import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureSymbol; +import de.monticore.lang.monticar.cnnarch._symboltable.CompositeElementSymbol; +import de.monticore.lang.monticar.cnnarch._symboltable.ConstantSymbol; +import de.monticore.lang.monticar.cnnarch._symboltable.LayerDeclarationSymbol; +import de.monticore.lang.monticar.cnnarch._symboltable.LayerSymbol; +import de.se_rwth.commons.logging.Log; + +import java.util.ArrayList; +import java.util.List; + +public abstract class LayerSupportChecker { + + protected List supportedLayerList = new ArrayList<>(); + + private boolean isSupportedLayer(ArchitectureElementSymbol element){ + ArchitectureElementSymbol resolvedElement = element.getResolvedThis().get(); + List constructLayerElemList; + + if (resolvedElement instanceof CompositeElementSymbol) { + constructLayerElemList = ((CompositeElementSymbol) resolvedElement).getElements(); + for (ArchitectureElementSymbol constructedLayerElement : constructLayerElemList) { + if (!isSupportedLayer(constructedLayerElement)) { + return false; + } + } + return true; + } + + // Support all inputs and outputs + if (resolvedElement.isInput() || resolvedElement.isOutput()) { + return true; + } + + // Support for constants is checked in ArchitectureSupportChecker + if (resolvedElement instanceof ConstantSymbol) { + return true; + } + + // Support all layer declarations + if (resolvedElement instanceof LayerSymbol) { + if (!((LayerSymbol) resolvedElement).getDeclaration().isPredefined()) { + return true; + } + } + + if (!supportedLayerList.contains(element.toString())) { + Log.error("Unsupported layer " + "'" + element.getName() + "'" + " for the current backend."); + return false; + } else { + return true; + } + } + + public boolean check(ArchitectureSymbol architecture) { + for (CompositeElementSymbol stream : architecture.getStreams()) { + for (ArchitectureElementSymbol element : stream.getElements()) { + if (!isSupportedLayer(element)) { + return false; + } + } + } + + return true; + } + +} diff --git a/src/main/java/de/monticore/lang/monticar/cnnarch/generator/Target.java b/src/main/java/de/monticore/lang/monticar/cnnarch/generator/Target.java new file mode 100644 index 0000000000000000000000000000000000000000..2504df3afbb9d4e73736774cfd68948b07d4d32e --- /dev/null +++ b/src/main/java/de/monticore/lang/monticar/cnnarch/generator/Target.java @@ -0,0 +1,37 @@ +/** + * + * ****************************************************************************** + * MontiCAR Modeling Family, www.se-rwth.de + * Copyright (c) 2017, Software Engineering Group at RWTH Aachen, + * All rights reserved. + * + * This project is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 3.0 of the License, or (at your option) any later version. + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this project. If not, see . + * ******************************************************************************* + */ +package de.monticore.lang.monticar.cnnarch.generator; + +//can be removed +public enum Target { + PYTHON{ + @Override + public String toString() { + return ".py"; + } + }, + CPP{ + @Override + public String toString() { + return ".h"; + } + } +} diff --git a/src/main/java/de/monticore/lang/monticar/cnnarch/generator/TemplateConfiguration.java b/src/main/java/de/monticore/lang/monticar/cnnarch/generator/TemplateConfiguration.java new file mode 100644 index 0000000000000000000000000000000000000000..dfb8a964031f741b0e1209a02656e387b4fede87 --- /dev/null +++ b/src/main/java/de/monticore/lang/monticar/cnnarch/generator/TemplateConfiguration.java @@ -0,0 +1,73 @@ +/** + * + * ****************************************************************************** + * MontiCAR Modeling Family, www.se-rwth.de + * Copyright (c) 2017, Software Engineering Group at RWTH Aachen, + * All rights reserved. + * + * This project is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 3.0 of the License, or (at your option) any later version. + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this project. If not, see . + * ******************************************************************************* + */ +package de.monticore.lang.monticar.cnnarch.generator; + +import de.se_rwth.commons.logging.Log; +import freemarker.template.Configuration; +import freemarker.template.Template; +import freemarker.template.TemplateException; +import freemarker.template.TemplateExceptionHandler; + +import java.io.IOException; +import java.io.StringWriter; +import java.io.Writer; +import java.util.Map; + +public abstract class TemplateConfiguration { + abstract protected String getBaseTemplatePackagePath(); + abstract public Configuration getConfiguration(); + + public TemplateConfiguration() { + + } + + protected Configuration createConfiguration() { + Configuration configuration = new Configuration(Configuration.VERSION_2_3_23); + configuration.setClassForTemplateLoading(TemplateConfiguration.class, getBaseTemplatePackagePath()); + configuration.setDefaultEncoding("UTF-8"); + configuration.setTemplateExceptionHandler(TemplateExceptionHandler.RETHROW_HANDLER); + return configuration; + } + + private void quitGeneration(){ + Log.error("Code generation is aborted"); + System.exit(1); + } + + public void processTemplate(Map ftlContext, String templatePath, Writer writer){ + try{ + Template template = getConfiguration().getTemplate(templatePath); + template.process(ftlContext, writer); + } catch (IOException e) { + Log.error("Freemarker could not find template " + templatePath + " :\n" + e.getMessage()); + quitGeneration(); + } catch (TemplateException e){ + Log.error("An exception occured in template " + templatePath + " :\n" + e.getMessage()); + quitGeneration(); + } + } + + public String processTemplate(Map ftlContext, String templatePath){ + StringWriter writer = new StringWriter(); + processTemplate(ftlContext, templatePath, writer); + return writer.toString(); + } +} diff --git a/src/main/java/de/monticore/lang/monticar/cnnarch/generator/TrainParamSupportChecker.java b/src/main/java/de/monticore/lang/monticar/cnnarch/generator/TrainParamSupportChecker.java new file mode 100644 index 0000000000000000000000000000000000000000..ad1e89d9997cc282e179b19234f29abfc7891bb4 --- /dev/null +++ b/src/main/java/de/monticore/lang/monticar/cnnarch/generator/TrainParamSupportChecker.java @@ -0,0 +1,94 @@ +package de.monticore.lang.monticar.cnnarch.generator; + +import de.monticore.lang.monticar.cnntrain._ast.*; +import de.monticore.lang.monticar.cnntrain._visitor.CNNTrainVisitor; +import de.se_rwth.commons.logging.Log; +import java.util.ArrayList; +import java.util.List; + +public abstract class TrainParamSupportChecker implements CNNTrainVisitor { + + protected List unsupportedElemList = new ArrayList<>(); + + protected void printUnsupportedEntryParam(String nodeName){ + Log.warn("Unsupported training parameter " + "'" + nodeName + "'" + " for the current backend. It will be ignored."); + } + + protected void printUnsupportedOptimizer(String nodeName){ + Log.warn("Unsupported optimizer parameter " + "'" + nodeName + "'" + " for the current backend. It will be ignored."); + } + + protected void printUnsupportedOptimizerParam(String nodeName){ + Log.warn("Unsupported training optimizer parameter " + "'" + nodeName + "'" + " for the current backend. It will be ignored."); + } + + public TrainParamSupportChecker() { + } + + public static final String unsupportedOptFlag = "unsupported_optimizer"; + + public List getUnsupportedElemList(){ + return this.unsupportedElemList; + } + + //Empty visit method denotes that the corresponding training parameter is supported. + //To set a training parameter as unsupported, add the corresponding node to the unsupportedElemList + public void visit(ASTNumEpochEntry node){} + + public void visit(ASTBatchSizeEntry node){} + + public void visit(ASTLoadCheckpointEntry node){} + + public void visit(ASTNormalizeEntry node){} + + public void visit(ASTTrainContextEntry node){} + + public void visit(ASTEvalMetricEntry node){} + + public void visit(ASTSGDOptimizer node){} + + public void visit(ASTAdamOptimizer node){} + + public void visit(ASTRmsPropOptimizer node){} + + public void visit(ASTAdaGradOptimizer node){} + + public void visit(ASTNesterovOptimizer node){} + + public void visit(ASTAdaDeltaOptimizer node){} + + public void visit(ASTLearningRateEntry node){} + + public void visit(ASTMinimumLearningRateEntry node){} + + public void visit(ASTWeightDecayEntry node){} + + public void visit(ASTLRDecayEntry node){} + + public void visit(ASTLRPolicyEntry node){} + + public void visit(ASTRescaleGradEntry node){} + + public void visit(ASTClipGradEntry node){} + + public void visit(ASTStepSizeEntry node){} + + public void visit(ASTMomentumEntry node){} + + public void visit(ASTBeta1Entry node){} + + public void visit(ASTBeta2Entry node){} + + public void visit(ASTEpsilonEntry node){} + + public void visit(ASTGamma1Entry node){} + + public void visit(ASTGamma2Entry node){} + + public void visit(ASTCenteredEntry node){} + + public void visit(ASTClipWeightsEntry node){} + + public void visit(ASTRhoEntry node){} + +} diff --git a/src/test/java/de/monticore/lang/monticar/cnnarch/generator/AbstractSymtabTest.java b/src/test/java/de/monticore/lang/monticar/cnnarch/generator/AbstractSymtabTest.java new file mode 100644 index 0000000000000000000000000000000000000000..bc02c1ad757b0b6660401a88ac4b44fcb74206a7 --- /dev/null +++ b/src/test/java/de/monticore/lang/monticar/cnnarch/generator/AbstractSymtabTest.java @@ -0,0 +1,132 @@ +/** + * + * ****************************************************************************** + * MontiCAR Modeling Family, www.se-rwth.de + * Copyright (c) 2017, Software Engineering Group at RWTH Aachen, + * All rights reserved. + * + * This project is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 3.0 of the License, or (at your option) any later version. + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this project. If not, see . + * ******************************************************************************* + */ +package de.monticore.lang.monticar.cnnarch.generator; + +import de.monticore.ModelingLanguageFamily; +import de.monticore.io.paths.ModelPath; +import de.monticore.lang.monticar.cnnarch._symboltable.CNNArchCompilationUnitSymbol; +import de.monticore.lang.monticar.cnnarch._symboltable.CNNArchLanguage; +import de.monticore.symboltable.GlobalScope; +import de.monticore.symboltable.Scope; +import org.junit.Assert; + +import java.io.File; +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.List; +import java.util.stream.Collectors; + +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; + +/** + * Common methods for symboltable tests + */ +public class AbstractSymtabTest { + + private static final String MODEL_PATH = "src/test/resources/"; + + protected static Scope createSymTab(String... modelPath) { + ModelingLanguageFamily fam = new ModelingLanguageFamily(); + + fam.addModelingLanguage(new CNNArchLanguage()); + + final ModelPath mp = new ModelPath(); + for (String m : modelPath) { + mp.addEntry(Paths.get(m)); + } + + return new GlobalScope(mp, fam); + } + + protected static CNNArchCompilationUnitSymbol getCompilationUnitSymbol(String modelPath, String model) { + Scope symTab = createSymTab(MODEL_PATH + modelPath); + CNNArchCompilationUnitSymbol comp = symTab. resolve( + model, CNNArchCompilationUnitSymbol.KIND).orElse(null); + assertNotNull("Could not resolve model " + model, comp); + + return comp; + } + + + + public static void checkFilesAreEqual(Path generationPath, Path resultsPath, List fileNames) { + for (String fileName : fileNames){ + File genFile = new File(generationPath.toString() + "/" + fileName); + File fileTarget = new File(resultsPath.toString() + "/" + fileName); + assertTrue(areBothFilesEqual(genFile, fileTarget)); + } + } + + public static boolean areBothFilesEqual(File file1, File file2) { + if (!file1.exists()) { + Assert.fail("file does not exist: " + file1.getAbsolutePath()); + return false; + } + if (!file2.exists()) { + Assert.fail("file does not exist: " + file2.getAbsolutePath()); + return false; + } + List lines1; + List lines2; + try { + lines1 = Files.readAllLines(file1.toPath()); + lines2 = Files.readAllLines(file2.toPath()); + } catch (IOException e) { + e.printStackTrace(); + Assert.fail("IO error: " + e.getMessage()); + return false; + } + lines1 = discardEmptyLines(lines1); + lines2 = discardEmptyLines(lines2); + if (lines1.size() != lines2.size()) { + Assert.fail( + "files have different number of lines: " + + file1.getAbsolutePath() + + " has " + lines1 + + " lines and " + file2.getAbsolutePath() + " has " + lines2 + " lines" + ); + return false; + } + int len = lines1.size(); + for (int i = 0; i < len; i++) { + String l1 = lines1.get(i); + String l2 = lines2.get(i); + Assert.assertEquals("files differ in " + i + " line: " + + file1.getAbsolutePath() + + " has " + l1 + + " and " + file2.getAbsolutePath() + " has " + l2, + l1, + l2 + ); + } + return true; + } + + private static List discardEmptyLines(List lines) { + return lines.stream() + .map(String::trim) + .filter(l -> !l.isEmpty()) + .collect(Collectors.toList()); + } +} diff --git a/src/test/java/de/monticore/lang/monticar/cnnarch/generator/DataPathConfigParserTest.java b/src/test/java/de/monticore/lang/monticar/cnnarch/generator/DataPathConfigParserTest.java new file mode 100644 index 0000000000000000000000000000000000000000..f3a4b133ececf1bcc193eaed4d1c95f2f4727935 --- /dev/null +++ b/src/test/java/de/monticore/lang/monticar/cnnarch/generator/DataPathConfigParserTest.java @@ -0,0 +1,66 @@ +/** + * + * ****************************************************************************** + * MontiCAR Modeling Family, www.se-rwth.de + * Copyright (c) 2017, Software Engineering Group at RWTH Aachen, + * All rights reserved. + * + * This project is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 3.0 of the License, or (at your option) any later version. + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this project. If not, see . + * ******************************************************************************* + */ +package de.monticore.lang.monticar.cnnarch.generator; + +import de.monticore.lang.monticar.cnnarch._symboltable.*; +import de.monticore.lang.monticar.cnnarch.predefined.AllPredefinedVariables; +import de.monticore.symboltable.Scope; +import de.se_rwth.commons.logging.Log; +import org.junit.Before; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; + +public class DataPathConfigParserTest extends AbstractSymtabTest { + + @Before + public void setUp() { + // ensure an empty log + Log.getFindings().clear(); + Log.enableFailQuick(false); + } + + @Test + public void testDataPathConfigParserValidComponent() { + DataPathConfigParser parser = new DataPathConfigParser("src/test/resources/architectures/data_paths.txt"); + + String data_path = parser.getDataPath("ComponentName"); + assertTrue("Wrong data path returned", data_path.equals("/path/to/training/data")); + } + + @Test + public void testDataPathConfigParserInvalidComponent() { + DataPathConfigParser parser = new DataPathConfigParser("src/test/resources/architectures/data_paths.txt"); + + String data_path = parser.getDataPath("NotExistingComponent"); + assertTrue("For not listed components, null should be returned", data_path == null); + assertTrue(Log.getFindings().size() == 1); + } + + @Test + public void testDataPathConfigParserInvalidPath() { + DataPathConfigParser parser = new DataPathConfigParser("invalid/path/data_paths.txt"); + + assertTrue(Log.getFindings().size() == 1); + } +} diff --git a/src/test/resources/architectures/data_paths.txt b/src/test/resources/architectures/data_paths.txt new file mode 100644 index 0000000000000000000000000000000000000000..a4c785fe203697c5410b48498cef45c4f788db60 --- /dev/null +++ b/src/test/resources/architectures/data_paths.txt @@ -0,0 +1 @@ +ComponentName /path/to/training/data \ No newline at end of file