diff --git a/.circleci/config.yml b/.circleci/config.yml
new file mode 100644
index 0000000000000000000000000000000000000000..9ec02adc5749985975f56d2151e96de4a98c4d4d
--- /dev/null
+++ b/.circleci/config.yml
@@ -0,0 +1,69 @@
+#
+#
+# ******************************************************************************
+# MontiCAR Modeling Family, www.se-rwth.de
+# Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
+# All rights reserved.
+#
+# This project is free software; you can redistribute it and/or
+# modify it under the terms of the GNU Lesser General Public
+# License as published by the Free Software Foundation; either
+# version 3.0 of the License, or (at your option) any later version.
+# This library is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+# Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public
+# License along with this project. If not, see .
+# *******************************************************************************
+#
+
+# Java Maven CircleCI 2.0 configuration file
+#
+# Check https://circleci.com/docs/2.0/language-java/ for more details
+#
+version: 2
+general:
+ branches:
+ ignore:
+ - gh-pages
+
+jobs:
+ build:
+ docker:
+ # specify the version you desire here
+ - image: circleci/openjdk:8-jdk
+
+ # Specify service dependencies here if necessary
+ # CircleCI maintains a library of pre-built images
+ # documented at https://circleci.com/docs/2.0/circleci-images/
+ # - image: circleci/postgres:9.4
+
+ working_directory: ~/repo
+
+ environment:
+ # Customize the JVM maximum heap limit
+ MAVEN_OPTS: -Xmx3200m
+
+ steps:
+ - checkout
+
+ # run tests!
+ - run: mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B clean install --settings "settings.xml"
+workflows:
+ version: 2
+ commit-workflow:
+ jobs:
+ - build
+ scheduled-workflow:
+ triggers:
+ - schedule:
+ cron: "30 1 * * *"
+ filters:
+ branches:
+ only: master
+
+ jobs:
+ - build
+
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..11780027328480bc4bac278fc1025f1fac376079
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,10 @@
+target
+nppBackup
+.project
+.settings
+.classpath
+.idea
+.git
+
+*.iml
+
diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml
new file mode 100644
index 0000000000000000000000000000000000000000..a45cd7c87301d8847cb1237464de25df85737f9b
--- /dev/null
+++ b/.gitlab-ci.yml
@@ -0,0 +1,50 @@
+#
+#
+# ******************************************************************************
+# MontiCAR Modeling Family, www.se-rwth.de
+# Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
+# All rights reserved.
+#
+# This project is free software; you can redistribute it and/or
+# modify it under the terms of the GNU Lesser General Public
+# License as published by the Free Software Foundation; either
+# version 3.0 of the License, or (at your option) any later version.
+# This library is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+# Lesser General Public License for more details.
+#
+# You should have received a copy of the GNU Lesser General Public
+# License along with this project. If not, see .
+# *******************************************************************************
+#
+
+stages:
+- windows
+- linux
+
+masterJobLinux:
+ stage: linux
+ image: maven:3-jdk-8
+ script:
+ - mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B clean deploy --settings settings.xml
+ - cat target/site/jacoco/index.html
+ - mvn package sonar:sonar -s settings.xml
+ only:
+ - master
+
+masterJobWindows:
+ stage: windows
+ script:
+ - mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B clean install --settings settings.xml
+ tags:
+ - Windows10
+
+BranchJobLinux:
+ stage: linux
+ image: maven:3-jdk-8
+ script:
+ - mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B clean install --settings settings.xml
+ - cat target/site/jacoco/index.html
+ except:
+ - master
diff --git a/.travis.yml b/.travis.yml
new file mode 100644
index 0000000000000000000000000000000000000000..5b54b0cf104b040c3fa25918a29d2da17626f6e0
--- /dev/null
+++ b/.travis.yml
@@ -0,0 +1,5 @@
+script:
+- git checkout ${TRAVIS_BRANCH}
+- mvn clean install cobertura:cobertura org.eluder.coveralls:coveralls-maven-plugin:report --settings "settings.xml"
+after_success:
+- if [ "${TRAVIS_BRANCH}" == "master" ]; then mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B deploy --debug --settings "./settings.xml"; fi
diff --git a/README.md b/README.md
index fdedc33a4caa509cbcd70f7f4ab1f390ee3d6d57..645554bb479a1d17718eb59978e7cfd737a842ec 100644
--- a/README.md
+++ b/README.md
@@ -1,2 +1,2 @@
-# CNNArch2X
-
+
+
diff --git a/pom.xml b/pom.xml
new file mode 100644
index 0000000000000000000000000000000000000000..75e7a7fbbe01dc2f371c911e9917ff27ae4d6b2e
--- /dev/null
+++ b/pom.xml
@@ -0,0 +1,254 @@
+
+ 4.0.0
+
+
+
+
+ de.monticore.lang.monticar
+ cnnarch-generator
+ 0.0.1-SNAPSHOT
+
+
+
+
+
+
+ 0.3.1-SNAPSHOT
+ 0.3.2-SNAPSHOT
+ 0.1.4
+
+
+ 18.0
+ 4.12
+ 1.1.2
+ 4.3.1
+
+
+ 2.5.4
+ 3.3
+ 2.4
+ 2.4.3
+ 0.8.1
+
+
+ grammars
+ cli
+
+
+ 1.8
+
+ UTF-8
+ UTF-8
+
+
+
+
+
+ org.antlr
+ antlr4-runtime
+ 4.7.1
+
+
+
+ com.google.guava
+ guava
+ ${guava.version}
+
+
+
+
+
+ de.monticore.lang.monticar
+ cnn-arch
+ ${CNNArch.version}
+
+
+
+ de.monticore.lang.monticar
+ cnn-arch
+ ${CNNArch.version}
+ ${grammars.classifier}
+ provided
+
+
+
+ de.monticore.lang.monticar
+ cnn-train
+ ${CNNTrain.version}
+
+
+
+ de.monticore.lang.monticar
+ cnn-train
+ ${CNNTrain.version}
+ ${grammars.classifier}
+ provided
+
+
+
+ de.monticore.lang.monticar
+ embedded-montiarc-math-opt-generator
+ ${embedded-montiarc-math-opt-generator}
+
+
+
+
+
+ junit
+ junit
+ ${junit.version}
+ test
+
+
+
+ com.github.stefanbirkner
+ system-rules
+ 1.3.0
+
+
+
+ ch.qos.logback
+ logback-classic
+ ${logback.version}
+
+
+
+ org.jscience
+ jscience
+ ${jscience.version}
+
+
+
+
+
+
+
+
+
+
+ maven-deploy-plugin
+ 2.8.1
+
+
+
+ org.jacoco
+ jacoco-maven-plugin
+ ${jacoco.plugin}
+
+
+ pre-unit-test
+
+ prepare-agent
+
+
+
+ post-unit-test
+ test
+
+ report
+
+
+
+
+
+
+
+
+ maven-compiler-plugin
+ ${compiler.plugin}
+
+ true
+ ${java.version}
+ ${java.version}
+
+
+
+
+ maven-assembly-plugin
+ 3.1.0
+
+
+ jar-with-dependencies
+ package
+
+ single
+
+
+
+
+ de.monticore.lang.monticar.cnnarch.mxnetgenerator.CNNArch2MxNetCli
+
+
+
+ jar-with-dependencies
+
+
+
+
+
+
+
+
+ org.apache.maven.plugins
+ maven-source-plugin
+ ${source.plugin}
+
+
+ create source jar
+ package
+
+ jar-no-fork
+
+
+ false
+
+ **/*.java
+ **/*.ftl
+
+
+
+
+
+
+ org.apache.maven.plugins
+ maven-surefire-plugin
+ 2.19.1
+
+ false
+
+
+
+ org.eluder.coveralls
+ coveralls-maven-plugin
+ 4.3.0
+
+
+
+
+ org.codehaus.mojo
+ cobertura-maven-plugin
+ 2.7
+
+ xml
+ 256m
+
+ true
+
+
+
+
+
+
+
+ se-nexus
+ https://nexus.se.rwth-aachen.de/content/repositories/embeddedmontiarc-releases/
+
+
+ se-nexus
+ https://nexus.se.rwth-aachen.de/content/repositories/embeddedmontiarc-snapshots/
+
+
+
+
diff --git a/settings.xml b/settings.xml
new file mode 100644
index 0000000000000000000000000000000000000000..61887dffdaabd02ba790fb8620b8fcb11fc81c87
--- /dev/null
+++ b/settings.xml
@@ -0,0 +1,103 @@
+
+
+
+
+
+
+ org.mortbay.jetty
+ de.topobyte
+
+
+
+
+
+
+
+ se-nexus
+ cibuild
+ ${env.cibuild}
+
+
+
+ github
+ travisbuilduser
+ ${env.travisbuilduserpassword}
+
+
+
+
+
+ se-nexus
+ external:*
+ https://nexus.se.rwth-aachen.de/content/groups/public
+
+
+
+
+
+ sonar
+
+ true
+
+
+
+
+ https://metric.se.rwth-aachen.de
+
+
+ jenkins
+
+
+ ${env.sonar}
+
+
+
+
+
+ se-nexus
+
+
+ central
+ http://central
+
+
+
+
+
+
+ central
+ http://central
+
+
+
+
+
+
+
+
+ se-nexus
+
+
diff --git a/src/license/se/license.txt b/src/license/se/license.txt
new file mode 100644
index 0000000000000000000000000000000000000000..27ba0853ebb009d0220b03526fe684ebf1474eb3
--- /dev/null
+++ b/src/license/se/license.txt
@@ -0,0 +1,18 @@
+
+ ******************************************************************************
+ MontiCAR Modeling Family, www.se-rwth.de
+ Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
+ All rights reserved.
+
+ This project is free software; you can redistribute it and/or
+ modify it under the terms of the GNU Lesser General Public
+ License as published by the Free Software Foundation; either
+ version 3.0 of the License, or (at your option) any later version.
+ This library is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+ Lesser General Public License for more details.
+
+ You should have received a copy of the GNU Lesser General Public
+ License along with this project. If not, see .
+*******************************************************************************
diff --git a/src/main/java/de/monticore/lang/monticar/cnnarch/generator/ArchitectureElementData.java b/src/main/java/de/monticore/lang/monticar/cnnarch/generator/ArchitectureElementData.java
new file mode 100644
index 0000000000000000000000000000000000000000..a1e2fc2a33e3ff73e9d1b5e769a48905d7d9cec5
--- /dev/null
+++ b/src/main/java/de/monticore/lang/monticar/cnnarch/generator/ArchitectureElementData.java
@@ -0,0 +1,203 @@
+/**
+ *
+ * ******************************************************************************
+ * MontiCAR Modeling Family, www.se-rwth.de
+ * Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
+ * All rights reserved.
+ *
+ * This project is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License as published by the Free Software Foundation; either
+ * version 3.0 of the License, or (at your option) any later version.
+ * This library is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with this project. If not, see .
+ * *******************************************************************************
+ */
+package de.monticore.lang.monticar.cnnarch.generator;
+
+import de.monticore.lang.monticar.cnnarch._symboltable.ArchTypeSymbol;
+import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureElementSymbol;
+import de.monticore.lang.monticar.cnnarch._symboltable.ConstantSymbol;
+import de.monticore.lang.monticar.cnnarch._symboltable.LayerSymbol;
+import de.monticore.lang.monticar.cnnarch.predefined.AllPredefinedLayers;
+import de.se_rwth.commons.logging.Log;
+
+import javax.annotation.Nullable;
+import java.util.Arrays;
+import java.util.List;
+
+public class ArchitectureElementData {
+
+ private String name;
+ private ArchitectureElementSymbol element;
+ private CNNArchTemplateController templateController;
+
+ public ArchitectureElementData(String name, ArchitectureElementSymbol element, CNNArchTemplateController templateController) {
+ this.name = name;
+ this.element = element;
+ this.templateController = templateController;
+ }
+
+ public String getName() {
+ return name;
+ }
+
+ public void setName(String name) {
+ this.name = name;
+ }
+
+ public ArchitectureElementSymbol getElement() {
+ return element;
+ }
+
+ public void setElement(ArchitectureElementSymbol element) {
+ this.element = element;
+ }
+
+ public CNNArchTemplateController getTemplateController() {
+ return templateController;
+ }
+
+ public void setTemplateController(CNNArchTemplateController templateController) {
+ this.templateController = templateController;
+ }
+
+ public List getInputs(){
+ return getTemplateController().getLayerInputs(getElement());
+ }
+
+ public boolean isLogisticRegressionOutput(){
+ return getTemplateController().isLogisticRegressionOutput(getElement());
+ }
+
+
+ public boolean isLinearRegressionOutput(){
+ boolean result = getTemplateController().isLinearRegressionOutput(getElement());
+ if (result){
+ Log.warn("The Output '" + getElement().getName() + "' is a linear regression output (squared loss) during training" +
+ " because the previous architecture element is not a softmax (cross-entropy loss) or sigmoid (logistic regression loss) activation. " +
+ "Other loss functions are currently not supported. "
+ , getElement().getSourcePosition());
+ }
+ return result;
+ }
+
+ public boolean isSoftmaxOutput(){
+ return getTemplateController().isSoftmaxOutput(getElement());
+ }
+
+ public int getConstValue() {
+ ConstantSymbol constant = (ConstantSymbol) getElement();
+ return constant.getExpression().getIntValue().get();
+ }
+
+ public List getKernel(){
+ return ((LayerSymbol) getElement())
+ .getIntTupleValue(AllPredefinedLayers.KERNEL_NAME).get();
+ }
+
+ public int getChannels(){
+ return ((LayerSymbol) getElement())
+ .getIntValue(AllPredefinedLayers.CHANNELS_NAME).get();
+ }
+
+ public List getStride(){
+ return ((LayerSymbol) getElement())
+ .getIntTupleValue(AllPredefinedLayers.STRIDE_NAME).get();
+ }
+
+ public int getUnits(){
+ return ((LayerSymbol) getElement())
+ .getIntValue(AllPredefinedLayers.UNITS_NAME).get();
+ }
+
+ public boolean getNoBias(){
+ return ((LayerSymbol) getElement())
+ .getBooleanValue(AllPredefinedLayers.NOBIAS_NAME).get();
+ }
+
+ public double getP(){
+ return ((LayerSymbol) getElement())
+ .getDoubleValue(AllPredefinedLayers.P_NAME).get();
+ }
+
+ public int getIndex(){
+ return ((LayerSymbol) getElement())
+ .getIntValue(AllPredefinedLayers.INDEX_NAME).get();
+ }
+
+ public int getNumOutputs(){
+ return ((LayerSymbol) getElement())
+ .getIntValue(AllPredefinedLayers.NUM_SPLITS_NAME).get();
+ }
+
+ public boolean getFixGamma(){
+ return ((LayerSymbol) getElement())
+ .getBooleanValue(AllPredefinedLayers.FIX_GAMMA_NAME).get();
+ }
+
+ public int getNsize(){
+ return ((LayerSymbol) getElement())
+ .getIntValue(AllPredefinedLayers.NSIZE_NAME).get();
+ }
+
+ public double getKnorm(){
+ return ((LayerSymbol) getElement())
+ .getDoubleValue(AllPredefinedLayers.KNORM_NAME).get();
+ }
+
+ public double getAlpha(){
+ return ((LayerSymbol) getElement())
+ .getDoubleValue(AllPredefinedLayers.ALPHA_NAME).get();
+ }
+
+ public double getBeta(){
+ return ((LayerSymbol) getElement())
+ .getDoubleValue(AllPredefinedLayers.BETA_NAME).get();
+ }
+
+ public int getSize(){
+ return ((LayerSymbol) getElement())
+ .getIntValue(AllPredefinedLayers.SIZE_NAME).get();
+ }
+
+ @Nullable
+ public String getPoolType(){
+ return ((LayerSymbol) getElement())
+ .getStringValue(AllPredefinedLayers.POOL_TYPE_NAME).get();
+ }
+
+ @Nullable
+ public List getPadding(){
+ return getPadding((LayerSymbol) getElement());
+ }
+
+ @Nullable
+ public List getPadding(LayerSymbol layer){
+ List kernel = layer.getIntTupleValue(AllPredefinedLayers.KERNEL_NAME).get();
+ List stride = layer.getIntTupleValue(AllPredefinedLayers.STRIDE_NAME).get();
+ ArchTypeSymbol inputType = layer.getInputTypes().get(0);
+ ArchTypeSymbol outputType = layer.getOutputTypes().get(0);
+
+ int heightWithPad = kernel.get(0) + stride.get(0)*(outputType.getHeight() - 1);
+ int widthWithPad = kernel.get(1) + stride.get(1)*(outputType.getWidth() - 1);
+ int heightPad = Math.max(0, heightWithPad - inputType.getHeight());
+ int widthPad = Math.max(0, widthWithPad - inputType.getWidth());
+
+ int topPad = (int)Math.ceil(heightPad / 2.0);
+ int bottomPad = (int)Math.floor(heightPad / 2.0);
+ int leftPad = (int)Math.ceil(widthPad / 2.0);
+ int rightPad = (int)Math.floor(widthPad / 2.0);
+
+ if (topPad == 0 && bottomPad == 0 && leftPad == 0 && rightPad == 0){
+ return null;
+ }
+
+ return Arrays.asList(0,0,0,0,topPad,bottomPad,leftPad,rightPad);
+ }
+}
diff --git a/src/main/java/de/monticore/lang/monticar/cnnarch/generator/ArchitectureSupportChecker.java b/src/main/java/de/monticore/lang/monticar/cnnarch/generator/ArchitectureSupportChecker.java
new file mode 100644
index 0000000000000000000000000000000000000000..17c666b617163be3f5e5b4c28e8e467395e548a8
--- /dev/null
+++ b/src/main/java/de/monticore/lang/monticar/cnnarch/generator/ArchitectureSupportChecker.java
@@ -0,0 +1,104 @@
+package de.monticore.lang.monticar.cnnarch.generator;
+
+import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureSymbol;
+import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureElementSymbol;
+import de.monticore.lang.monticar.cnnarch._symboltable.CompositeElementSymbol;
+import de.monticore.lang.monticar.cnnarch._symboltable.ConstantSymbol;
+import de.monticore.lang.monticar.cnnarch._symboltable.SerialCompositeElementSymbol;
+import de.se_rwth.commons.logging.Log;
+
+import java.util.List;
+
+public abstract class ArchitectureSupportChecker {
+
+ // Overload functions returning always true to enable the features
+ protected boolean checkMultipleStreams(ArchitectureSymbol architecture) {
+ if (architecture.getStreams().size() != 1) {
+ Log.error("This cnn architecture has multiple instructions, " +
+ "which is currently not supported by the code generator. "
+ , architecture.getSourcePosition());
+
+ return false;
+ }
+
+ return true;
+ }
+
+ protected boolean checkMultipleInputs(ArchitectureSymbol architecture) {
+ if (architecture.getInputs().size() > 1) {
+ Log.error("This cnn architecture has multiple inputs, " +
+ "which is currently not supported by the code generator. "
+ , architecture.getSourcePosition());
+
+ return false;
+ }
+
+ return true;
+ }
+
+ protected boolean checkMultipleOutputs(ArchitectureSymbol architecture) {
+ if (architecture.getOutputs().size() > 1) {
+ Log.error("This cnn architecture has multiple outputs, " +
+ "which is currently not supported by the code generator. "
+ , architecture.getSourcePosition());
+
+ return false;
+ }
+
+ return true;
+ }
+
+ protected boolean checkMultiDimensionalOutput(ArchitectureSymbol architecture) {
+ if (architecture.getOutputs().get(0).getDefinition().getType().getWidth() != 1 ||
+ architecture.getOutputs().get(0).getDefinition().getType().getHeight() != 1) {
+ Log.error("This cnn architecture has a multi-dimensional output, " +
+ "which is currently not supported by the code generator."
+ , architecture.getSourcePosition());
+
+ return false;
+ }
+
+ return true;
+ }
+
+ protected boolean hasConstant(ArchitectureElementSymbol element) {
+ ArchitectureElementSymbol resolvedElement = element.getResolvedThis().get();
+
+ if (resolvedElement instanceof CompositeElementSymbol) {
+ List constructedElements = ((CompositeElementSymbol) resolvedElement).getElements();
+
+ for (ArchitectureElementSymbol constructedElement : constructedElements) {
+ if (hasConstant(constructedElement)) {
+ return true;
+ }
+ }
+ }
+ else if (resolvedElement instanceof ConstantSymbol) {
+ return true;
+ }
+
+ return false;
+ }
+
+ protected boolean checkConstants(ArchitectureSymbol architecture) {
+ for (SerialCompositeElementSymbol stream : architecture.getStreams()) {
+ for (ArchitectureElementSymbol element : stream.getElements()) {
+ if (hasConstant(element)) {
+ Log.error("This cnn architecture has a constant, which is currently not supported by the code generator."
+ , architecture.getSourcePosition());
+ return false;
+ }
+ }
+ }
+
+ return true;
+ }
+
+ public boolean check(ArchitectureSymbol architecture) {
+ return checkMultipleStreams(architecture)
+ && checkMultipleInputs(architecture)
+ && checkMultipleOutputs(architecture)
+ && checkMultiDimensionalOutput(architecture)
+ && checkConstants(architecture);
+ }
+}
diff --git a/src/main/java/de/monticore/lang/monticar/cnnarch/generator/CNNArchGenerator.java b/src/main/java/de/monticore/lang/monticar/cnnarch/generator/CNNArchGenerator.java
new file mode 100644
index 0000000000000000000000000000000000000000..04e97d6cbb9eb76e201079abb04888127c67ee29
--- /dev/null
+++ b/src/main/java/de/monticore/lang/monticar/cnnarch/generator/CNNArchGenerator.java
@@ -0,0 +1,143 @@
+/**
+ *
+ * ******************************************************************************
+ * MontiCAR Modeling Family, www.se-rwth.de
+ * Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
+ * All rights reserved.
+ *
+ * This project is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License as published by the Free Software Foundation; either
+ * version 3.0 of the License, or (at your option) any later version.
+ * This library is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with this project. If not, see .
+ * *******************************************************************************
+ */
+package de.monticore.lang.monticar.cnnarch.generator;
+
+import de.monticore.io.paths.ModelPath;
+import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureSymbol;
+import de.monticore.lang.monticar.cnnarch._symboltable.CNNArchLanguage;
+import de.monticore.lang.monticar.generator.cmake.CMakeConfig;
+import de.monticore.lang.monticar.generator.cmake.CMakeFindModule;
+import de.monticore.lang.monticar.generator.FileContent;
+import de.monticore.lang.monticar.generator.cpp.GeneratorCPP;
+import de.monticore.symboltable.GlobalScope;
+import de.monticore.symboltable.Scope;
+import de.se_rwth.commons.logging.Log;
+
+import java.io.IOException;
+import java.nio.file.Path;
+import java.util.HashMap;
+import java.util.Map;
+
+public abstract class CNNArchGenerator {
+
+ protected ArchitectureSupportChecker architectureSupportChecker;
+ protected LayerSupportChecker layerSupportChecker;
+
+ private String generationTargetPath;
+ private String modelsDirPath;
+
+ protected CNNArchGenerator() {
+ setGenerationTargetPath("./target/generated-sources-cnnarch/");
+ }
+
+ public static void quitGeneration(){
+ Log.error("Code generation is aborted");
+ System.exit(1);
+ }
+
+ public boolean isCMakeRequired() {
+ return true;
+ }
+
+ public String getGenerationTargetPath(){
+ if (generationTargetPath.charAt(generationTargetPath.length() - 1) != '/') {
+ this.generationTargetPath = generationTargetPath + "/";
+ }
+ return generationTargetPath;
+ }
+
+ public void setGenerationTargetPath(String generationTargetPath){
+ this.generationTargetPath = generationTargetPath;
+ }
+
+ protected String getModelsDirPath() {
+ return this.modelsDirPath;
+ }
+
+ public void generate(Path modelsDirPath, String rootModelName){
+ this.modelsDirPath = modelsDirPath.toString();
+ final ModelPath mp = new ModelPath(modelsDirPath);
+ GlobalScope scope = new GlobalScope(mp, new CNNArchLanguage());
+ generate(scope, rootModelName);
+ }
+
+ // TODO: Rewrite so that CNNArchSymbolCompiler is used in EMADL2CPP instead of this method
+ public boolean check(ArchitectureSymbol architecture) {
+ return architectureSupportChecker.check(architecture) && layerSupportChecker.check(architecture);
+ }
+
+ public void generate(Scope scope, String rootModelName){
+ CNNArchSymbolCompiler symbolCompiler = new CNNArchSymbolCompiler(architectureSupportChecker, layerSupportChecker);
+ ArchitectureSymbol architectureSymbol = symbolCompiler.compileArchitectureSymbol(scope, rootModelName);
+
+ try{
+ String confPath = getModelsDirPath() + "/data_paths.txt";
+ DataPathConfigParser newParserConfig = new DataPathConfigParser(confPath);
+ String dataPath = newParserConfig.getDataPath(rootModelName);
+ architectureSymbol.setDataPath(dataPath);
+ architectureSymbol.setComponentName(rootModelName);
+ generateFiles(architectureSymbol);
+ } catch (IOException e){
+ Log.error(e.toString());
+ }
+ }
+
+ //check cocos with CNNArchCocos.checkAll(architecture) before calling this method.
+ public abstract Map generateStrings(ArchitectureSymbol architecture);
+
+ //check cocos with CNNArchCocos.checkAll(architecture) before calling this method.
+ public void generateFiles(ArchitectureSymbol architecture) throws IOException{
+ Map fileContentMap = generateStrings(architecture);
+ generateFromFilecontentsMap(fileContentMap);
+ }
+
+ public void generateFromFilecontentsMap(Map fileContentMap) throws IOException {
+ GeneratorCPP genCPP = new GeneratorCPP();
+ genCPP.setGenerationTargetPath(getGenerationTargetPath());
+ for (String fileName : fileContentMap.keySet()){
+ genCPP.generateFile(new FileContent(fileContentMap.get(fileName), fileName));
+ }
+ }
+
+ public void generateCMake(String rootModelName){
+ Map fileContentMap = generateCMakeContent(rootModelName);
+ try {
+ generateFromFilecontentsMap(fileContentMap);
+ } catch (IOException e) {
+ Log.error("CMake file could not be generated" + e.getMessage());
+ }
+ }
+
+ public Map generateCMakeContent(String rootModelName) {
+ // model name should start with a lower case letter. If it is a component, replace dot . by _
+ rootModelName = rootModelName.replace('.', '_').replace('[', '_').replace(']', '_');
+ rootModelName = rootModelName.substring(0, 1).toLowerCase() + rootModelName.substring(1);
+
+ CMakeConfig cMakeConfig = new CMakeConfig(rootModelName);
+ cMakeConfig.addModuleDependency(new CMakeFindModule("Armadillo", true));
+ cMakeConfig.addCMakeCommand("set(LIBS ${LIBS} mxnet)");
+
+ Map fileContentMap = new HashMap<>();
+ for (FileContent fileContent : cMakeConfig.generateCMakeFiles()){
+ fileContentMap.put(fileContent.getFileName(), fileContent.getFileContent());
+ }
+ return fileContentMap;
+ }}
diff --git a/src/main/java/de/monticore/lang/monticar/cnnarch/generator/CNNArchSymbolCompiler.java b/src/main/java/de/monticore/lang/monticar/cnnarch/generator/CNNArchSymbolCompiler.java
new file mode 100644
index 0000000000000000000000000000000000000000..46fddd137a5aec44704769b9f2a8a50c20be9943
--- /dev/null
+++ b/src/main/java/de/monticore/lang/monticar/cnnarch/generator/CNNArchSymbolCompiler.java
@@ -0,0 +1,52 @@
+package de.monticore.lang.monticar.cnnarch.generator;
+
+import de.monticore.io.paths.ModelPath;
+import de.monticore.lang.monticar.cnnarch._cocos.CNNArchCocos;
+import de.monticore.lang.monticar.cnnarch._symboltable.*;
+import de.monticore.symboltable.GlobalScope;
+import de.monticore.symboltable.Scope;
+import de.se_rwth.commons.logging.Log;
+
+import java.nio.file.Path;
+import java.util.List;
+import java.util.Optional;
+
+public class CNNArchSymbolCompiler {
+ private final ArchitectureSupportChecker architectureSupportChecker;
+ private final LayerSupportChecker layerSupportChecker;
+
+ public CNNArchSymbolCompiler(final ArchitectureSupportChecker architectureSupportChecker,
+ final LayerSupportChecker layerSupportChecker) {
+ this.architectureSupportChecker = architectureSupportChecker;
+ this.layerSupportChecker = layerSupportChecker;
+ }
+
+ public ArchitectureSymbol compileArchitectureSymbolFromModelsDir(
+ final Path modelsDirPath, final String rootModel) {
+ ModelPath mp = new ModelPath(modelsDirPath);
+ GlobalScope scope = new GlobalScope(mp, new CNNArchLanguage());
+ return compileArchitectureSymbol(scope, rootModel);
+ }
+
+ public ArchitectureSymbol compileArchitectureSymbol(Scope scope, String rootModelName) {
+ Optional compilationUnit = scope.resolve(rootModelName, CNNArchCompilationUnitSymbol.KIND);
+ if (!compilationUnit.isPresent()){
+ failWithMessage("Could not resolve architecture " + rootModelName);
+ }
+
+ CNNArchCocos.checkAll(compilationUnit.get());
+
+ ArchitectureSymbol architecture = compilationUnit.get().getArchitecture();
+
+ if (!architectureSupportChecker.check(architecture) || !layerSupportChecker.check(architecture)) {
+ failWithMessage("Architecture not supported by generator");
+ }
+
+ return architecture;
+ }
+
+ private void failWithMessage(final String message) {
+ Log.error(message);
+ System.exit(1);
+ }
+}
diff --git a/src/main/java/de/monticore/lang/monticar/cnnarch/generator/CNNArchTemplateController.java b/src/main/java/de/monticore/lang/monticar/cnnarch/generator/CNNArchTemplateController.java
new file mode 100644
index 0000000000000000000000000000000000000000..ffbfc5be5d61134ecd2d88ef66008820fde13e12
--- /dev/null
+++ b/src/main/java/de/monticore/lang/monticar/cnnarch/generator/CNNArchTemplateController.java
@@ -0,0 +1,247 @@
+/**
+ *
+ * ******************************************************************************
+ * MontiCAR Modeling Family, www.se-rwth.de
+ * Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
+ * All rights reserved.
+ *
+ * This project is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License as published by the Free Software Foundation; either
+ * version 3.0 of the License, or (at your option) any later version.
+ * This library is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with this project. If not, see .
+ * *******************************************************************************
+ */
+package de.monticore.lang.monticar.cnnarch.generator;
+
+import de.monticore.lang.monticar.cnnarch._symboltable.*;
+import de.monticore.lang.monticar.cnnarch.predefined.Sigmoid;
+import de.monticore.lang.monticar.cnnarch.predefined.Softmax;
+
+import java.io.StringWriter;
+import java.io.Writer;
+import java.util.*;
+
+public abstract class CNNArchTemplateController {
+
+ public static final String FTL_FILE_ENDING = ".ftl";
+ public static final String TEMPLATE_ELEMENTS_DIR_PATH = "elements/";
+ public static final String TEMPLATE_CONTROLLER_KEY = "tc";
+ public static final String ELEMENT_DATA_KEY = "element";
+
+ private final TemplateConfiguration templateConfiguration;
+
+ private LayerNameCreator nameManager;
+ private ArchitectureSymbol architecture;
+
+ //temporary attributes. They are set after calling process()
+ private Writer writer;
+ private String mainTemplateNameWithoutEnding;
+ private Target targetLanguage;
+ private ArchitectureElementData dataElement;
+
+ protected CNNArchTemplateController(ArchitectureSymbol architecture, TemplateConfiguration templateConfiguration) {
+ setArchitecture(architecture);
+ this.templateConfiguration = templateConfiguration;
+ }
+
+ protected TemplateConfiguration getTemplateConfiguration() {
+ return templateConfiguration;
+ }
+
+ protected LayerNameCreator getNameManager() {
+ return nameManager;
+ }
+
+ protected void setNameManager(LayerNameCreator nameManager) {
+ this.nameManager = nameManager;
+ }
+
+ protected Writer getWriter() {
+ return writer;
+ }
+
+ protected void setWriter(Writer writer) {
+ this.writer = writer;
+ }
+
+ protected String getMainTemplateNameWithoutEnding() {
+ return mainTemplateNameWithoutEnding;
+ }
+
+ protected void setMainTemplateNameWithoutEnding(String mainTemplateNameWithoutEnding) {
+ this.mainTemplateNameWithoutEnding = mainTemplateNameWithoutEnding;
+ }
+
+ protected Target getTargetLanguage() {
+ return targetLanguage;
+ }
+
+ protected void setTargetLanguage(Target targetLanguage) {
+ this.targetLanguage = targetLanguage;
+ }
+
+ protected ArchitectureElementData getDataElement() {
+ return dataElement;
+ }
+
+ protected void setDataElement(ArchitectureElementData dataElement) {
+ this.dataElement = dataElement;
+ }
+
+ public String getFileNameWithoutEnding() {
+ return mainTemplateNameWithoutEnding + "_" + getFullArchitectureName();
+ }
+
+ public ArchitectureElementData getCurrentElement() {
+ return dataElement;
+ }
+
+ public void setCurrentElement(ArchitectureElementSymbol layer) {
+ this.dataElement = new ArchitectureElementData(getName(layer), layer, this);
+ }
+
+ public void setCurrentElement(ArchitectureElementData dataElement) {
+ this.dataElement = dataElement;
+ }
+
+ public ArchitectureSymbol getArchitecture() {
+ return architecture;
+ }
+
+ public void setArchitecture(ArchitectureSymbol architecture) {
+ this.architecture = architecture;
+ this.nameManager = new LayerNameCreator(architecture);
+ }
+
+ public String getName(ArchitectureElementSymbol layer){
+ return nameManager.getName(layer);
+ }
+
+ public String getArchitectureName(){
+ return getArchitecture().getEnclosingScope().getSpanningSymbol().get().getName().replaceAll("\\.","_");
+ }
+
+ public String getFullArchitectureName(){
+ return getArchitecture().getEnclosingScope().getSpanningSymbol().get().getFullName().replaceAll("\\.","_");
+ }
+
+ public String getDataPath(){
+ return getArchitecture().getDataPath();
+ }
+
+ public List getLayerInputs(ArchitectureElementSymbol layer){
+ List inputNames = new ArrayList<>();
+
+ if (isSoftmaxOutput(layer) || isLogisticRegressionOutput(layer)){
+ inputNames = getLayerInputs(layer.getInputElement().get());
+ } else {
+ for (ArchitectureElementSymbol input : layer.getPrevious()) {
+ if (input.getOutputTypes().size() == 1) {
+ inputNames.add(getName(input));
+ } else {
+ for (int i = 0; i < input.getOutputTypes().size(); i++) {
+ inputNames.add(getName(input) + "[" + i + "]");
+ }
+ }
+ }
+ }
+ return inputNames;
+
+ }
+
+ public List getArchitectureInputs(){
+ List list = new ArrayList<>();
+ for (IOSymbol ioElement : getArchitecture().getInputs()){
+ list.add(nameManager.getName(ioElement));
+ }
+ return list;
+ }
+
+ public List getArchitectureOutputs(){
+ List list = new ArrayList<>();
+ for (IOSymbol ioElement : getArchitecture().getOutputs()){
+ list.add(nameManager.getName(ioElement));
+ }
+ return list;
+ }
+
+ public String getComponentName(){
+ return getArchitecture().getComponentName();
+ }
+
+ public void include(String relativePath, String templateWithoutFileEnding, Writer writer){
+ String templatePath = relativePath + templateWithoutFileEnding + FTL_FILE_ENDING;
+ Map ftlContext = new HashMap<>();
+ ftlContext.put(TEMPLATE_CONTROLLER_KEY, this);
+ ftlContext.put(ELEMENT_DATA_KEY, getCurrentElement());
+ templateConfiguration.processTemplate(ftlContext, templatePath, writer);
+ }
+
+ public Map.Entry process(String templateNameWithoutEnding, Target targetLanguage){
+ StringWriter newWriter = new StringWriter();
+ this.mainTemplateNameWithoutEnding = templateNameWithoutEnding;
+ this.targetLanguage = targetLanguage;
+ this.writer = newWriter;
+
+ include("", templateNameWithoutEnding, newWriter);
+ String fileEnding = targetLanguage.toString();
+ String fileName = getFileNameWithoutEnding() + fileEnding;
+ Map.Entry fileContent = new AbstractMap.SimpleEntry<>(fileName, newWriter.toString());
+
+ this.mainTemplateNameWithoutEnding = null;
+ this.targetLanguage = null;
+ this.writer = null;
+ return fileContent;
+ }
+
+ public String join(Iterable iterable, String separator){
+ return join(iterable, separator, "", "");
+ }
+
+ public String join(Iterable iterable, String separator, String elementPrefix, String elementPostfix){
+ StringBuilder stringBuilder = new StringBuilder();
+ boolean isFirst = true;
+ for (Object element : iterable){
+ if (!isFirst){
+ stringBuilder.append(separator);
+ }
+ stringBuilder.append(elementPrefix);
+ stringBuilder.append(element.toString());
+ stringBuilder.append(elementPostfix);
+ isFirst = false;
+ }
+ return stringBuilder.toString();
+ }
+
+
+ public boolean isLogisticRegressionOutput(ArchitectureElementSymbol architectureElement){
+ return isTOutput(Sigmoid.class, architectureElement);
+ }
+
+ public boolean isLinearRegressionOutput(ArchitectureElementSymbol architectureElement){
+ return architectureElement.isOutput()
+ && !isLogisticRegressionOutput(architectureElement)
+ && !isSoftmaxOutput(architectureElement);
+ }
+
+ public boolean isSoftmaxOutput(ArchitectureElementSymbol architectureElement){
+ return isTOutput(Softmax.class, architectureElement);
+ }
+
+ private boolean isTOutput(Class inputPredefinedLayerClass, ArchitectureElementSymbol architectureElement){
+ if (architectureElement.isOutput()
+ && architectureElement.getInputElement().isPresent()
+ && architectureElement.getInputElement().get() instanceof LayerSymbol){
+ LayerSymbol inputLayer = (LayerSymbol) architectureElement.getInputElement().get();
+ return inputPredefinedLayerClass.isInstance(inputLayer.getDeclaration());
+ }
+ return false;
+ }
+}
diff --git a/src/main/java/de/monticore/lang/monticar/cnnarch/generator/CNNTrainGenerator.java b/src/main/java/de/monticore/lang/monticar/cnnarch/generator/CNNTrainGenerator.java
new file mode 100644
index 0000000000000000000000000000000000000000..63a6321ac862cddbf66381df1c203c980f81e312
--- /dev/null
+++ b/src/main/java/de/monticore/lang/monticar/cnnarch/generator/CNNTrainGenerator.java
@@ -0,0 +1,115 @@
+package de.monticore.lang.monticar.cnnarch.generator;
+
+import de.monticore.io.paths.ModelPath;
+import de.monticore.lang.monticar.cnnarch.generator.ConfigurationData;
+import de.monticore.lang.monticar.cnnarch.generator.TemplateConfiguration;
+import de.monticore.lang.monticar.cnntrain._ast.ASTCNNTrainNode;
+import de.monticore.lang.monticar.cnntrain._ast.ASTOptimizerEntry;
+import de.monticore.lang.monticar.cnntrain._cocos.CNNTrainCocos;
+import de.monticore.lang.monticar.cnntrain._symboltable.CNNTrainCompilationUnitSymbol;
+import de.monticore.lang.monticar.cnntrain._symboltable.CNNTrainLanguage;
+import de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol;
+import de.monticore.lang.monticar.generator.FileContent;
+import de.monticore.lang.monticar.generator.cpp.GeneratorCPP;
+import de.monticore.symboltable.GlobalScope;
+import de.se_rwth.commons.logging.Log;
+
+import java.io.IOException;
+import java.nio.file.Path;
+import java.util.*;
+
+public abstract class CNNTrainGenerator {
+
+ protected TrainParamSupportChecker trainParamSupportChecker;
+
+ private String generationTargetPath;
+ private String instanceName;
+
+ protected CNNTrainGenerator() {
+ setGenerationTargetPath("./target/generated-sources-cnnarch/");
+ }
+
+ private void supportCheck(ConfigurationSymbol configuration){
+ checkEntryParams(configuration);
+ checkOptimizerParams(configuration);
+ }
+
+ private void checkEntryParams(ConfigurationSymbol configuration){
+ Iterator it = configuration.getEntryMap().keySet().iterator();
+ while (it.hasNext()) {
+ String key = it.next().toString();
+ ASTCNNTrainNode astTrainEntryNode = (ASTCNNTrainNode) configuration.getEntryMap().get(key).getAstNode().get();
+ astTrainEntryNode.accept(trainParamSupportChecker);
+ }
+ it = configuration.getEntryMap().keySet().iterator();
+ while (it.hasNext()) {
+ String key = it.next().toString();
+ if (trainParamSupportChecker.getUnsupportedElemList().contains(key)) {
+ it.remove();
+ }
+ }
+ }
+
+ private void checkOptimizerParams(ConfigurationSymbol configuration){
+ if (configuration.getOptimizer() != null) {
+ ASTOptimizerEntry astOptimizer = (ASTOptimizerEntry) configuration.getOptimizer().getAstNode().get();
+ astOptimizer.accept(trainParamSupportChecker);
+ if (trainParamSupportChecker.getUnsupportedElemList().contains(trainParamSupportChecker.unsupportedOptFlag)) {
+ configuration.setOptimizer(null);
+ }else {
+ Iterator it = configuration.getOptimizer().getOptimizerParamMap().keySet().iterator();
+ while (it.hasNext()) {
+ String key = it.next().toString();
+ if (trainParamSupportChecker.getUnsupportedElemList().contains(key)) {
+ it.remove();
+ }
+ }
+ }
+ }
+ }
+
+ private static void quitGeneration(){
+ Log.error("Code generation is aborted");
+ System.exit(1);
+ }
+
+ public String getInstanceName() {
+ String parsedInstanceName = this.instanceName.replace('.', '_').replace('[', '_').replace(']', '_');
+ parsedInstanceName = parsedInstanceName.substring(0, 1).toLowerCase() + parsedInstanceName.substring(1);
+ return parsedInstanceName;
+ }
+
+ public void setInstanceName(String instanceName) {
+ this.instanceName = instanceName;
+ }
+
+ public String getGenerationTargetPath() {
+ if (generationTargetPath.charAt(generationTargetPath.length() - 1) != '/') {
+ this.generationTargetPath = generationTargetPath + "/";
+ }
+ return generationTargetPath;
+ }
+
+ public void setGenerationTargetPath(String generationTargetPath) {
+ this.generationTargetPath = generationTargetPath;
+ }
+
+ public ConfigurationSymbol getConfigurationSymbol(Path modelsDirPath, String rootModelName) {
+ final ModelPath mp = new ModelPath(modelsDirPath);
+ GlobalScope scope = new GlobalScope(mp, new CNNTrainLanguage());
+ Optional compilationUnit = scope.resolve(rootModelName, CNNTrainCompilationUnitSymbol.KIND);
+ if (!compilationUnit.isPresent()) {
+ Log.error("could not resolve training configuration " + rootModelName);
+ quitGeneration();
+ }
+ setInstanceName(compilationUnit.get().getFullName());
+ CNNTrainCocos.checkAll(compilationUnit.get());
+ supportCheck(compilationUnit.get().getConfiguration());
+ return compilationUnit.get().getConfiguration();
+ }
+
+ public abstract void generate(Path modelsDirPath, String rootModelNames);
+
+ //check cocos with CNNTrainCocos.checkAll(configuration) before calling this method.
+ public abstract Map generateStrings(ConfigurationSymbol configuration);
+}
\ No newline at end of file
diff --git a/src/main/java/de/monticore/lang/monticar/cnnarch/generator/ConfigurationData.java b/src/main/java/de/monticore/lang/monticar/cnnarch/generator/ConfigurationData.java
new file mode 100644
index 0000000000000000000000000000000000000000..23f29db1b48424e1f85745d360efdd6615fe8235
--- /dev/null
+++ b/src/main/java/de/monticore/lang/monticar/cnnarch/generator/ConfigurationData.java
@@ -0,0 +1,99 @@
+package de.monticore.lang.monticar.cnnarch.generator;
+
+import de.monticore.lang.monticar.cnntrain._symboltable.*;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+public class ConfigurationData {
+
+ ConfigurationSymbol configuration;
+ String instanceName;
+
+ public ConfigurationData(ConfigurationSymbol configuration, String instanceName) {
+ this.configuration = configuration;
+ this.instanceName = instanceName;
+ }
+
+ public ConfigurationSymbol getConfiguration() {
+ return configuration;
+ }
+
+ public String getInstanceName() {
+ return instanceName;
+ }
+
+ public String getNumEpoch() {
+ if (!getConfiguration().getEntryMap().containsKey("num_epoch")) {
+ return null;
+ }
+ return String.valueOf(getConfiguration().getEntry("num_epoch").getValue());
+ }
+
+ public String getBatchSize() {
+ if (!getConfiguration().getEntryMap().containsKey("batch_size")) {
+ return null;
+ }
+ return String.valueOf(getConfiguration().getEntry("batch_size") .getValue());
+ }
+
+ public Boolean getLoadCheckpoint() {
+ if (!getConfiguration().getEntryMap().containsKey("load_checkpoint")) {
+ return null;
+ }
+ return (Boolean) getConfiguration().getEntry("load_checkpoint").getValue().getValue();
+ }
+
+ public Boolean getNormalize() {
+ if (!getConfiguration().getEntryMap().containsKey("normalize")) {
+ return null;
+ }
+ return (Boolean) getConfiguration().getEntry("normalize").getValue().getValue();
+ }
+
+ public String getContext() {
+ if (!getConfiguration().getEntryMap().containsKey("context")) {
+ return null;
+ }
+ return getConfiguration().getEntry("context").getValue().toString();
+ }
+
+ public String getEvalMetric() {
+ if (!getConfiguration().getEntryMap().containsKey("eval_metric")) {
+ return null;
+ }
+ return getConfiguration().getEntry("eval_metric").getValue().toString();
+ }
+
+ public String getOptimizerName() {
+ if (getConfiguration().getOptimizer() == null) {
+ return null;
+ }
+ return getConfiguration().getOptimizer().getName();
+ }
+
+ public Map getOptimizerParams() {
+ // get classes for single enum values
+ List lrPolicyClasses = new ArrayList<>();
+ for (LRPolicy enum_value: LRPolicy.values()) {
+ lrPolicyClasses.add(enum_value.getClass());
+ }
+
+ Map mapToStrings = new HashMap<>();
+ Map optimizerParams = getConfiguration().getOptimizer().getOptimizerParamMap();
+ for (Map.Entry entry : optimizerParams.entrySet()) {
+ String paramName = entry.getKey();
+ String valueAsString = entry.getValue().toString();
+ Class realClass = entry.getValue().getValue().getValue().getClass();
+ if (realClass == Boolean.class) {
+ valueAsString = (Boolean) entry.getValue().getValue().getValue() ? "True" : "False";
+ } else if (lrPolicyClasses.contains(realClass)) {
+ valueAsString = "'" + valueAsString + "'";
+ }
+ mapToStrings.put(paramName, valueAsString);
+ }
+ return mapToStrings;
+ }
+}
diff --git a/src/main/java/de/monticore/lang/monticar/cnnarch/generator/DataPathConfigParser.java b/src/main/java/de/monticore/lang/monticar/cnnarch/generator/DataPathConfigParser.java
new file mode 100644
index 0000000000000000000000000000000000000000..3a034b0ffd268b2f7c8c0c5cd4b68bd2ea6de466
--- /dev/null
+++ b/src/main/java/de/monticore/lang/monticar/cnnarch/generator/DataPathConfigParser.java
@@ -0,0 +1,66 @@
+/**
+ *
+ * ******************************************************************************
+ * MontiCAR Modeling Family, www.se-rwth.de
+ * Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
+ * All rights reserved.
+ *
+ * This project is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License as published by the Free Software Foundation; either
+ * version 3.0 of the License, or (at your option) any later version.
+ * This library is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with this project. If not, see .
+ * *******************************************************************************
+ */
+package de.monticore.lang.monticar.cnnarch.generator;
+
+import de.se_rwth.commons.logging.Log;
+
+import java.io.*;
+import java.net.URL;
+import java.util.Objects;
+import java.util.Properties;
+
+public class DataPathConfigParser{
+
+ private String configTargetPath;
+ private String configFileName;
+ private Properties properties;
+
+ public DataPathConfigParser(String configPath) {
+ setConfigPath(configPath);
+ properties = new Properties();
+ try
+ {
+ properties.load(new FileInputStream(configTargetPath));
+ } catch(IOException e)
+ {
+ Log.error("Config file " + configPath + " could not be found");
+ }
+ }
+
+ public String getConfigPath() {
+ if (configTargetPath.charAt(configTargetPath.length() - 1) != '/') {
+ this.configTargetPath = configTargetPath + "/";
+ }
+ return configTargetPath;
+ }
+
+ public void setConfigPath(String configTargetPath){
+ this.configTargetPath = configTargetPath;
+ }
+
+ public String getDataPath(String modelName) {
+ String path = properties.getProperty(modelName);
+ if(path == null) {
+ Log.error("Data path config file did not specify a path for component '" + modelName + "'");
+ }
+ return path;
+ }
+}
diff --git a/src/main/java/de/monticore/lang/monticar/cnnarch/generator/GenericCNNArchCli.java b/src/main/java/de/monticore/lang/monticar/cnnarch/generator/GenericCNNArchCli.java
new file mode 100644
index 0000000000000000000000000000000000000000..a02e2a795781ce36bc91a42c3d3fbe2a01e26838
--- /dev/null
+++ b/src/main/java/de/monticore/lang/monticar/cnnarch/generator/GenericCNNArchCli.java
@@ -0,0 +1,83 @@
+package de.monticore.lang.monticar.cnnarch.generator;
+
+import de.se_rwth.commons.logging.Log;
+import org.apache.commons.cli.*;
+
+import java.nio.file.Path;
+import java.nio.file.Paths;
+
+/**
+ *
+ */
+public class GenericCNNArchCli {
+ private final CNNArchGenerator cnnArchGenerator;
+
+ public static final Option OPTION_MODELS_PATH = Option.builder("m")
+ .longOpt("models-dir")
+ .desc("full path to the directory with the CNNArch model")
+ .hasArg(true)
+ .required(true)
+ .build();
+
+ public static final Option OPTION_ROOT_MODEL = Option.builder("r")
+ .longOpt("root-model")
+ .desc("name of the architecture")
+ .hasArg(true)
+ .required(true)
+ .build();
+
+ public static final Option OPTION_OUTPUT_PATH = Option.builder("o")
+ .longOpt("output-dir")
+ .desc("full path to output directory for tests")
+ .hasArg(true)
+ .required(false)
+ .build();
+
+ public GenericCNNArchCli(CNNArchGenerator cnnArchGenerator) {
+ this.cnnArchGenerator = cnnArchGenerator;
+ }
+
+ public void run(String[] args) {
+ Options options = getOptions();
+ CommandLineParser parser = new DefaultParser();
+ CommandLine cliArgs = parseArgs(options, parser, args);
+ if (cliArgs != null) {
+ runGenerator(cliArgs);
+ }
+ }
+
+ private Options getOptions() {
+ Options options = new Options();
+ options.addOption(OPTION_MODELS_PATH);
+ options.addOption(OPTION_ROOT_MODEL);
+ options.addOption(OPTION_OUTPUT_PATH);
+ return options;
+ }
+
+ private CommandLine parseArgs(Options options, CommandLineParser parser, String[] args) {
+ CommandLine cliArgs;
+ try {
+ cliArgs = parser.parse(options, args);
+ } catch (ParseException e) {
+ Log.error("argument parsing exception: " + e.getMessage());
+ quitGeneration();
+ return null;
+ }
+ return cliArgs;
+ }
+
+ private void quitGeneration(){
+ Log.error("Code generation is aborted");
+ System.exit(1);
+ }
+
+ private void runGenerator(CommandLine cliArgs) {
+ Path modelsDirPath = Paths.get(cliArgs.getOptionValue(OPTION_MODELS_PATH.getOpt()));
+ String rootModelName = cliArgs.getOptionValue(OPTION_ROOT_MODEL.getOpt());
+ String outputPath = cliArgs.getOptionValue(OPTION_OUTPUT_PATH.getOpt());
+ if (outputPath != null){
+ cnnArchGenerator.setGenerationTargetPath(outputPath);
+ }
+ cnnArchGenerator.generate(modelsDirPath, rootModelName);
+ }
+}
\ No newline at end of file
diff --git a/src/main/java/de/monticore/lang/monticar/cnnarch/generator/LayerNameCreator.java b/src/main/java/de/monticore/lang/monticar/cnnarch/generator/LayerNameCreator.java
new file mode 100644
index 0000000000000000000000000000000000000000..63eeb851b45784bf2dcf5ac414d320b8c29d9f44
--- /dev/null
+++ b/src/main/java/de/monticore/lang/monticar/cnnarch/generator/LayerNameCreator.java
@@ -0,0 +1,151 @@
+/**
+ *
+ * ******************************************************************************
+ * MontiCAR Modeling Family, www.se-rwth.de
+ * Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
+ * All rights reserved.
+ *
+ * This project is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License as published by the Free Software Foundation; either
+ * version 3.0 of the License, or (at your option) any later version.
+ * This library is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with this project. If not, see .
+ * *******************************************************************************
+ */
+package de.monticore.lang.monticar.cnnarch.generator;
+
+import de.monticore.lang.monticar.cnnarch._symboltable.*;
+import de.monticore.lang.monticar.cnnarch.predefined.Convolution;
+import de.monticore.lang.monticar.cnnarch.predefined.FullyConnected;
+import de.monticore.lang.monticar.cnnarch.predefined.Pooling;
+
+import java.util.*;
+
+public class LayerNameCreator {
+
+ private Map elementToName = new HashMap<>();
+ private Map nameToElement = new HashMap<>();
+
+ public LayerNameCreator(ArchitectureSymbol architecture) {
+ int stage = 1;
+ for (SerialCompositeElementSymbol stream : architecture.getStreams()) {
+ stage = name(stream, stage, new ArrayList<>());
+ }
+ }
+
+ public ArchitectureElementSymbol getArchitectureElement(String name){
+ return nameToElement.get(name);
+ }
+
+ public String getName(ArchitectureElementSymbol architectureElement){
+ return elementToName.get(architectureElement);
+ }
+
+ protected int name(ArchitectureElementSymbol architectureElement, int stage, List streamIndices){
+ if (architectureElement instanceof SerialCompositeElementSymbol) {
+ return nameSerialComposite((SerialCompositeElementSymbol) architectureElement, stage, streamIndices);
+ } else if (architectureElement instanceof ParallelCompositeElementSymbol){
+ return nameParallelComposite((ParallelCompositeElementSymbol) architectureElement, stage, streamIndices);
+ } else{
+ if (architectureElement.isAtomic()){
+ if (architectureElement.getMaxSerialLength().get() > 0){
+ return add(architectureElement, stage, streamIndices);
+ } else {
+ return stage;
+ }
+ } else {
+ ArchitectureElementSymbol resolvedElement = architectureElement.getResolvedThis().get();
+ return name(resolvedElement, stage, streamIndices);
+ }
+ }
+ }
+
+ protected int nameSerialComposite(SerialCompositeElementSymbol compositeElement, int stage, List streamIndices){
+ int endStage = stage;
+ for (ArchitectureElementSymbol subElement : compositeElement.getElements()){
+ endStage = name(subElement, endStage, streamIndices);
+ }
+ return endStage;
+ }
+
+ protected int nameParallelComposite(ParallelCompositeElementSymbol compositeElement, int stage, List streamIndices){
+ int startStage = stage + 1;
+ streamIndices.add(1);
+ int lastIndex = streamIndices.size() - 1;
+
+ List endStages = new ArrayList<>();
+ for (ArchitectureElementSymbol subElement : compositeElement.getElements()){
+ endStages.add(name(subElement, startStage, streamIndices));
+ streamIndices.set(lastIndex, streamIndices.get(lastIndex) + 1);
+ }
+
+ streamIndices.remove(lastIndex);
+ return Collections.max(endStages) + 1;
+ }
+
+ protected int add(ArchitectureElementSymbol architectureElement, int stage, List streamIndices){
+ int endStage = stage;
+ if (!elementToName.containsKey(architectureElement)) {
+ String name = createName(architectureElement, endStage, streamIndices);
+
+ while (nameToElement.containsKey(name)) {
+ endStage++;
+ name = createName(architectureElement, endStage, streamIndices);
+ }
+
+ elementToName.put(architectureElement, name);
+ nameToElement.put(name, architectureElement);
+ }
+ return endStage;
+ }
+
+ protected String createName(ArchitectureElementSymbol architectureElement, int stage, List streamIndices){
+ if (architectureElement instanceof IOSymbol){
+ String name = createBaseName(architectureElement);
+ IOSymbol ioElement = (IOSymbol) architectureElement;
+ if (ioElement.getArrayAccess().isPresent()){
+ int arrayAccess = ioElement.getArrayAccess().get().getIntValue().get();
+ name = name + "_" + arrayAccess + "_";
+ }
+ return name;
+ } else {
+ return createBaseName(architectureElement) + stage + createStreamPostfix(streamIndices) + "_";
+ }
+ }
+
+
+ protected String createBaseName(ArchitectureElementSymbol architectureElement){
+ if (architectureElement instanceof LayerSymbol) {
+ LayerDeclarationSymbol layerDeclaration = ((LayerSymbol) architectureElement).getDeclaration();
+ if (layerDeclaration instanceof Convolution) {
+ return "conv";
+ } else if (layerDeclaration instanceof FullyConnected) {
+ return "fc";
+ } else if (layerDeclaration instanceof Pooling) {
+ return "pool";
+ } else {
+ return layerDeclaration.getName().toLowerCase();
+ }
+ } else if (architectureElement instanceof CompositeElementSymbol){
+ return "group";
+ } else {
+ return architectureElement.getName();
+ }
+ }
+
+ protected String createStreamPostfix(List streamIndices){
+ StringBuilder stringBuilder = new StringBuilder();
+ for (int streamIndex : streamIndices){
+ stringBuilder.append("_");
+ stringBuilder.append(streamIndex);
+ }
+ return stringBuilder.toString();
+ }
+}
+
diff --git a/src/main/java/de/monticore/lang/monticar/cnnarch/generator/LayerSupportChecker.java b/src/main/java/de/monticore/lang/monticar/cnnarch/generator/LayerSupportChecker.java
new file mode 100644
index 0000000000000000000000000000000000000000..ace5288e3db2d0bf08cb2ec74e55ec28772c3b93
--- /dev/null
+++ b/src/main/java/de/monticore/lang/monticar/cnnarch/generator/LayerSupportChecker.java
@@ -0,0 +1,70 @@
+package de.monticore.lang.monticar.cnnarch.generator;
+
+import de.monticore.lang.monticar.cnnarch.predefined.AllPredefinedLayers;
+import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureElementSymbol;
+import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureSymbol;
+import de.monticore.lang.monticar.cnnarch._symboltable.CompositeElementSymbol;
+import de.monticore.lang.monticar.cnnarch._symboltable.ConstantSymbol;
+import de.monticore.lang.monticar.cnnarch._symboltable.LayerDeclarationSymbol;
+import de.monticore.lang.monticar.cnnarch._symboltable.LayerSymbol;
+import de.se_rwth.commons.logging.Log;
+
+import java.util.ArrayList;
+import java.util.List;
+
+public abstract class LayerSupportChecker {
+
+ protected List supportedLayerList = new ArrayList<>();
+
+ private boolean isSupportedLayer(ArchitectureElementSymbol element){
+ ArchitectureElementSymbol resolvedElement = element.getResolvedThis().get();
+ List constructLayerElemList;
+
+ if (resolvedElement instanceof CompositeElementSymbol) {
+ constructLayerElemList = ((CompositeElementSymbol) resolvedElement).getElements();
+ for (ArchitectureElementSymbol constructedLayerElement : constructLayerElemList) {
+ if (!isSupportedLayer(constructedLayerElement)) {
+ return false;
+ }
+ }
+ return true;
+ }
+
+ // Support all inputs and outputs
+ if (resolvedElement.isInput() || resolvedElement.isOutput()) {
+ return true;
+ }
+
+ // Support for constants is checked in ArchitectureSupportChecker
+ if (resolvedElement instanceof ConstantSymbol) {
+ return true;
+ }
+
+ // Support all layer declarations
+ if (resolvedElement instanceof LayerSymbol) {
+ if (!((LayerSymbol) resolvedElement).getDeclaration().isPredefined()) {
+ return true;
+ }
+ }
+
+ if (!supportedLayerList.contains(element.toString())) {
+ Log.error("Unsupported layer " + "'" + element.getName() + "'" + " for the current backend.");
+ return false;
+ } else {
+ return true;
+ }
+ }
+
+ public boolean check(ArchitectureSymbol architecture) {
+ for (CompositeElementSymbol stream : architecture.getStreams()) {
+ for (ArchitectureElementSymbol element : stream.getElements()) {
+ if (!isSupportedLayer(element)) {
+ return false;
+ }
+ }
+ }
+
+ return true;
+ }
+
+}
diff --git a/src/main/java/de/monticore/lang/monticar/cnnarch/generator/Target.java b/src/main/java/de/monticore/lang/monticar/cnnarch/generator/Target.java
new file mode 100644
index 0000000000000000000000000000000000000000..2504df3afbb9d4e73736774cfd68948b07d4d32e
--- /dev/null
+++ b/src/main/java/de/monticore/lang/monticar/cnnarch/generator/Target.java
@@ -0,0 +1,37 @@
+/**
+ *
+ * ******************************************************************************
+ * MontiCAR Modeling Family, www.se-rwth.de
+ * Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
+ * All rights reserved.
+ *
+ * This project is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License as published by the Free Software Foundation; either
+ * version 3.0 of the License, or (at your option) any later version.
+ * This library is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with this project. If not, see .
+ * *******************************************************************************
+ */
+package de.monticore.lang.monticar.cnnarch.generator;
+
+//can be removed
+public enum Target {
+ PYTHON{
+ @Override
+ public String toString() {
+ return ".py";
+ }
+ },
+ CPP{
+ @Override
+ public String toString() {
+ return ".h";
+ }
+ }
+}
diff --git a/src/main/java/de/monticore/lang/monticar/cnnarch/generator/TemplateConfiguration.java b/src/main/java/de/monticore/lang/monticar/cnnarch/generator/TemplateConfiguration.java
new file mode 100644
index 0000000000000000000000000000000000000000..dfb8a964031f741b0e1209a02656e387b4fede87
--- /dev/null
+++ b/src/main/java/de/monticore/lang/monticar/cnnarch/generator/TemplateConfiguration.java
@@ -0,0 +1,73 @@
+/**
+ *
+ * ******************************************************************************
+ * MontiCAR Modeling Family, www.se-rwth.de
+ * Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
+ * All rights reserved.
+ *
+ * This project is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License as published by the Free Software Foundation; either
+ * version 3.0 of the License, or (at your option) any later version.
+ * This library is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with this project. If not, see .
+ * *******************************************************************************
+ */
+package de.monticore.lang.monticar.cnnarch.generator;
+
+import de.se_rwth.commons.logging.Log;
+import freemarker.template.Configuration;
+import freemarker.template.Template;
+import freemarker.template.TemplateException;
+import freemarker.template.TemplateExceptionHandler;
+
+import java.io.IOException;
+import java.io.StringWriter;
+import java.io.Writer;
+import java.util.Map;
+
+public abstract class TemplateConfiguration {
+ abstract protected String getBaseTemplatePackagePath();
+ abstract public Configuration getConfiguration();
+
+ public TemplateConfiguration() {
+
+ }
+
+ protected Configuration createConfiguration() {
+ Configuration configuration = new Configuration(Configuration.VERSION_2_3_23);
+ configuration.setClassForTemplateLoading(TemplateConfiguration.class, getBaseTemplatePackagePath());
+ configuration.setDefaultEncoding("UTF-8");
+ configuration.setTemplateExceptionHandler(TemplateExceptionHandler.RETHROW_HANDLER);
+ return configuration;
+ }
+
+ private void quitGeneration(){
+ Log.error("Code generation is aborted");
+ System.exit(1);
+ }
+
+ public void processTemplate(Map ftlContext, String templatePath, Writer writer){
+ try{
+ Template template = getConfiguration().getTemplate(templatePath);
+ template.process(ftlContext, writer);
+ } catch (IOException e) {
+ Log.error("Freemarker could not find template " + templatePath + " :\n" + e.getMessage());
+ quitGeneration();
+ } catch (TemplateException e){
+ Log.error("An exception occured in template " + templatePath + " :\n" + e.getMessage());
+ quitGeneration();
+ }
+ }
+
+ public String processTemplate(Map ftlContext, String templatePath){
+ StringWriter writer = new StringWriter();
+ processTemplate(ftlContext, templatePath, writer);
+ return writer.toString();
+ }
+}
diff --git a/src/main/java/de/monticore/lang/monticar/cnnarch/generator/TrainParamSupportChecker.java b/src/main/java/de/monticore/lang/monticar/cnnarch/generator/TrainParamSupportChecker.java
new file mode 100644
index 0000000000000000000000000000000000000000..ad1e89d9997cc282e179b19234f29abfc7891bb4
--- /dev/null
+++ b/src/main/java/de/monticore/lang/monticar/cnnarch/generator/TrainParamSupportChecker.java
@@ -0,0 +1,94 @@
+package de.monticore.lang.monticar.cnnarch.generator;
+
+import de.monticore.lang.monticar.cnntrain._ast.*;
+import de.monticore.lang.monticar.cnntrain._visitor.CNNTrainVisitor;
+import de.se_rwth.commons.logging.Log;
+import java.util.ArrayList;
+import java.util.List;
+
+public abstract class TrainParamSupportChecker implements CNNTrainVisitor {
+
+ protected List unsupportedElemList = new ArrayList<>();
+
+ protected void printUnsupportedEntryParam(String nodeName){
+ Log.warn("Unsupported training parameter " + "'" + nodeName + "'" + " for the current backend. It will be ignored.");
+ }
+
+ protected void printUnsupportedOptimizer(String nodeName){
+ Log.warn("Unsupported optimizer parameter " + "'" + nodeName + "'" + " for the current backend. It will be ignored.");
+ }
+
+ protected void printUnsupportedOptimizerParam(String nodeName){
+ Log.warn("Unsupported training optimizer parameter " + "'" + nodeName + "'" + " for the current backend. It will be ignored.");
+ }
+
+ public TrainParamSupportChecker() {
+ }
+
+ public static final String unsupportedOptFlag = "unsupported_optimizer";
+
+ public List getUnsupportedElemList(){
+ return this.unsupportedElemList;
+ }
+
+ //Empty visit method denotes that the corresponding training parameter is supported.
+ //To set a training parameter as unsupported, add the corresponding node to the unsupportedElemList
+ public void visit(ASTNumEpochEntry node){}
+
+ public void visit(ASTBatchSizeEntry node){}
+
+ public void visit(ASTLoadCheckpointEntry node){}
+
+ public void visit(ASTNormalizeEntry node){}
+
+ public void visit(ASTTrainContextEntry node){}
+
+ public void visit(ASTEvalMetricEntry node){}
+
+ public void visit(ASTSGDOptimizer node){}
+
+ public void visit(ASTAdamOptimizer node){}
+
+ public void visit(ASTRmsPropOptimizer node){}
+
+ public void visit(ASTAdaGradOptimizer node){}
+
+ public void visit(ASTNesterovOptimizer node){}
+
+ public void visit(ASTAdaDeltaOptimizer node){}
+
+ public void visit(ASTLearningRateEntry node){}
+
+ public void visit(ASTMinimumLearningRateEntry node){}
+
+ public void visit(ASTWeightDecayEntry node){}
+
+ public void visit(ASTLRDecayEntry node){}
+
+ public void visit(ASTLRPolicyEntry node){}
+
+ public void visit(ASTRescaleGradEntry node){}
+
+ public void visit(ASTClipGradEntry node){}
+
+ public void visit(ASTStepSizeEntry node){}
+
+ public void visit(ASTMomentumEntry node){}
+
+ public void visit(ASTBeta1Entry node){}
+
+ public void visit(ASTBeta2Entry node){}
+
+ public void visit(ASTEpsilonEntry node){}
+
+ public void visit(ASTGamma1Entry node){}
+
+ public void visit(ASTGamma2Entry node){}
+
+ public void visit(ASTCenteredEntry node){}
+
+ public void visit(ASTClipWeightsEntry node){}
+
+ public void visit(ASTRhoEntry node){}
+
+}
diff --git a/src/test/java/de/monticore/lang/monticar/cnnarch/generator/AbstractSymtabTest.java b/src/test/java/de/monticore/lang/monticar/cnnarch/generator/AbstractSymtabTest.java
new file mode 100644
index 0000000000000000000000000000000000000000..bc02c1ad757b0b6660401a88ac4b44fcb74206a7
--- /dev/null
+++ b/src/test/java/de/monticore/lang/monticar/cnnarch/generator/AbstractSymtabTest.java
@@ -0,0 +1,132 @@
+/**
+ *
+ * ******************************************************************************
+ * MontiCAR Modeling Family, www.se-rwth.de
+ * Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
+ * All rights reserved.
+ *
+ * This project is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License as published by the Free Software Foundation; either
+ * version 3.0 of the License, or (at your option) any later version.
+ * This library is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with this project. If not, see .
+ * *******************************************************************************
+ */
+package de.monticore.lang.monticar.cnnarch.generator;
+
+import de.monticore.ModelingLanguageFamily;
+import de.monticore.io.paths.ModelPath;
+import de.monticore.lang.monticar.cnnarch._symboltable.CNNArchCompilationUnitSymbol;
+import de.monticore.lang.monticar.cnnarch._symboltable.CNNArchLanguage;
+import de.monticore.symboltable.GlobalScope;
+import de.monticore.symboltable.Scope;
+import org.junit.Assert;
+
+import java.io.File;
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.List;
+import java.util.stream.Collectors;
+
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertTrue;
+
+/**
+ * Common methods for symboltable tests
+ */
+public class AbstractSymtabTest {
+
+ private static final String MODEL_PATH = "src/test/resources/";
+
+ protected static Scope createSymTab(String... modelPath) {
+ ModelingLanguageFamily fam = new ModelingLanguageFamily();
+
+ fam.addModelingLanguage(new CNNArchLanguage());
+
+ final ModelPath mp = new ModelPath();
+ for (String m : modelPath) {
+ mp.addEntry(Paths.get(m));
+ }
+
+ return new GlobalScope(mp, fam);
+ }
+
+ protected static CNNArchCompilationUnitSymbol getCompilationUnitSymbol(String modelPath, String model) {
+ Scope symTab = createSymTab(MODEL_PATH + modelPath);
+ CNNArchCompilationUnitSymbol comp = symTab. resolve(
+ model, CNNArchCompilationUnitSymbol.KIND).orElse(null);
+ assertNotNull("Could not resolve model " + model, comp);
+
+ return comp;
+ }
+
+
+
+ public static void checkFilesAreEqual(Path generationPath, Path resultsPath, List fileNames) {
+ for (String fileName : fileNames){
+ File genFile = new File(generationPath.toString() + "/" + fileName);
+ File fileTarget = new File(resultsPath.toString() + "/" + fileName);
+ assertTrue(areBothFilesEqual(genFile, fileTarget));
+ }
+ }
+
+ public static boolean areBothFilesEqual(File file1, File file2) {
+ if (!file1.exists()) {
+ Assert.fail("file does not exist: " + file1.getAbsolutePath());
+ return false;
+ }
+ if (!file2.exists()) {
+ Assert.fail("file does not exist: " + file2.getAbsolutePath());
+ return false;
+ }
+ List lines1;
+ List lines2;
+ try {
+ lines1 = Files.readAllLines(file1.toPath());
+ lines2 = Files.readAllLines(file2.toPath());
+ } catch (IOException e) {
+ e.printStackTrace();
+ Assert.fail("IO error: " + e.getMessage());
+ return false;
+ }
+ lines1 = discardEmptyLines(lines1);
+ lines2 = discardEmptyLines(lines2);
+ if (lines1.size() != lines2.size()) {
+ Assert.fail(
+ "files have different number of lines: "
+ + file1.getAbsolutePath()
+ + " has " + lines1
+ + " lines and " + file2.getAbsolutePath() + " has " + lines2 + " lines"
+ );
+ return false;
+ }
+ int len = lines1.size();
+ for (int i = 0; i < len; i++) {
+ String l1 = lines1.get(i);
+ String l2 = lines2.get(i);
+ Assert.assertEquals("files differ in " + i + " line: "
+ + file1.getAbsolutePath()
+ + " has " + l1
+ + " and " + file2.getAbsolutePath() + " has " + l2,
+ l1,
+ l2
+ );
+ }
+ return true;
+ }
+
+ private static List discardEmptyLines(List lines) {
+ return lines.stream()
+ .map(String::trim)
+ .filter(l -> !l.isEmpty())
+ .collect(Collectors.toList());
+ }
+}
diff --git a/src/test/java/de/monticore/lang/monticar/cnnarch/generator/DataPathConfigParserTest.java b/src/test/java/de/monticore/lang/monticar/cnnarch/generator/DataPathConfigParserTest.java
new file mode 100644
index 0000000000000000000000000000000000000000..f3a4b133ececf1bcc193eaed4d1c95f2f4727935
--- /dev/null
+++ b/src/test/java/de/monticore/lang/monticar/cnnarch/generator/DataPathConfigParserTest.java
@@ -0,0 +1,66 @@
+/**
+ *
+ * ******************************************************************************
+ * MontiCAR Modeling Family, www.se-rwth.de
+ * Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
+ * All rights reserved.
+ *
+ * This project is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License as published by the Free Software Foundation; either
+ * version 3.0 of the License, or (at your option) any later version.
+ * This library is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with this project. If not, see .
+ * *******************************************************************************
+ */
+package de.monticore.lang.monticar.cnnarch.generator;
+
+import de.monticore.lang.monticar.cnnarch._symboltable.*;
+import de.monticore.lang.monticar.cnnarch.predefined.AllPredefinedVariables;
+import de.monticore.symboltable.Scope;
+import de.se_rwth.commons.logging.Log;
+import org.junit.Before;
+import org.junit.Test;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertTrue;
+
+public class DataPathConfigParserTest extends AbstractSymtabTest {
+
+ @Before
+ public void setUp() {
+ // ensure an empty log
+ Log.getFindings().clear();
+ Log.enableFailQuick(false);
+ }
+
+ @Test
+ public void testDataPathConfigParserValidComponent() {
+ DataPathConfigParser parser = new DataPathConfigParser("src/test/resources/architectures/data_paths.txt");
+
+ String data_path = parser.getDataPath("ComponentName");
+ assertTrue("Wrong data path returned", data_path.equals("/path/to/training/data"));
+ }
+
+ @Test
+ public void testDataPathConfigParserInvalidComponent() {
+ DataPathConfigParser parser = new DataPathConfigParser("src/test/resources/architectures/data_paths.txt");
+
+ String data_path = parser.getDataPath("NotExistingComponent");
+ assertTrue("For not listed components, null should be returned", data_path == null);
+ assertTrue(Log.getFindings().size() == 1);
+ }
+
+ @Test
+ public void testDataPathConfigParserInvalidPath() {
+ DataPathConfigParser parser = new DataPathConfigParser("invalid/path/data_paths.txt");
+
+ assertTrue(Log.getFindings().size() == 1);
+ }
+}
diff --git a/src/test/resources/architectures/data_paths.txt b/src/test/resources/architectures/data_paths.txt
new file mode 100644
index 0000000000000000000000000000000000000000..a4c785fe203697c5410b48498cef45c4f788db60
--- /dev/null
+++ b/src/test/resources/architectures/data_paths.txt
@@ -0,0 +1 @@
+ComponentName /path/to/training/data
\ No newline at end of file