Skip to content
Snippets Groups Projects
Commit aaa745b1 authored by Evgeny Kusmenko's avatar Evgeny Kusmenko
Browse files

Merge branch 'timmermanns' into 'master'

Removed github deploy plugin, changed repository to se nexus

See merge request CNNArch2MXNet!2
parents 0e3a557a 07d96962
No related branches found
No related tags found
No related merge requests found
Showing
with 1771 additions and 3 deletions
#
#
# ******************************************************************************
# MontiCAR Modeling Family, www.se-rwth.de
# Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
# All rights reserved.
#
# This project is free software; you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public
# License as published by the Free Software Foundation; either
# version 3.0 of the License, or (at your option) any later version.
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public
# License along with this project. If not, see <http://www.gnu.org/licenses/>.
# *******************************************************************************
#
# Java Maven CircleCI 2.0 configuration file
#
# Check https://circleci.com/docs/2.0/language-java/ for more details
#
version: 2
general:
branches:
ignore:
- gh-pages
jobs:
build:
docker:
# specify the version you desire here
- image: circleci/openjdk:8-jdk
# Specify service dependencies here if necessary
# CircleCI maintains a library of pre-built images
# documented at https://circleci.com/docs/2.0/circleci-images/
# - image: circleci/postgres:9.4
working_directory: ~/repo
environment:
# Customize the JVM maximum heap limit
MAVEN_OPTS: -Xmx3200m
steps:
- checkout
# run tests!
- run: mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B clean install --settings "settings.xml"
workflows:
version: 2
commit-workflow:
jobs:
- build
scheduled-workflow:
triggers:
- schedule:
cron: "30 1 * * *"
filters:
branches:
only: master
jobs:
- build
target
nppBackup
.project
.settings
.classpath
.idea
.git
*.iml
image: maven:3-jdk-8 #
#
# ******************************************************************************
# MontiCAR Modeling Family, www.se-rwth.de
# Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
# All rights reserved.
#
# This project is free software; you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public
# License as published by the Free Software Foundation; either
# version 3.0 of the License, or (at your option) any later version.
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public
# License along with this project. If not, see <http://www.gnu.org/licenses/>.
# *******************************************************************************
#
build: stages:
script: "mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B clean install --settings settings.xml" - windows
- linux
masterJobLinux:
stage: linux
image: maven:3-jdk-8
script:
- mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B clean install deploy --settings settings.xml
only:
- master
masterJobWindows:
stage: windows
script:
- mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B clean install --settings settings.xml
tags:
- Windows10
BranchJobLinux:
stage: linux
image: maven:3-jdk-8
script:
- mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B clean install --settings settings.xml
except:
- master
script:
- git checkout ${TRAVIS_BRANCH}
- mvn clean install cobertura:cobertura org.eluder.coveralls:coveralls-maven-plugin:report --settings "settings.xml"
after_success:
- if [ "${TRAVIS_BRANCH}" == "master" ]; then mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B deploy --debug --settings "./settings.xml"; fi
pom.xml 0 → 100644
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<!-- == PROJECT COORDINATES ============================================= -->
<groupId>de.monticore.lang.monticar</groupId>
<artifactId>cnnarch-mxnet-generator</artifactId>
<version>0.2.1-SNAPSHOT</version>
<!-- == PROJECT DEPENDENCIES ============================================= -->
<properties>
<!-- .. SE-Libraries .................................................. -->
<CNNArch.version>0.2.1-SNAPSHOT</CNNArch.version>
<!-- .. Libraries .................................................. -->
<guava.version>18.0</guava.version>
<junit.version>4.12</junit.version>
<logback.version>1.1.2</logback.version>
<jscience.version>4.3.1</jscience.version>
<!-- .. Plugins ....................................................... -->
<assembly.plugin>2.5.4</assembly.plugin>
<compiler.plugin>3.3</compiler.plugin>
<source.plugin>2.4</source.plugin>
<shade.plugin>2.4.3</shade.plugin>
<!-- Classifiers -->
<grammars.classifier>grammars</grammars.classifier>
<cli.classifier>cli</cli.classifier>
<!-- .. Misc .......................................................... -->
<java.version>1.8</java.version>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<project.reporting.outputEncoding>UTF-8</project.reporting.outputEncoding>
</properties>
<dependencies>
<dependency>
<groupId>org.antlr</groupId>
<artifactId>antlr4-runtime</artifactId>
<version>4.7.1</version>
</dependency>
<dependency>
<groupId>com.google.guava</groupId>
<artifactId>guava</artifactId>
<version>${guava.version}</version>
</dependency>
<!-- MontiCore Dependencies -->
<dependency>
<groupId>de.monticore.lang.monticar</groupId>
<artifactId>cnn-arch</artifactId>
<version>${CNNArch.version}</version>
</dependency>
<dependency>
<groupId>de.monticore.lang.monticar</groupId>
<artifactId>cnn-arch</artifactId>
<version>${CNNArch.version}</version>
<classifier>${grammars.classifier}</classifier>
<scope>provided</scope>
</dependency>
<!-- .. Test Libraries ............................................... -->
<dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>
<version>${junit.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>ch.qos.logback</groupId>
<artifactId>logback-classic</artifactId>
<version>${logback.version}</version>
</dependency>
<dependency>
<groupId>org.jscience</groupId>
<artifactId>jscience</artifactId>
<version>${jscience.version}</version>
</dependency>
</dependencies>
<!-- == PROJECT BUILD SETTINGS =========================================== -->
<build>
<plugins>
<plugin>
<artifactId>maven-deploy-plugin</artifactId>
<version>2.8.1</version>
</plugin>
<!-- Other Configuration -->
<plugin>
<artifactId>maven-compiler-plugin</artifactId>
<version>${compiler.plugin}</version>
<configuration>
<useIncrementalCompilation>true</useIncrementalCompilation>
<source>${java.version}</source>
<target>${java.version}</target>
</configuration>
</plugin>
<plugin>
<artifactId>maven-assembly-plugin</artifactId>
<version>3.1.0</version>
<executions>
<execution>
<id>jar-with-dependencies</id>
<phase>package</phase>
<goals>
<goal>single</goal>
</goals>
<configuration>
<archive>
<manifest>
<mainClass>de.monticore.lang.monticar.cnnarch.generator.CNNArchGeneratorCli</mainClass>
</manifest>
</archive>
<descriptorRefs>
<descriptorRef>jar-with-dependencies</descriptorRef>
</descriptorRefs>
</configuration>
</execution>
</executions>
</plugin>
<!-- Source Jar Configuration -->
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-source-plugin</artifactId>
<version>${source.plugin}</version>
<executions>
<execution>
<id>create source jar</id>
<phase>package</phase>
<goals>
<goal>jar-no-fork</goal>
</goals>
<configuration>
<excludeResources>false</excludeResources>
<includes>
<include>**/*.java</include>
<include>**/*.ftl</include>
</includes>
</configuration>
</execution>
</executions>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-surefire-plugin</artifactId>
<version>2.19.1</version>
<configuration>
</configuration>
</plugin>
<plugin>
<groupId>org.eluder.coveralls</groupId>
<artifactId>coveralls-maven-plugin</artifactId>
<version>4.3.0</version>
<configuration>
</configuration>
</plugin>
<plugin>
<groupId>org.codehaus.mojo</groupId>
<artifactId>cobertura-maven-plugin</artifactId>
<version>2.7</version>
<configuration>
<format>xml</format>
<maxmem>256m</maxmem>
<!-- aggregated reports for multi-module projects -->
<aggregate>true</aggregate>
</configuration>
</plugin>
</plugins>
</build>
<distributionManagement>
<repository>
<id>se-nexus</id>
<url>https://nexus.se.rwth-aachen.de/content/repositories/embeddedmontiarc-releases/</url>
</repository>
<snapshotRepository>
<id>se-nexus</id>
<url>https://nexus.se.rwth-aachen.de/content/repositories/embeddedmontiarc-snapshots/</url>
</snapshotRepository>
</distributionManagement>
</project>
<?xml version="1.0" encoding="UTF-8"?>
<!--
******************************************************************************
MontiCAR Modeling Family, www.se-rwth.de
Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
All rights reserved.
This project is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 3.0 of the License, or (at your option) any later version.
This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public
License along with this project. If not, see <http://www.gnu.org/licenses/>.
*******************************************************************************
-->
<settings xmlns="http://maven.apache.org/SETTINGS/1.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/SETTINGS/1.0.0 http://maven.apache.org/xsd/settings-1.0.0.xsd">
<pluginGroups>
<pluginGroup>org.mortbay.jetty</pluginGroup>
<pluginGroup>de.topobyte</pluginGroup>
</pluginGroups>
<proxies>
</proxies>
<servers>
<server>
<id>se-nexus</id>
<username>cibuild</username>
<password>${env.cibuild}</password>
</server>
<server>
<id>github</id>
<username>travisbuilduser</username>
<password>${env.travisbuilduserpassword}</password>
</server>
</servers>
<mirrors>
<mirror>
<id>se-nexus</id>
<mirrorOf>external:*</mirrorOf>
<url>https://nexus.se.rwth-aachen.de/content/groups/public</url>
</mirror>
</mirrors>
<profiles>
<profile>
<id>se-nexus</id>
<repositories>
<repository>
<id>central</id>
<url>http://central</url>
<releases><enabled /></releases>
<snapshots><enabled /></snapshots>
</repository>
</repositories>
<pluginRepositories>
<pluginRepository>
<id>central</id>
<url>http://central</url>
<releases><enabled /></releases>
<snapshots><enabled /></snapshots>
</pluginRepository>
</pluginRepositories>
</profile>
</profiles>
<activeProfiles>
<activeProfile>se-nexus</activeProfile>
</activeProfiles>
</settings>
\ No newline at end of file
******************************************************************************
MontiCAR Modeling Family, www.se-rwth.de
Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
All rights reserved.
This project is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 3.0 of the License, or (at your option) any later version.
This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public
License along with this project. If not, see <http://www.gnu.org/licenses/>.
*******************************************************************************
/**
*
* ******************************************************************************
* MontiCAR Modeling Family, www.se-rwth.de
* Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
* All rights reserved.
*
* This project is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 3.0 of the License, or (at your option) any later version.
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package de.monticore.lang.monticar.cnnarch.generator;
import de.monticore.lang.monticar.cnnarch._symboltable.ArchTypeSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureElementSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.LayerSymbol;
import de.monticore.lang.monticar.cnnarch.predefined.AllPredefinedLayers;
import de.se_rwth.commons.logging.Log;
import javax.annotation.Nullable;
import java.util.Arrays;
import java.util.List;
public class ArchitectureElementData {
private String name;
private ArchitectureElementSymbol element;
private CNNArchTemplateController templateController;
public ArchitectureElementData(String name, ArchitectureElementSymbol element, CNNArchTemplateController templateController) {
this.name = name;
this.element = element;
this.templateController = templateController;
}
public String getName() {
return name;
}
public void setName(String name) {
this.name = name;
}
public ArchitectureElementSymbol getElement() {
return element;
}
public void setElement(ArchitectureElementSymbol element) {
this.element = element;
}
public CNNArchTemplateController getTemplateController() {
return templateController;
}
public void setTemplateController(CNNArchTemplateController templateController) {
this.templateController = templateController;
}
public List<String> getInputs(){
return getTemplateController().getLayerInputs(getElement());
}
public boolean isLogisticRegressionOutput(){
return getTemplateController().isLogisticRegressionOutput(getElement());
}
public boolean isLinearRegressionOutput(){
boolean result = getTemplateController().isLinearRegressionOutput(getElement());
if (result){
Log.warn("The Output '" + getElement().getName() + "' is a linear regression output (squared loss) during training" +
" because the previous architecture element is not a softmax (cross-entropy loss) or sigmoid (logistic regression loss) activation. " +
"Other loss functions are currently not supported. "
, getElement().getSourcePosition());
}
return result;
}
public boolean isSoftmaxOutput(){
return getTemplateController().isSoftmaxOutput(getElement());
}
public List<Integer> getKernel(){
return ((LayerSymbol) getElement())
.getIntTupleValue(AllPredefinedLayers.KERNEL_NAME).get();
}
public int getChannels(){
return ((LayerSymbol) getElement())
.getIntValue(AllPredefinedLayers.CHANNELS_NAME).get();
}
public List<Integer> getStride(){
return ((LayerSymbol) getElement())
.getIntTupleValue(AllPredefinedLayers.STRIDE_NAME).get();
}
public int getUnits(){
return ((LayerSymbol) getElement())
.getIntValue(AllPredefinedLayers.UNITS_NAME).get();
}
public boolean getNoBias(){
return ((LayerSymbol) getElement())
.getBooleanValue(AllPredefinedLayers.NOBIAS_NAME).get();
}
public double getP(){
return ((LayerSymbol) getElement())
.getDoubleValue(AllPredefinedLayers.P_NAME).get();
}
public int getIndex(){
return ((LayerSymbol) getElement())
.getIntValue(AllPredefinedLayers.INDEX_NAME).get();
}
public int getNumOutputs(){
return ((LayerSymbol) getElement())
.getIntValue(AllPredefinedLayers.NUM_SPLITS_NAME).get();
}
public boolean getFixGamma(){
return ((LayerSymbol) getElement())
.getBooleanValue(AllPredefinedLayers.FIX_GAMMA_NAME).get();
}
public int getNsize(){
return ((LayerSymbol) getElement())
.getIntValue(AllPredefinedLayers.NSIZE_NAME).get();
}
public double getKnorm(){
return ((LayerSymbol) getElement())
.getDoubleValue(AllPredefinedLayers.KNORM_NAME).get();
}
public double getAlpha(){
return ((LayerSymbol) getElement())
.getDoubleValue(AllPredefinedLayers.ALPHA_NAME).get();
}
public double getBeta(){
return ((LayerSymbol) getElement())
.getDoubleValue(AllPredefinedLayers.BETA_NAME).get();
}
@Nullable
public String getPoolType(){
return ((LayerSymbol) getElement())
.getStringValue(AllPredefinedLayers.POOL_TYPE_NAME).get();
}
@Nullable
public List<Integer> getPadding(){
return getPadding((LayerSymbol) getElement());
}
@Nullable
public List<Integer> getPadding(LayerSymbol layer){
List<Integer> kernel = layer.getIntTupleValue(AllPredefinedLayers.KERNEL_NAME).get();
List<Integer> stride = layer.getIntTupleValue(AllPredefinedLayers.STRIDE_NAME).get();
ArchTypeSymbol inputType = layer.getInputTypes().get(0);
ArchTypeSymbol outputType = layer.getOutputTypes().get(0);
int heightWithPad = kernel.get(0) + stride.get(0)*(outputType.getHeight() - 1);
int widthWithPad = kernel.get(1) + stride.get(1)*(outputType.getWidth() - 1);
int heightPad = Math.max(0, heightWithPad - inputType.getHeight());
int widthPad = Math.max(0, widthWithPad - inputType.getWidth());
int topPad = (int)Math.ceil(heightPad / 2.0);
int bottomPad = (int)Math.floor(heightPad / 2.0);
int leftPad = (int)Math.ceil(widthPad / 2.0);
int rightPad = (int)Math.floor(widthPad / 2.0);
if (topPad == 0 && bottomPad == 0 && leftPad == 0 && rightPad == 0){
return null;
}
return Arrays.asList(0,0,0,0,topPad,bottomPad,leftPad,rightPad);
}
}
/**
*
* ******************************************************************************
* MontiCAR Modeling Family, www.se-rwth.de
* Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
* All rights reserved.
*
* This project is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 3.0 of the License, or (at your option) any later version.
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package de.monticore.lang.monticar.cnnarch.generator;
import de.monticore.io.paths.ModelPath;
import de.monticore.lang.monticar.cnnarch._cocos.CNNArchCocos;
import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.CNNArchCompilationUnitSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.CNNArchLanguage;
import de.monticore.symboltable.GlobalScope;
import de.monticore.symboltable.Scope;
import de.se_rwth.commons.logging.Log;
import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.nio.file.Path;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
public class CNNArchGenerator {
private String generationTargetPath;
public CNNArchGenerator() {
setGenerationTargetPath("./target/generated-sources-cnnarch/");
}
public String getGenerationTargetPath() {
if (generationTargetPath.charAt(generationTargetPath.length() - 1) != '/') {
this.generationTargetPath = generationTargetPath + "/";
}
return generationTargetPath;
}
public void setGenerationTargetPath(String generationTargetPath) {
this.generationTargetPath = generationTargetPath;
}
public void generate(Path modelsDirPath, String rootModelName){
final ModelPath mp = new ModelPath(modelsDirPath);
GlobalScope scope = new GlobalScope(mp, new CNNArchLanguage());
generate(scope, rootModelName);
}
public void generate(Scope scope, String rootModelName){
Optional<CNNArchCompilationUnitSymbol> compilationUnit = scope.resolve(rootModelName, CNNArchCompilationUnitSymbol.KIND);
if (!compilationUnit.isPresent()){
Log.error("could not resolve architecture " + rootModelName);
System.exit(1);
}
CNNArchCocos.checkAll(compilationUnit.get());
try{
generateFiles(compilationUnit.get().getArchitecture());
}
catch (IOException e){
Log.error(e.toString());
}
}
//check cocos with CNNArchCocos.checkAll(architecture) before calling this method.
public Map<String, String> generateStrings(ArchitectureSymbol architecture){
Map<String, String> fileContentMap = new HashMap<>();
CNNArchTemplateController archTc = new CNNArchTemplateController(architecture);
Map.Entry<String, String> temp;
temp = archTc.process("CNNPredictor", Target.CPP);
fileContentMap.put(temp.getKey(), temp.getValue());
temp = archTc.process("CNNCreator", Target.PYTHON);
fileContentMap.put(temp.getKey(), temp.getValue());
temp = archTc.process("execute", Target.CPP);
fileContentMap.put(temp.getKey().replace(".h", ""), temp.getValue());
temp = archTc.process("CNNBufferFile", Target.CPP);
fileContentMap.put("CNNBufferFile.h", temp.getValue());
checkValidGeneration(architecture);
return fileContentMap;
}
private void checkValidGeneration(ArchitectureSymbol architecture){
if (architecture.getInputs().size() > 1){
Log.warn("This cnn architecture has multiple inputs, " +
"which is currently not supported by the generator. " +
"The generated code will not work correctly."
, architecture.getSourcePosition());
}
if (architecture.getOutputs().size() > 1){
Log.warn("This cnn architecture has multiple outputs, " +
"which is currently not supported by the generator. " +
"The generated code will not work correctly."
, architecture.getSourcePosition());
}
if (architecture.getOutputs().get(0).getDefinition().getType().getWidth() != 1 ||
architecture.getOutputs().get(0).getDefinition().getType().getHeight() != 1){
Log.error("This cnn architecture has a multi-dimensional output, " +
"which is currently not supported by the generator."
, architecture.getSourcePosition());
}
}
//check cocos with CNNArchCocos.checkAll(architecture) before calling this method.
public void generateFiles(ArchitectureSymbol architecture) throws IOException{
CNNArchTemplateController archTc = new CNNArchTemplateController(architecture);
Map<String, String> fileContentMap = generateStrings(architecture);
for (String fileName : fileContentMap.keySet()){
File f = new File(getGenerationTargetPath() + fileName);
Log.info(f.getName(), "FileCreation:");
if (!f.exists()) {
f.getParentFile().mkdirs();
if (!f.createNewFile()) {
Log.error("File could not be created");
}
}
FileWriter writer = new FileWriter(f);
writer.write(fileContentMap.get(fileName));
writer.close();
}
}
}
/**
*
* ******************************************************************************
* MontiCAR Modeling Family, www.se-rwth.de
* Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
* All rights reserved.
*
* This project is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 3.0 of the License, or (at your option) any later version.
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package de.monticore.lang.monticar.cnnarch.generator;
import org.apache.commons.cli.*;
import java.nio.file.Path;
import java.nio.file.Paths;
public class CNNArchGeneratorCli {
public static final Option OPTION_MODELS_PATH = Option.builder("m")
.longOpt("models-dir")
.desc("full path to the directory with the CNNArch model")
.hasArg(true)
.required(true)
.build();
public static final Option OPTION_ROOT_MODEL = Option.builder("r")
.longOpt("root-model")
.desc("name of the architecture")
.hasArg(true)
.required(true)
.build();
public static final Option OPTION_OUTPUT_PATH = Option.builder("o")
.longOpt("output-dir")
.desc("full path to output directory for tests")
.hasArg(true)
.required(false)
.build();
private CNNArchGeneratorCli() {
}
public static void main(String[] args) {
Options options = getOptions();
CommandLineParser parser = new DefaultParser();
CommandLine cliArgs = parseArgs(options, parser, args);
if (cliArgs != null) {
runGenerator(cliArgs);
}
}
private static Options getOptions() {
Options options = new Options();
options.addOption(OPTION_MODELS_PATH);
options.addOption(OPTION_ROOT_MODEL);
options.addOption(OPTION_OUTPUT_PATH);
return options;
}
private static CommandLine parseArgs(Options options, CommandLineParser parser, String[] args) {
CommandLine cliArgs;
try {
cliArgs = parser.parse(options, args);
} catch (ParseException e) {
System.err.println("argument parsing exception: " + e.getMessage());
System.exit(1);
return null;
}
return cliArgs;
}
private static void runGenerator(CommandLine cliArgs) {
Path modelsDirPath = Paths.get(cliArgs.getOptionValue(OPTION_MODELS_PATH.getOpt()));
String rootModelName = cliArgs.getOptionValue(OPTION_ROOT_MODEL.getOpt());
String outputPath = cliArgs.getOptionValue(OPTION_OUTPUT_PATH.getOpt());
CNNArchGenerator generator = new CNNArchGenerator();
if (outputPath != null){
generator.setGenerationTargetPath(outputPath);
}
generator.generate(modelsDirPath, rootModelName);
}
}
/**
*
* ******************************************************************************
* MontiCAR Modeling Family, www.se-rwth.de
* Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
* All rights reserved.
*
* This project is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 3.0 of the License, or (at your option) any later version.
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package de.monticore.lang.monticar.cnnarch.generator;
import de.monticore.lang.monticar.cnnarch._symboltable.*;
import de.monticore.lang.monticar.cnnarch.predefined.Sigmoid;
import de.monticore.lang.monticar.cnnarch.predefined.Softmax;
import de.se_rwth.commons.logging.Log;
import freemarker.template.Configuration;
import freemarker.template.Template;
import freemarker.template.TemplateException;
import java.io.IOException;
import java.io.StringWriter;
import java.io.Writer;
import java.util.*;
public class CNNArchTemplateController {
public static final String FTL_FILE_ENDING = ".ftl";
public static final String TEMPLATE_ELEMENTS_DIR_PATH = "elements/";
public static final String TEMPLATE_CONTROLLER_KEY = "tc";
public static final String ELEMENT_DATA_KEY = "element";
private LayerNameCreator nameManager;
private Configuration freemarkerConfig = TemplateConfiguration.get();
private ArchitectureSymbol architecture;
private Writer writer;
private String mainTemplateNameWithoutEnding;
private Target targetLanguage;
private ArchitectureElementData dataElement;
public CNNArchTemplateController(ArchitectureSymbol architecture) {
setArchitecture(architecture);
}
public String getFileNameWithoutEnding() {
return mainTemplateNameWithoutEnding + "_" + getFullArchitectureName();
}
public Target getTargetLanguage(){
return targetLanguage;
}
public void setTargetLanguage(Target targetLanguage) {
this.targetLanguage = targetLanguage;
}
public ArchitectureElementData getCurrentElement() {
return dataElement;
}
public void setCurrentElement(ArchitectureElementSymbol layer) {
this.dataElement = new ArchitectureElementData(getName(layer), layer, this);
}
public void setCurrentElement(ArchitectureElementData dataElement) {
this.dataElement = dataElement;
}
public ArchitectureSymbol getArchitecture() {
return architecture;
}
public void setArchitecture(ArchitectureSymbol architecture) {
this.architecture = architecture;
this.nameManager = new LayerNameCreator(architecture);
}
public String getName(ArchitectureElementSymbol layer){
return nameManager.getName(layer);
}
public String getArchitectureName(){
return getArchitecture().getEnclosingScope().getSpanningSymbol().get().getName().replaceAll("\\.","_");
}
public String getFullArchitectureName(){
return getArchitecture().getEnclosingScope().getSpanningSymbol().get().getFullName().replaceAll("\\.","_");
}
public List<String> getLayerInputs(ArchitectureElementSymbol layer){
List<String> inputNames = new ArrayList<>();
if (isSoftmaxOutput(layer) || isLogisticRegressionOutput(layer)){
inputNames = getLayerInputs(layer.getInputElement().get());
}
else {
for (ArchitectureElementSymbol input : layer.getPrevious()) {
if (input.getOutputTypes().size() == 1) {
inputNames.add(getName(input));
} else {
for (int i = 0; i < input.getOutputTypes().size(); i++) {
inputNames.add(getName(input) + "[" + i + "]");
}
}
}
}
return inputNames;
}
public List<String> getArchitectureInputs(){
List<String> list = new ArrayList<>();
for (IOSymbol ioElement : getArchitecture().getInputs()){
list.add(nameManager.getName(ioElement));
}
return list;
}
public List<String> getArchitectureOutputs(){
List<String> list = new ArrayList<>();
for (IOSymbol ioElement : getArchitecture().getOutputs()){
list.add(nameManager.getName(ioElement));
}
return list;
}
public void include(String relativePath, String templateWithoutFileEnding, Writer writer){
String templatePath = relativePath + templateWithoutFileEnding + FTL_FILE_ENDING;
try {
Template template = freemarkerConfig.getTemplate(templatePath);
Map<String, Object> ftlContext = new HashMap<>();
ftlContext.put(TEMPLATE_CONTROLLER_KEY, this);
ftlContext.put(ELEMENT_DATA_KEY, getCurrentElement());
this.writer = writer;
template.process(ftlContext, writer);
this.writer = null;
}
catch (IOException e) {
Log.error("Freemarker could not find template " + templatePath + " :\n" + e.getMessage());
System.exit(1);
}
catch (TemplateException e){
Log.error("An exception occured in template " + templatePath + " :\n" + e.getMessage());
System.exit(1);
}
}
public void include(IOSymbol ioElement, Writer writer){
ArchitectureElementData previousElement = getCurrentElement();
setCurrentElement(ioElement);
if (ioElement.isAtomic()){
if (ioElement.isInput()){
include(TEMPLATE_ELEMENTS_DIR_PATH, "Input", writer);
}
else {
include(TEMPLATE_ELEMENTS_DIR_PATH, "Output", writer);
}
}
else {
include(ioElement.getResolvedThis().get(), writer);
}
setCurrentElement(previousElement);
}
public void include(LayerSymbol layer, Writer writer){
ArchitectureElementData previousElement = getCurrentElement();
setCurrentElement(layer);
if (layer.isAtomic()){
ArchitectureElementSymbol nextElement = layer.getOutputElement().get();
if (!isSoftmaxOutput(nextElement) && !isLogisticRegressionOutput(nextElement)){
String templateName = layer.getDeclaration().getName();
include(TEMPLATE_ELEMENTS_DIR_PATH, templateName, writer);
}
}
else {
include(layer.getResolvedThis().get(), writer);
}
setCurrentElement(previousElement);
}
public void include(CompositeElementSymbol compositeElement, Writer writer){
ArchitectureElementData previousElement = getCurrentElement();
setCurrentElement(compositeElement);
for (ArchitectureElementSymbol element : compositeElement.getElements()){
include(element, writer);
}
setCurrentElement(previousElement);
}
public void include(ArchitectureElementSymbol architectureElement, Writer writer){
if (architectureElement instanceof CompositeElementSymbol){
include((CompositeElementSymbol) architectureElement, writer);
}
else if (architectureElement instanceof LayerSymbol){
include((LayerSymbol) architectureElement, writer);
}
else {
include((IOSymbol) architectureElement, writer);
}
}
public void include(ArchitectureElementSymbol architectureElement){
if (writer == null){
throw new IllegalStateException("missing writer");
}
include(architectureElement, writer);
}
public Map.Entry<String,String> process(String templateNameWithoutEnding, Target targetLanguage){
StringWriter writer = new StringWriter();
this.mainTemplateNameWithoutEnding = templateNameWithoutEnding;
this.targetLanguage = targetLanguage;
include("", templateNameWithoutEnding, writer);
String fileEnding = targetLanguage.toString();
if (targetLanguage == Target.CPP){
fileEnding = ".h";
}
String fileName = getFileNameWithoutEnding() + fileEnding;
Map.Entry<String,String> fileContent = new AbstractMap.SimpleEntry<>(fileName, writer.toString());
this.mainTemplateNameWithoutEnding = null;
this.targetLanguage = null;
return fileContent;
}
public String join(Iterable iterable, String separator){
return join(iterable, separator, "", "");
}
public String join(Iterable iterable, String separator, String elementPrefix, String elementPostfix){
StringBuilder stringBuilder = new StringBuilder();
boolean isFirst = true;
for (Object element : iterable){
if (!isFirst){
stringBuilder.append(separator);
}
stringBuilder.append(elementPrefix);
stringBuilder.append(element.toString());
stringBuilder.append(elementPostfix);
isFirst = false;
}
return stringBuilder.toString();
}
public boolean isLogisticRegressionOutput(ArchitectureElementSymbol architectureElement){
return isTOutput(Sigmoid.class, architectureElement);
}
public boolean isLinearRegressionOutput(ArchitectureElementSymbol architectureElement){
return architectureElement.isOutput()
&& !isLogisticRegressionOutput(architectureElement)
&& !isSoftmaxOutput(architectureElement);
}
public boolean isSoftmaxOutput(ArchitectureElementSymbol architectureElement){
return isTOutput(Softmax.class, architectureElement);
}
private boolean isTOutput(Class inputPredefinedLayerClass, ArchitectureElementSymbol architectureElement){
if (architectureElement.isOutput()){
if (architectureElement.getInputElement().isPresent() && architectureElement.getInputElement().get() instanceof LayerSymbol){
LayerSymbol inputLayer = (LayerSymbol) architectureElement.getInputElement().get();
if (inputPredefinedLayerClass.isInstance(inputLayer.getDeclaration())){
return true;
}
}
}
return false;
}
}
/**
*
* ******************************************************************************
* MontiCAR Modeling Family, www.se-rwth.de
* Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
* All rights reserved.
*
* This project is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 3.0 of the License, or (at your option) any later version.
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package de.monticore.lang.monticar.cnnarch.generator;
import de.monticore.lang.monticar.cnnarch._symboltable.*;
import de.monticore.lang.monticar.cnnarch.predefined.Convolution;
import de.monticore.lang.monticar.cnnarch.predefined.FullyConnected;
import de.monticore.lang.monticar.cnnarch.predefined.Pooling;
import java.util.*;
public class LayerNameCreator {
private Map<ArchitectureElementSymbol, String> elementToName = new HashMap<>();
private Map<String, ArchitectureElementSymbol> nameToElement = new HashMap<>();
public LayerNameCreator(ArchitectureSymbol architecture) {
name(architecture.getBody(), 1, new ArrayList<>());
}
public ArchitectureElementSymbol getArchitectureElement(String name){
return nameToElement.get(name);
}
public String getName(ArchitectureElementSymbol architectureElement){
return elementToName.get(architectureElement);
}
protected int name(ArchitectureElementSymbol architectureElement, int stage, List<Integer> streamIndices){
if (architectureElement instanceof CompositeElementSymbol){
return nameComposite((CompositeElementSymbol) architectureElement, stage, streamIndices);
}
else{
if (architectureElement.isAtomic()){
if (architectureElement.getMaxSerialLength().get() > 0){
return add(architectureElement, stage, streamIndices);
}
else {
return stage;
}
}
else {
ArchitectureElementSymbol resolvedElement = architectureElement.getResolvedThis().get();
return name(resolvedElement, stage, streamIndices);
}
}
}
protected int nameComposite(CompositeElementSymbol compositeElement, int stage, List<Integer> streamIndices){
if (compositeElement.isParallel()){
int startStage = stage + 1;
streamIndices.add(1);
int lastIndex = streamIndices.size() - 1;
List<Integer> endStages = new ArrayList<>();
for (ArchitectureElementSymbol subElement : compositeElement.getElements()){
endStages.add(name(subElement, startStage, streamIndices));
streamIndices.set(lastIndex, streamIndices.get(lastIndex) + 1);
}
streamIndices.remove(lastIndex);
return Collections.max(endStages) + 1;
}
else {
int endStage = stage;
for (ArchitectureElementSymbol subElement : compositeElement.getElements()){
endStage = name(subElement, endStage, streamIndices);
}
return endStage;
}
}
protected int add(ArchitectureElementSymbol architectureElement, int stage, List<Integer> streamIndices){
int endStage = stage;
if (!elementToName.containsKey(architectureElement)) {
String name = createName(architectureElement, endStage, streamIndices);
while (nameToElement.containsKey(name)) {
endStage++;
name = createName(architectureElement, endStage, streamIndices);
}
elementToName.put(architectureElement, name);
nameToElement.put(name, architectureElement);
}
return endStage;
}
protected String createName(ArchitectureElementSymbol architectureElement, int stage, List<Integer> streamIndices){
if (architectureElement instanceof IOSymbol){
String name = createBaseName(architectureElement);
IOSymbol ioElement = (IOSymbol) architectureElement;
if (ioElement.getArrayAccess().isPresent()){
int arrayAccess = ioElement.getArrayAccess().get().getIntValue().get();
name = name + "_" + arrayAccess + "_";
}
return name;
}
else {
return createBaseName(architectureElement) + stage + createStreamPostfix(streamIndices) + "_";
}
}
protected String createBaseName(ArchitectureElementSymbol architectureElement){
if (architectureElement instanceof LayerSymbol) {
LayerDeclarationSymbol layerDeclaration = ((LayerSymbol) architectureElement).getDeclaration();
if (layerDeclaration instanceof Convolution) {
return "conv";
} else if (layerDeclaration instanceof FullyConnected) {
return "fc";
} else if (layerDeclaration instanceof Pooling) {
return "pool";
} else {
return layerDeclaration.getName().toLowerCase();
}
}
else if (architectureElement instanceof CompositeElementSymbol){
return "group";
}
else {
return architectureElement.getName();
}
}
protected String createStreamPostfix(List<Integer> streamIndices){
StringBuilder stringBuilder = new StringBuilder();
for (int streamIndex : streamIndices){
stringBuilder.append("_");
stringBuilder.append(streamIndex);
}
return stringBuilder.toString();
}
}
/**
*
* ******************************************************************************
* MontiCAR Modeling Family, www.se-rwth.de
* Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
* All rights reserved.
*
* This project is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 3.0 of the License, or (at your option) any later version.
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package de.monticore.lang.monticar.cnnarch.generator;
//can be removed
public enum Target {
PYTHON{
@Override
public String toString() {
return ".py";
}
},
CPP{
@Override
public String toString() {
return ".cpp";
}
};
public static Target fromString(String target){
switch (target.toLowerCase()){
case "python":
return PYTHON;
case "py":
return PYTHON;
case "cpp":
return CPP;
case "c++":
return CPP;
default:
throw new IllegalArgumentException();
}
}
}
/**
*
* ******************************************************************************
* MontiCAR Modeling Family, www.se-rwth.de
* Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
* All rights reserved.
*
* This project is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 3.0 of the License, or (at your option) any later version.
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package de.monticore.lang.monticar.cnnarch.generator;
import freemarker.template.Configuration;
import freemarker.template.TemplateExceptionHandler;
public class TemplateConfiguration {
private static TemplateConfiguration instance;
private Configuration configuration;
private TemplateConfiguration() {
configuration = new Configuration(Configuration.VERSION_2_3_23);
configuration.setClassForTemplateLoading(TemplateConfiguration.class, "/templates/");
configuration.setDefaultEncoding("UTF-8");
configuration.setTemplateExceptionHandler(TemplateExceptionHandler.RETHROW_HANDLER);
}
public Configuration getConfiguration() {
return configuration;
}
public static Configuration get(){
if (instance == null){
instance = new TemplateConfiguration();
}
return instance.getConfiguration();
}
}
#ifndef CNNBUFFERFILE_H
#define CNNBUFFERFILE_H
#include <stdio.h>
#include <iostream>
#include <fstream>
// Read file to buffer
class BufferFile {
public :
std::string file_path_;
int length_;
char* buffer_;
explicit BufferFile(std::string file_path)
:file_path_(file_path) {
std::ifstream ifs(file_path.c_str(), std::ios::in | std::ios::binary);
if (!ifs) {
std::cerr << "Can't open the file. Please check " << file_path << ". \n";
length_ = 0;
buffer_ = NULL;
return;
}
ifs.seekg(0, std::ios::end);
length_ = ifs.tellg();
ifs.seekg(0, std::ios::beg);
std::cout << file_path.c_str() << " ... "<< length_ << " bytes\n";
buffer_ = new char[sizeof(char) * length_];
ifs.read(buffer_, length_);
ifs.close();
}
int GetLength() {
return length_;
}
char* GetBuffer() {
return buffer_;
}
~BufferFile() {
if (buffer_) {
delete[] buffer_;
buffer_ = NULL;
}
}
};
#endif // CNNBUFFERFILE_H
import mxnet as mx
import logging
import os
import errno
import shutil
import h5py
import sys
import numpy as np
@mx.init.register
class MyConstant(mx.init.Initializer):
def __init__(self, value):
super(MyConstant, self).__init__(value=value)
self.value = value
def _init_weight(self, _, arr):
arr[:] = mx.nd.array(self.value)
class ${tc.fileNameWithoutEnding}:
module = None
_data_dir_ = "data/${tc.fullArchitectureName}/"
_model_dir_ = "model/${tc.fullArchitectureName}/"
_model_prefix_ = "${tc.architectureName}"
_input_names_ = [${tc.join(tc.architectureInputs, ",", "'", "'")}]
_input_shapes_ = [<#list tc.architecture.inputs as input>(${tc.join(input.definition.type.dimensions, ",")})</#list>]
_output_names_ = [${tc.join(tc.architectureOutputs, ",", "'", "_label'")}]
def load(self, context):
lastEpoch = 0
param_file = None
try:
os.remove(self._model_dir_ + self._model_prefix_ + "_newest-0000.params")
except OSError:
pass
try:
os.remove(self._model_dir_ + self._model_prefix_ + "_newest-symbol.json")
except OSError:
pass
if os.path.isdir(self._model_dir_):
for file in os.listdir(self._model_dir_):
if ".params" in file and self._model_prefix_ in file:
epochStr = file.replace(".params","").replace(self._model_prefix_ + "-","")
epoch = int(epochStr)
if epoch > lastEpoch:
lastEpoch = epoch
param_file = file
if param_file is None:
return 0
else:
logging.info("Loading checkpoint: " + param_file)
self.module.load(prefix=self._model_dir_ + self._model_prefix_,
epoch=lastEpoch,
data_names=self._input_names_,
label_names=self._output_names_,
context=context)
return lastEpoch
def load_data(self, batch_size):
train_h5, test_h5 = self.load_h5_files()
data_mean = train_h5[self._input_names_[0]][:].mean(axis=0)
data_std = train_h5[self._input_names_[0]][:].std(axis=0) + 1e-5
train_iter = mx.io.NDArrayIter(train_h5[self._input_names_[0]],
train_h5[self._output_names_[0]],
batch_size=batch_size,
data_name=self._input_names_[0],
label_name=self._output_names_[0])
test_iter = None
if test_h5 != None:
test_iter = mx.io.NDArrayIter(test_h5[self._input_names_[0]],
test_h5[self._output_names_[0]],
batch_size=batch_size,
data_name=self._input_names_[0],
label_name=self._output_names_[0])
return train_iter, test_iter, data_mean, data_std
def load_h5_files(self):
train_h5 = None
test_h5 = None
train_path = self._data_dir_ + "train.h5"
test_path = self._data_dir_ + "test.h5"
if os.path.isfile(train_path):
train_h5 = h5py.File(train_path, 'r')
if not (self._input_names_[0] in train_h5 and self._output_names_[0] in train_h5):
logging.error("The HDF5 file '" + os.path.abspath(train_path) + "' has to contain the datasets: "
+ "'" + self._input_names_[0] + "', '" + self._output_names_[0] + "'")
sys.exit(1)
test_iter = None
if os.path.isfile(test_path):
test_h5 = h5py.File(test_path, 'r')
if not (self._input_names_[0] in test_h5 and self._output_names_[0] in test_h5):
logging.error("The HDF5 file '" + os.path.abspath(test_path) + "' has to contain the datasets: "
+ "'" + self._input_names_[0] + "', '" + self._output_names_[0] + "'")
sys.exit(1)
else:
logging.warning("Couldn't load test set. File '" + os.path.abspath(test_path) + "' does not exist.")
return train_h5, test_h5
else:
logging.error("Data loading failure. File '" + os.path.abspath(train_path) + "' does not exist.")
sys.exit(1)
def train(self, batch_size,
num_epoch=10,
optimizer='adam',
optimizer_params=(('learning_rate', 0.001),),
load_checkpoint=True,
context='gpu',
checkpoint_period=5,
normalize=True):
if context == 'gpu':
mx_context = mx.gpu()
elif context == 'cpu':
mx_context = mx.cpu()
else:
logging.error("Context argument is '" + context + "'. Only 'cpu' and 'gpu are valid arguments'.")
if 'weight_decay' in optimizer_params:
optimizer_params['wd'] = optimizer_params['weight_decay']
del optimizer_params['weight_decay']
if 'learning_rate_decay' in optimizer_params:
min_learning_rate = 1e-08
if 'learning_rate_minimum' in optimizer_params:
min_learning_rate = optimizer_params['learning_rate_minimum']
del optimizer_params['learning_rate_minimum']
optimizer_params['lr_scheduler'] = mx.lr_scheduler.FactorScheduler(
optimizer_params['step_size'],
factor=optimizer_params['learning_rate_decay'],
stop_factor_lr=min_learning_rate)
del optimizer_params['step_size']
del optimizer_params['learning_rate_decay']
train_iter, test_iter, data_mean, data_std = self.load_data(batch_size)
if self.module == None:
if normalize:
self.construct(mx_context, data_mean, data_std)
else:
self.construct(mx_context)
begin_epoch = 0
if load_checkpoint:
begin_epoch = self.load(mx_context)
else:
if os.path.isdir(self._model_dir_):
shutil.rmtree(self._model_dir_)
try:
os.makedirs(self._model_dir_)
except OSError:
if not os.path.isdir(self._model_dir_):
raise
self.module.fit(
train_data=train_iter,
eval_data=test_iter,
optimizer=optimizer,
optimizer_params=optimizer_params,
batch_end_callback=mx.callback.Speedometer(batch_size),
epoch_end_callback=mx.callback.do_checkpoint(prefix=self._model_dir_ + self._model_prefix_, period=checkpoint_period),
begin_epoch=begin_epoch,
num_epoch=num_epoch + begin_epoch)
self.module.save_checkpoint(self._model_dir_ + self._model_prefix_, num_epoch + begin_epoch)
self.module.save_checkpoint(self._model_dir_ + self._model_prefix_ + '_newest', 0)
def construct(self, context, data_mean=None, data_std=None):
${tc.include(tc.architecture.body)}
self.module = mx.mod.Module(symbol=mx.symbol.Group([${tc.join(tc.architectureOutputs, ",")}]),
data_names=self._input_names_,
label_names=self._output_names_,
context=context)
#ifndef ${tc.fileNameWithoutEnding?upper_case}
#define ${tc.fileNameWithoutEnding?upper_case}
#include <mxnet/c_predict_api.h>
#include <cassert>
#include <string>
#include <vector>
#include <CNNBufferFile.h>
class ${tc.fileNameWithoutEnding}{
public:
const std::string json_file = "model/${tc.fullArchitectureName}/${tc.architectureName}_newest-symbol.json";
const std::string param_file = "model/${tc.fullArchitectureName}/${tc.architectureName}_newest-0000.params";
const std::vector<std::string> input_keys = {"data"};
//const std::vector<std::string> input_keys = {${tc.join(tc.architectureInputs, ",", "\"", "\"")}};
const std::vector<std::vector<mx_uint>> input_shapes = {<#list tc.architecture.inputs as input>{1,${tc.join(input.definition.type.dimensions, ",")}}<#if input?has_next>,</#if></#list>};
const bool use_gpu = false;
PredictorHandle handle;
explicit ${tc.fileNameWithoutEnding}(){
init(json_file, param_file, input_keys, input_shapes, use_gpu);
}
~${tc.fileNameWithoutEnding}(){
if(handle) MXPredFree(handle);
}
void predict(${tc.join(tc.architectureInputs, ", ", "const vector<float> &", "")},
${tc.join(tc.architectureOutputs, ", ", "vector<float> &", "")}){
<#list tc.architectureInputs as inputName>
MXPredSetInput(handle, "data", ${inputName}.data(), ${inputName}.size());
//MXPredSetInput(handle, "${inputName}", ${inputName}.data(), ${inputName}.size());
</#list>
MXPredForward(handle);
mx_uint output_index;
mx_uint *shape = 0;
mx_uint shape_len;
size_t size;
<#list tc.architectureOutputs as outputName>
output_index = ${outputName?index?c};
MXPredGetOutputShape(handle, output_index, &shape, &shape_len);
size = 1;
for (mx_uint i = 0; i < shape_len; ++i) size *= shape[i];
assert(size == ${outputName}.size());
MXPredGetOutput(handle, ${outputName?index?c}, &(${outputName}[0]), ${outputName}.size());
</#list>
}
void init(const std::string &json_file,
const std::string &param_file,
const std::vector<std::string> &input_keys,
const std::vector<std::vector<mx_uint>> &input_shapes,
const bool &use_gpu){
BufferFile json_data(json_file);
BufferFile param_data(param_file);
int dev_type = use_gpu ? 2 : 1;
int dev_id = 0;
handle = 0;
if (json_data.GetLength() == 0 ||
param_data.GetLength() == 0) {
std::exit(-1);
}
const mx_uint num_input_nodes = input_keys.size();
const char* input_keys_ptr[num_input_nodes];
for(mx_uint i = 0; i < num_input_nodes; i++){
input_keys_ptr[i] = input_keys[i].c_str();
}
mx_uint shape_data_size = 0;
mx_uint input_shape_indptr[input_shapes.size() + 1];
input_shape_indptr[0] = 0;
for(mx_uint i = 0; i < input_shapes.size(); i++){
input_shape_indptr[i+1] = input_shapes[i].size();
shape_data_size += input_shapes[i].size();
}
mx_uint input_shape_data[shape_data_size];
mx_uint index = 0;
for(mx_uint i = 0; i < input_shapes.size(); i++){
for(mx_uint j = 0; j < input_shapes[i].size(); j++){
input_shape_data[index] = input_shapes[i][j];
index++;
}
}
MXPredCreate((const char*)json_data.GetBuffer(),
(const char*)param_data.GetBuffer(),
static_cast<size_t>(param_data.GetLength()),
dev_type,
dev_id,
num_input_nodes,
input_keys_ptr,
input_shape_indptr,
input_shape_data,
&handle);
assert(handle);
}
};
#endif // ${tc.fileNameWithoutEnding?upper_case}
${element.name} = ${tc.join(element.inputs, " + ")}
<#include "OutputShape.ftl">
\ No newline at end of file
${element.name} = mx.symbol.BatchNorm(data=${element.inputs[0]},
fix_gamma=${element.fixGamma?string("True","False")},
name="${element.name}")
${element.name} = mx.symbol.concat(${tc.join(element.inputs, ", ")},
dim=1,
name="${element.name}")
<#include "OutputShape.ftl">
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment