Commit 99234812 authored by Thomas Michael Timmermanns's avatar Thomas Michael Timmermanns Committed by Thomas Michael Timmermanns

Added Generator and generation tests.

parent f6a5e5a3
script:
- git checkout ${TRAVIS_BRANCH}
- mvn clean install cobertura:cobertura org.eluder.coveralls:coveralls-maven-plugin:report --settings "settings.xml"
after_success:
- if [ "${TRAVIS_BRANCH}" == "master" ]; then mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B deploy --debug --settings "./settings.xml"; fi
......@@ -43,11 +43,13 @@
<SIUnit.version>0.0.10-SNAPSHOT</SIUnit.version>
<Common-MontiCar.version>0.0.10-SNAPSHOT</Common-MontiCar.version>
<Math.version>0.0.11-SNAPSHOT</Math.version>
<!-- .. Libraries .................................................. -->
<guava.version>18.0</guava.version>
<junit.version>4.12</junit.version>
<logback.version>1.1.2</logback.version>
<jscience.version>4.3.1</jscience.version>
<commons-cli.version>1.4</commons-cli.version>
<!-- .. Plugins ....................................................... -->
<monticore.plugin>4.5.3-SNAPSHOT</monticore.plugin>
......@@ -159,6 +161,12 @@
<scope>provided</scope>
</dependency>
<dependency>
<groupId>commons-cli</groupId>
<artifactId>commons-cli</artifactId>
<version>${commons-cli.version}</version>
</dependency>
<!-- .. Test Libraries ............................................... -->
<dependency>
......
......@@ -20,7 +20,7 @@ grammar CNNArch extends de.monticore.lang.math.Math {
(in:"input" | out:"output")
type:ArchType
Name&
(ArrayDeclaration)? NEWLINETOKEN*;
(ArrayDeclaration)?;
ArchType = ElementType "^" Shape;
......
......@@ -22,6 +22,7 @@ package de.monticore.lang.monticar.cnnarch._cocos;
import de.monticore.lang.monticar.cnnarch._ast.ASTCNNArchNode;
import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.CNNArchCompilationUnitSymbol;
import de.se_rwth.commons.logging.Log;
//check all cocos
......@@ -29,10 +30,23 @@ public class CNNArchCocos {
public static void checkAll(ArchitectureSymbol architecture){
ASTCNNArchNode node = (ASTCNNArchNode) architecture.getAstNode().get();
int findings = Log.getFindings().size();
createPreResolveChecker().checkAll(node);
if (Log.getFindings().isEmpty()){
if (findings == Log.getFindings().size()){
architecture.resolve();
if (Log.getFindings().isEmpty()){
if (findings == Log.getFindings().size()){
createPostResolveChecker().checkAll(node);
}
}
}
public static void checkAll(CNNArchCompilationUnitSymbol compilationUnit){
ASTCNNArchNode node = (ASTCNNArchNode) compilationUnit.getAstNode().get();
int findings = Log.getFindings().size();
createPreResolveChecker().checkAll(node);
if (findings == Log.getFindings().size()){
compilationUnit.getArchitecture().resolve();
if (findings == Log.getFindings().size()){
createPostResolveChecker().checkAll(node);
}
}
......
......@@ -31,7 +31,6 @@ import de.se_rwth.commons.Joiners;
import de.se_rwth.commons.logging.Log;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
public class CheckMethodLayer implements CNNArchASTMethodLayerCoCo{
......
......@@ -24,7 +24,6 @@
package de.monticore.lang.monticar.cnnarch._symboltable;
import de.monticore.symboltable.CommonScopeSpanningSymbol;
import de.monticore.symboltable.Symbol;
import java.util.*;
......
/**
*
* ******************************************************************************
* MontiCAR Modeling Family, www.se-rwth.de
* Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
* All rights reserved.
*
* This project is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 3.0 of the License, or (at your option) any later version.
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package de.monticore.lang.monticar.cnnarch.generator;
import de.monticore.io.paths.ModelPath;
import de.monticore.lang.monticar.cnnarch._cocos.CNNArchCocos;
import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.CNNArchCompilationUnitSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.CNNArchLanguage;
import de.monticore.symboltable.GlobalScope;
import de.monticore.symboltable.Scope;
import de.se_rwth.commons.logging.Log;
import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.nio.file.Path;
import java.util.Optional;
public class CNNArchGenerator {
private Target targetLanguage;
private String generationTargetPath;
public CNNArchGenerator() {
setTargetLanguage(Target.CPP);
setGenerationTargetPath("./target/generated-sources-cnnarch/");
}
public Target getTargetLanguage() {
return targetLanguage;
}
public void setTargetLanguage(Target targetLanguage) {
this.targetLanguage = targetLanguage;
}
public String getGenerationTargetPath() {
return generationTargetPath;
}
public void setGenerationTargetPath(String generationTargetPath) {
if (generationTargetPath.charAt(generationTargetPath.length() - 1) != '/') {
this.generationTargetPath = generationTargetPath + "/";
}
else {
this.generationTargetPath = generationTargetPath;
}
}
public void generate(Path modelsDirPath, String rootModelName){
final ModelPath mp = new ModelPath(modelsDirPath);
GlobalScope scope = new GlobalScope(mp, new CNNArchLanguage());
generate(scope, rootModelName);
}
public void generate(Scope scope, String rootModelName){
Optional<CNNArchCompilationUnitSymbol> compilationUnit = scope.resolve(rootModelName, CNNArchCompilationUnitSymbol.KIND);
if (!compilationUnit.isPresent()){
Log.error("could not resolve architecture " + rootModelName);
System.exit(1);
}
CNNArchCocos.checkAll(compilationUnit.get());
try{
ArchitectureSymbol architecture = compilationUnit.get().getArchitecture();
generateNetworkFile(architecture);
}
catch (IOException e){
Log.error(e.toString());
}
}
public String generateNetworkString(ArchitectureSymbol architecture){
TemplateController archTc = new TemplateController(architecture, targetLanguage);
return archTc.process();
}
public void generateNetworkFile(ArchitectureSymbol architecture) throws IOException{
File f = new File(getGenerationTargetPath() + getFileName(architecture));
Log.info(f.getName(), "FileCreation:");
if (!f.exists()) {
f.getParentFile().mkdirs();
if (!f.createNewFile()) {
Log.error("File could not be created");
}
}
FileWriter writer = new FileWriter(f);
TemplateController archTc = new TemplateController(architecture, targetLanguage);
archTc.process(writer);
writer.close();
}
public String getFileName(ArchitectureSymbol architecture){
String name = architecture.getEnclosingScope().getSpanningSymbol().get().getFullName();
name = name.replaceAll("\\.", "_").replaceAll("\\[", "_").replaceAll("\\]", "_");
String fileEnding = getTargetLanguage().toString();
if (getTargetLanguage() == Target.CPP){
fileEnding = ".h";
}
return name + "__network" + fileEnding;
}
}
/**
*
* ******************************************************************************
* MontiCAR Modeling Family, www.se-rwth.de
* Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
* All rights reserved.
*
* This project is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 3.0 of the License, or (at your option) any later version.
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package de.monticore.lang.monticar.cnnarch.generator;
import org.apache.commons.cli.*;
import java.nio.file.Path;
import java.nio.file.Paths;
public class CNNArchGeneratorCli {
public static final Option OPTION_MODELS_PATH = Option.builder("m")
.longOpt("models-dir")
.desc("full path to directory with CNNArch models e.g. C:\\Users\\vpupkin\\proj\\MyAwesomeAutopilot\\src\\main\\emam")
.hasArg(true)
.required(true)
.build();
public static final Option OPTION_ROOT_MODEL = Option.builder("r")
.longOpt("root-model")
.desc("fully qualified name of the root model e.g. de.rwth.vpupkin.modeling.mySuperAwesomeAutopilotComponent")
.hasArg(true)
.required(true)
.build();
public static final Option OPTION_OUTPUT_PATH = Option.builder("o")
.longOpt("output-dir")
.desc("full path to output directory for tests e.g. C:\\Users\\vpupkin\\proj\\MyAwesomeAutopilot\\target\\gen-cpp")
.hasArg(true)
.required(false)
.build();
public static final Option OPTION_TARGET_LANG = Option.builder("t")
.longOpt("target-language")
.desc("target language of network e.g. c++ or python")
.hasArg(true)
.required(false)
.build();
private CNNArchGeneratorCli() {
}
public static void main(String[] args) {
Options options = getOptions();
CommandLineParser parser = new DefaultParser();
CommandLine cliArgs = parseArgs(options, parser, args);
if (cliArgs != null) {
runGenerator(cliArgs);
}
}
private static Options getOptions() {
Options options = new Options();
options.addOption(OPTION_MODELS_PATH);
options.addOption(OPTION_ROOT_MODEL);
options.addOption(OPTION_OUTPUT_PATH);
options.addOption(OPTION_TARGET_LANG);
return options;
}
private static CommandLine parseArgs(Options options, CommandLineParser parser, String[] args) {
CommandLine cliArgs;
try {
cliArgs = parser.parse(options, args);
} catch (ParseException e) {
System.err.println("argument parsing exception: " + e.getMessage());
System.exit(1);
return null;
}
return cliArgs;
}
private static void runGenerator(CommandLine cliArgs) {
Path modelsDirPath = Paths.get(cliArgs.getOptionValue(OPTION_MODELS_PATH.getOpt()));
String rootModelName = cliArgs.getOptionValue(OPTION_ROOT_MODEL.getOpt());
String outputPath = cliArgs.getOptionValue(OPTION_OUTPUT_PATH.getOpt());
String targetLanguage = cliArgs.getOptionValue(OPTION_TARGET_LANG.getOpt());
CNNArchGenerator generator = new CNNArchGenerator();
if (outputPath != null){
generator.setGenerationTargetPath(outputPath);
}
if (targetLanguage != null){
generator.setTargetLanguage(Target.fromString(targetLanguage));
}
generator.generate(modelsDirPath, rootModelName);
}
}
/**
*
* ******************************************************************************
* MontiCAR Modeling Family, www.se-rwth.de
* Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
* All rights reserved.
*
* This project is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 3.0 of the License, or (at your option) any later version.
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package de.monticore.lang.monticar.cnnarch.generator;
import de.monticore.lang.monticar.cnnarch._symboltable.*;
import de.monticore.lang.monticar.cnnarch.predefined.Convolution;
import de.monticore.lang.monticar.cnnarch.predefined.FullyConnected;
import de.monticore.lang.monticar.cnnarch.predefined.Pooling;
import java.util.*;
public class LayerNameCreator {
private Map<LayerSymbol, String> layerToName = new HashMap<>();
private Map<String, LayerSymbol> nameToLayer = new HashMap<>();
public LayerNameCreator(ArchitectureSymbol architecture) {
name(architecture.getBody(), 1, new ArrayList<>());
}
public LayerSymbol getLayer(String name){
return nameToLayer.get(name);
}
public String getName(LayerSymbol layer){
return layerToName.get(layer);
}
protected int name(CompositeLayerSymbol compositeLayer, int stage, List<Integer> streamIndices){
if (compositeLayer.isParallel()){
int startStage = stage + 1;
streamIndices.add(1);
int lastIndex = streamIndices.size() - 1;
List<Integer> endStages = new ArrayList<>();
for (LayerSymbol subLayer : compositeLayer.getLayers()){
endStages.add(name(subLayer, startStage, streamIndices));
streamIndices.set(lastIndex, streamIndices.get(lastIndex) + 1);
}
streamIndices.remove(lastIndex);
return Collections.max(endStages) + 1;
}
else {
int endStage = stage;
for (LayerSymbol subLayer : compositeLayer.getLayers()){
endStage = name(subLayer, endStage, streamIndices);
}
return endStage;
}
}
protected int name(LayerSymbol layer, int stage, List<Integer> streamIndices){
if (layer instanceof CompositeLayerSymbol){
return name((CompositeLayerSymbol) layer, stage, streamIndices);
}
else if (layer instanceof MethodLayerSymbol){
if (layer.isAtomic()){
if (layer.getMaxSerialLength().get() > 0){
return add(layer, stage, streamIndices);
}
else {
return stage;
}
}
else {
LayerSymbol resolvedLayer = ((MethodLayerSymbol) layer).getResolvedThis().get();
return (name(resolvedLayer, stage, streamIndices));
}
}
else {
return add(layer, stage, streamIndices);
}
}
protected int add(LayerSymbol layer, int stage, List<Integer> streamIndices){
int endStage = stage;
if (!layerToName.containsKey(layer)) {
String name = createName(layer, endStage, streamIndices);
while (nameToLayer.containsKey(name)) {
endStage++;
name = createName(layer, endStage, streamIndices);
}
layerToName.put(layer, name);
nameToLayer.put(name, layer);
}
return endStage;
}
protected String createName(LayerSymbol layer, int stage, List<Integer> streamIndices){
if (layer instanceof IOLayerSymbol){
String name = createBaseName(layer);
IOLayerSymbol ioLayer = (IOLayerSymbol) layer;
if (ioLayer.getArrayAccess().isPresent()){
int arrayAccess = ioLayer.getArrayAccess().get().getIntValue().get();
name = name + arrayAccess;
}
return name;
}
else {
return createBaseName(layer) + stage + createStreamPostfix(streamIndices);
}
}
protected String createBaseName(LayerSymbol layer){
if (layer instanceof MethodLayerSymbol) {
MethodDeclarationSymbol method = ((MethodLayerSymbol) layer).getMethod();
if (method instanceof Convolution) {
return "conv";
} else if (method instanceof FullyConnected) {
return "fc";
} else if (method instanceof Pooling) {
return "pool";
} else {
return method.getName().toLowerCase();
}
}
else if (layer instanceof CompositeLayerSymbol){
return "group";
}
else {
return layer.getName();
}
}
protected String createStreamPostfix(List<Integer> streamIndices){
StringBuilder stringBuilder = new StringBuilder();
for (int streamIndex : streamIndices){
stringBuilder.append("_");
stringBuilder.append(streamIndex);
}
return stringBuilder.toString();
}
}
/**
*
* ******************************************************************************
* MontiCAR Modeling Family, www.se-rwth.de
* Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
* All rights reserved.
*
* This project is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 3.0 of the License, or (at your option) any later version.
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package de.monticore.lang.monticar.cnnarch.generator;
public enum Target {
PYTHON{
@Override
public String toString() {
return ".py";
}
},
CPP{
@Override
public String toString() {
return ".cpp";
}
};
public static Target fromString(String target){
switch (target.toLowerCase()){
case "python":
return PYTHON;
case "py":
return PYTHON;
case "cpp":
return CPP;
case "c++":
return CPP;
default:
throw new IllegalArgumentException();
}
}
}
/**
*
* ******************************************************************************
* MontiCAR Modeling Family, www.se-rwth.de
* Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
* All rights reserved.
*
* This project is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 3.0 of the License, or (at your option) any later version.
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package de.monticore.lang.monticar.cnnarch.generator;
import freemarker.template.Configuration;
import freemarker.template.TemplateExceptionHandler;
public class TemplateConfiguration {
public static Configuration get(){
Configuration cfg = new Configuration(Configuration.VERSION_2_3_23);
cfg.setClassForTemplateLoading(TemplateConfiguration.class, "/templates/");
cfg.setDefaultEncoding("UTF-8");
cfg.setTemplateExceptionHandler(TemplateExceptionHandler.RETHROW_HANDLER);
return cfg;
}
}
<#if tc.target == ".py">
import mxnet as mx
import logging
import os
import errno
import shutil
import numpy as np
from collections import namedtuple
Batch = namedtuple('Batch', ['data'])
logging.basicConfig(level=logging.DEBUG)
class Network:
<#list tc.architectureInputs as input>
${input} = None
</#list>
<#list tc.architectureOutputs as output>
${output} = None
</#list>
Module = None
_checkpoint_dir = 'checkpoints/'
def load(self):
self.Module.load(prefix=self._checkpoint_dir)
self.Module.bind(for_training=False,
data_shapes=[('data', (1,3,224,224))],
label_shapes=self.Module._label_shapes)
def predict(self, image):
# compute the predict probabilities
self.Module.forward(Batch([mx.nd.array(image)]))
prob = self.Module.get_outputs()[0].asnumpy()
# top-5
prob = np.squeeze(prob)
return np.argsort(prob)[::-1]
def train(self, train_iter, test_iter, batch_size, optimizer, num_epoch, checkpoint_period):
shutil.rmtree(self._checkpoint_dir)
try:
os.makedirs(self._checkpoint_dir)
except OSError:
if not os.path.isdir(self._checkpoint_dir):
raise
self.Module.fit(
train_data=train_iter,
eval_data=test_iter,
optimizer=optimizer,
batch_end_callback=mx.callback.Speedometer(batch_size, 50),
epoch_end_callback=mx.callback.do_checkpoint(prefix=self._checkpoint_dir+'${tc.architecture.name}', period=checkpoint_period),
num_epoch=num_epoch)
def __init__(self, context=mx.gpu()):
${tc.include(tc.architecture.body)}
self.Module = mx.mod.Module(symbol=mx.symbol.Group([${tc.join(tc.architectureOutputs, ",", "self.", "")}]),
data_names=[${tc.join(tc.architectureInputs, ",", "'", "'")}],
label_names=[${tc.join(tc.architectureOutputs, ",", "'", "_label'")}],
context=context)
<#elseif tc.target == ".cpp">
#include "mxnet-cpp/MxNetCpp.h"
using namespace std;
using namespace mxnet::cpp;
class Network{
<#list tc.architectureInputs as input>
Symbol m_${input};
</#list>
<#list tc.architectureOutputs as output>
Symbol m_${output};
</#list>
Module m_module;
public:
Network(Context context = Context::gpu());
<#list tc.architectureInputs as input>
Symbol get${input?capitalize}();
</#list>
<#list tc.architectureOutputs as output>
Symbol get${output?capitalize}();
</#list>
Module getModule();
};
<#list tc.architectureInputs as input>
Symbol Network::get${input?capitalize}(){
return m_${input};
}
</#list>
<#list tc.architectureOutputs as output>
Symbol Network::get${output?capitalize}(){
return m_${output};
}
</#list>