Commit 99e219f0 authored by Christopher Jan-Steffen Brix's avatar Christopher Jan-Steffen Brix

Merge branch 'master' into oneclick_nn_training

parents 97585803 8099a9c9
......@@ -27,7 +27,7 @@ masterJobLinux:
stage: linux
image: maven:3-jdk-8
script:
- mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B clean install deploy --settings settings.xml
- mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B clean deploy --settings settings.xml
- cat target/site/jacoco/index.html
- mvn package sonar:sonar -s settings.xml
only:
......
......@@ -8,16 +8,16 @@
<groupId>de.monticore.lang.monticar</groupId>
<artifactId>cnnarch-mxnet-generator</artifactId>
<version>0.2.8</version>
<version>0.2.12</version>
<!-- == PROJECT DEPENDENCIES ============================================= -->
<properties>
<!-- .. SE-Libraries .................................................. -->
<CNNArch.version>0.2.7-SNAPSHOT</CNNArch.version>
<CNNTrain.version>0.2.5</CNNTrain.version>
<embedded-montiarc-math-generator>0.0.25-20180812.120330-2</embedded-montiarc-math-generator>
<CNNArch.version>0.2.9</CNNArch.version>
<CNNTrain.version>0.2.6</CNNTrain.version>
<embedded-montiarc-math-opt-generator>0.1.4</embedded-montiarc-math-opt-generator>
<!-- .. Libraries .................................................. -->
<guava.version>18.0</guava.version>
......@@ -89,8 +89,8 @@
<dependency>
<groupId>de.monticore.lang.monticar</groupId>
<artifactId>embedded-montiarc-math-generator</artifactId>
<version>${embedded-montiarc-math-generator}</version>
<artifactId>embedded-montiarc-math-opt-generator</artifactId>
<version>${embedded-montiarc-math-opt-generator}</version>
</dependency>
......
......@@ -20,10 +20,11 @@
*/
package de.monticore.lang.monticar.cnnarch.mxnetgenerator;
import de.monticore.io.paths.ModelPath;
import de.monticore.lang.monticar.cnnarch.CNNArchGenerator;
import de.monticore.lang.monticar.cnnarch._cocos.CNNArchCocos;
import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureElementSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.CompositeElementSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.CNNArchCompilationUnitSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.CNNArchLanguage;
import de.monticore.lang.monticar.cnnarch.DataPathConfigParser;
......@@ -31,7 +32,6 @@ import de.monticore.lang.monticar.generator.FileContent;
import de.monticore.lang.monticar.generator.cmake.CMakeConfig;
import de.monticore.lang.monticar.generator.cmake.CMakeFindModule;
import de.monticore.lang.monticar.generator.cpp.GeneratorCPP;
import de.monticore.symboltable.GlobalScope;
import de.monticore.symboltable.Scope;
import de.se_rwth.commons.logging.Log;
......@@ -41,65 +41,63 @@ import java.nio.file.Path;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
import java.util.List;
public class CNNArch2MxNet implements CNNArchGenerator {
public class CNNArch2MxNet extends CNNArchGenerator {
private String generationTargetPath;
private String modelPath;
private boolean isSupportedLayer(ArchitectureElementSymbol element, LayerSupportChecker layerChecker){
List<ArchitectureElementSymbol> constructLayerElemList;
public CNNArch2MxNet() {
setGenerationTargetPath("./target/generated-sources-cnnarch/");
}
@Override
public boolean isCMakeRequired() {
return true;
}
public String getModelPath(){
return modelPath;
}
public void setModelPath(Path modelPath){
this.modelPath = modelPath.toString();
}
public String getGenerationTargetPath() {
if (generationTargetPath.charAt(generationTargetPath.length() - 1) != '/') {
this.generationTargetPath = generationTargetPath + "/";
if (element.getResolvedThis().get() instanceof CompositeElementSymbol) {
constructLayerElemList = ((CompositeElementSymbol)element.getResolvedThis().get()).getElements();
for (ArchitectureElementSymbol constructedLayerElement : constructLayerElemList) {
if (!isSupportedLayer(constructedLayerElement, layerChecker)) {
return false;
}
}
}
if (!layerChecker.isSupported(element.toString())) {
Log.error("Unsupported layer " + "'" + element.getName() + "'" + " for the backend MXNET.");
return false;
} else {
return true;
}
return generationTargetPath;
}
public void setGenerationTargetPath(String generationTargetPath) {
this.generationTargetPath = generationTargetPath;
private boolean supportCheck(ArchitectureSymbol architecture){
LayerSupportChecker layerChecker = new LayerSupportChecker();
for (ArchitectureElementSymbol element : ((CompositeElementSymbol)architecture.getBody()).getElements()){
if(!isSupportedLayer(element, layerChecker)) {
return false;
}
}
return true;
}
public void generate(Path modelsDirPath, String rootModelName){
final ModelPath mp = new ModelPath(modelsDirPath);
GlobalScope scope = new GlobalScope(mp, new CNNArchLanguage());
setModelPath(modelsDirPath);
generate(scope, rootModelName);
public CNNArch2MxNet() {
setGenerationTargetPath("./target/generated-sources-cnnarch/");
}
public void generate(Scope scope, String rootModelName){
Optional<CNNArchCompilationUnitSymbol> compilationUnit = scope.resolve(rootModelName, CNNArchCompilationUnitSymbol.KIND);
if (!compilationUnit.isPresent()){
Log.error("could not resolve architecture " + rootModelName);
System.exit(1);
quitGeneration();
}
CNNArchCocos.checkAll(compilationUnit.get());
if (!supportCheck(compilationUnit.get().getArchitecture())){
quitGeneration();
}
try{
String confPath = getModelPath() + "/data_paths.txt";
String confPath = getModelsDirPath() + "/data_paths.txt";
DataPathConfigParser newParserConfig = new DataPathConfigParser(confPath);
String dataPath = newParserConfig.getDataPath(rootModelName);
compilationUnit.get().getArchitecture().setDataPath(dataPath);
compilationUnit.get().getArchitecture().setComponentName(rootModelName);
generateFiles(compilationUnit.get().getArchitecture());
}
catch (IOException e){
} catch (IOException e){
Log.error(e.toString());
}
}
......@@ -127,41 +125,7 @@ public class CNNArch2MxNet implements CNNArchGenerator {
return fileContentMap;
}
private void checkValidGeneration(ArchitectureSymbol architecture){
if (architecture.getInputs().size() > 1){
Log.error("This cnn architecture has multiple inputs, " +
"which is currently not supported by the mxnetgenerator. "
, architecture.getSourcePosition());
}
if (architecture.getOutputs().size() > 1){
Log.error("This cnn architecture has multiple outputs, " +
"which is currently not supported by the mxnetgenerator. "
, architecture.getSourcePosition());
}
if (architecture.getOutputs().get(0).getDefinition().getType().getWidth() != 1 ||
architecture.getOutputs().get(0).getDefinition().getType().getHeight() != 1){
Log.error("This cnn architecture has a multi-dimensional output, " +
"which is currently not supported by the mxnetgenerator."
, architecture.getSourcePosition());
}
}
//check cocos with CNNArchCocos.checkAll(architecture) before calling this method.
public void generateFiles(ArchitectureSymbol architecture) throws IOException{
Map<String, String> fileContentMap = generateStrings(architecture);
generateFromFilecontentsMap(fileContentMap);
}
public void generateCMake(String rootModelName) {
Map<String, String> fileContentMap = generateCMakeContent(rootModelName);
try {
generateFromFilecontentsMap(fileContentMap);
} catch (IOException e) {
e.printStackTrace();
}
}
private void generateFromFilecontentsMap(Map<String, String> fileContentMap) throws IOException {
public void generateFromFilecontentsMap(Map<String, String> fileContentMap) throws IOException {
GeneratorCPP genCPP = new GeneratorCPP();
genCPP.setGenerationTargetPath(getGenerationTargetPath());
for (String fileName : fileContentMap.keySet()){
......
......@@ -19,6 +19,7 @@
* *******************************************************************************
*/
package de.monticore.lang.monticar.cnnarch.mxnetgenerator;
import de.se_rwth.commons.logging.Log;
import org.apache.commons.cli.*;
......@@ -73,13 +74,18 @@ public class CNNArch2MxNetCli {
try {
cliArgs = parser.parse(options, args);
} catch (ParseException e) {
System.err.println("argument parsing exception: " + e.getMessage());
System.exit(1);
Log.error("argument parsing exception: " + e.getMessage());
quitGeneration();
return null;
}
return cliArgs;
}
private static void quitGeneration(){
Log.error("Code generation is aborted");
System.exit(1);
}
private static void runGenerator(CommandLine cliArgs) {
Path modelsDirPath = Paths.get(cliArgs.getOptionValue(OPTION_MODELS_PATH.getOpt()));
String rootModelName = cliArgs.getOptionValue(OPTION_ROOT_MODEL.getOpt());
......
......@@ -95,8 +95,7 @@ public class CNNArchTemplateController {
if (isSoftmaxOutput(layer) || isLogisticRegressionOutput(layer)){
inputNames = getLayerInputs(layer.getInputElement().get());
}
else {
} else {
for (ArchitectureElementSymbol input : layer.getPrevious()) {
if (input.getOutputTypes().size() == 1) {
inputNames.add(getName(input));
......@@ -146,12 +145,10 @@ public class CNNArchTemplateController {
if (ioElement.isAtomic()){
if (ioElement.isInput()){
include(TEMPLATE_ELEMENTS_DIR_PATH, "Input", writer);
}
else {
} else {
include(TEMPLATE_ELEMENTS_DIR_PATH, "Output", writer);
}
}
else {
} else {
include(ioElement.getResolvedThis().get(), writer);
}
......@@ -168,8 +165,7 @@ public class CNNArchTemplateController {
String templateName = layer.getDeclaration().getName();
include(TEMPLATE_ELEMENTS_DIR_PATH, templateName, writer);
}
}
else {
} else {
include(layer.getResolvedThis().get(), writer);
}
......@@ -190,11 +186,9 @@ public class CNNArchTemplateController {
public void include(ArchitectureElementSymbol architectureElement, Writer writer){
if (architectureElement instanceof CompositeElementSymbol){
include((CompositeElementSymbol) architectureElement, writer);
}
else if (architectureElement instanceof LayerSymbol){
} else if (architectureElement instanceof LayerSymbol){
include((LayerSymbol) architectureElement, writer);
}
else {
} else {
include((IOSymbol) architectureElement, writer);
}
}
......@@ -207,15 +201,15 @@ public class CNNArchTemplateController {
}
public Map.Entry<String,String> process(String templateNameWithoutEnding, Target targetLanguage){
StringWriter writer = new StringWriter();
StringWriter newWriter = new StringWriter();
this.mainTemplateNameWithoutEnding = templateNameWithoutEnding;
this.targetLanguage = targetLanguage;
this.writer = writer;
this.writer = newWriter;
include("", templateNameWithoutEnding, writer);
include("", templateNameWithoutEnding, newWriter);
String fileEnding = targetLanguage.toString();
String fileName = getFileNameWithoutEnding() + fileEnding;
Map.Entry<String,String> fileContent = new AbstractMap.SimpleEntry<>(fileName, writer.toString());
Map.Entry<String,String> fileContent = new AbstractMap.SimpleEntry<>(fileName, newWriter.toString());
this.mainTemplateNameWithoutEnding = null;
this.targetLanguage = null;
......@@ -259,12 +253,12 @@ public class CNNArchTemplateController {
}
private boolean isTOutput(Class inputPredefinedLayerClass, ArchitectureElementSymbol architectureElement){
if (architectureElement.isOutput()){
if (architectureElement.getInputElement().isPresent() && architectureElement.getInputElement().get() instanceof LayerSymbol){
LayerSymbol inputLayer = (LayerSymbol) architectureElement.getInputElement().get();
if (inputPredefinedLayerClass.isInstance(inputLayer.getDeclaration())){
return true;
}
if (architectureElement.isOutput()
&& architectureElement.getInputElement().isPresent()
&& architectureElement.getInputElement().get() instanceof LayerSymbol){
LayerSymbol inputLayer = (LayerSymbol) architectureElement.getInputElement().get();
if (inputPredefinedLayerClass.isInstance(inputLayer.getDeclaration())){
return true;
}
}
return false;
......
......@@ -37,7 +37,9 @@ public class CNNTrain2MxNet implements CNNTrainGenerator {
it = configuration.getEntryMap().keySet().iterator();
while (it.hasNext()) {
String key = it.next().toString();
if (funcChecker.getUnsupportedElemList().contains(key)) it.remove();
if (funcChecker.getUnsupportedElemList().contains(key)) {
it.remove();
}
}
}
......@@ -52,12 +54,19 @@ public class CNNTrain2MxNet implements CNNTrainGenerator {
Iterator it = configuration.getOptimizer().getOptimizerParamMap().keySet().iterator();
while (it.hasNext()) {
String key = it.next().toString();
if (funcChecker.getUnsupportedElemList().contains(key)) it.remove();
if (funcChecker.getUnsupportedElemList().contains(key)) {
it.remove();
}
}
}
}
}
private static void quitGeneration(){
Log.error("Code generation is aborted");
System.exit(1);
}
public CNNTrain2MxNet() {
setGenerationTargetPath("./target/generated-sources-cnnarch/");
}
......@@ -89,7 +98,7 @@ public class CNNTrain2MxNet implements CNNTrainGenerator {
Optional<CNNTrainCompilationUnitSymbol> compilationUnit = scope.resolve(rootModelName, CNNTrainCompilationUnitSymbol.KIND);
if (!compilationUnit.isPresent()) {
Log.error("could not resolve training configuration " + rootModelName);
System.exit(1);
quitGeneration();
}
setInstanceName(compilationUnit.get().getFullName());
CNNTrainCocos.checkAll(compilationUnit.get());
......@@ -107,7 +116,7 @@ public class CNNTrain2MxNet implements CNNTrainGenerator {
genCPP.generateFile(new FileContent(fileContents.get(fileName), fileName));
}
} catch (IOException e) {
e.printStackTrace();
Log.error("CNNTrainer file could not be generated" + e.getMessage());
}
}
......
......@@ -89,8 +89,7 @@ public class ConfigurationData {
Class realClass = entry.getValue().getValue().getValue().getClass();
if (realClass == Boolean.class) {
valueAsString = (Boolean) entry.getValue().getValue().getValue() ? "True" : "False";
}
else if (lrPolicyClasses.contains(realClass)) {
} else if (lrPolicyClasses.contains(realClass)) {
valueAsString = "'" + valueAsString + "'";
}
mapToStrings.put(paramName, valueAsString);
......
......@@ -47,17 +47,14 @@ public class LayerNameCreator {
protected int name(ArchitectureElementSymbol architectureElement, int stage, List<Integer> streamIndices){
if (architectureElement instanceof CompositeElementSymbol){
return nameComposite((CompositeElementSymbol) architectureElement, stage, streamIndices);
}
else{
} else{
if (architectureElement.isAtomic()){
if (architectureElement.getMaxSerialLength().get() > 0){
return add(architectureElement, stage, streamIndices);
}
else {
} else {
return stage;
}
}
else {
} else {
ArchitectureElementSymbol resolvedElement = architectureElement.getResolvedThis().get();
return name(resolvedElement, stage, streamIndices);
}
......@@ -78,8 +75,7 @@ public class LayerNameCreator {
streamIndices.remove(lastIndex);
return Collections.max(endStages) + 1;
}
else {
} else {
int endStage = stage;
for (ArchitectureElementSymbol subElement : compositeElement.getElements()){
endStage = name(subElement, endStage, streamIndices);
......@@ -113,8 +109,7 @@ public class LayerNameCreator {
name = name + "_" + arrayAccess + "_";
}
return name;
}
else {
} else {
return createBaseName(architectureElement) + stage + createStreamPostfix(streamIndices) + "_";
}
}
......@@ -132,11 +127,9 @@ public class LayerNameCreator {
} else {
return layerDeclaration.getName().toLowerCase();
}
}
else if (architectureElement instanceof CompositeElementSymbol){
} else if (architectureElement instanceof CompositeElementSymbol){
return "group";
}
else {
} else {
return architectureElement.getName();
}
}
......
package de.monticore.lang.monticar.cnnarch.mxnetgenerator;
import static de.monticore.lang.monticar.cnnarch.predefined.AllPredefinedLayers.*;
import java.util.ArrayList;
import java.util.List;
public class LayerSupportChecker {
private List<String> unsupportedLayerList = new ArrayList();
public LayerSupportChecker() {
//Set the unsupported layers for the backend
//this.unsupportedLayerList.add(PREDEFINED_LAYER_NAME);
}
public boolean isSupported(String element) {
return !this.unsupportedLayerList.contains(element);
}
}
......@@ -43,6 +43,11 @@ public class TemplateConfiguration {
configuration.setTemplateExceptionHandler(TemplateExceptionHandler.RETHROW_HANDLER);
}
private static void quitGeneration(){
Log.error("Code generation is aborted");
System.exit(1);
}
public Configuration getConfiguration() {
return configuration;
}
......@@ -58,14 +63,12 @@ public class TemplateConfiguration {
try{
Template template = TemplateConfiguration.get().getTemplate(templatePath);
template.process(ftlContext, writer);
}
catch (IOException e) {
} catch (IOException e) {
Log.error("Freemarker could not find template " + templatePath + " :\n" + e.getMessage());
System.exit(1);
}
catch (TemplateException e){
quitGeneration();
} catch (TemplateException e){
Log.error("An exception occured in template " + templatePath + " :\n" + e.getMessage());
System.exit(1);
quitGeneration();
}
}
......
......@@ -25,12 +25,14 @@ public class TrainParamSupportChecker implements CNNTrainVisitor {
public TrainParamSupportChecker() {
}
public String unsupportedOptFlag = "unsupported_optimizer";
public static final String unsupportedOptFlag = "unsupported_optimizer";
public List getUnsupportedElemList(){
return this.unsupportedElemList;
}
//Empty visit method denotes that the corresponding training parameter is supported.
//To set a training parameter as unsupported, add the corresponding node to the unsupportedElemList
public void visit(ASTNumEpochEntry node){}
public void visit(ASTBatchSizeEntry node){}
......
cmake_minimum_required(VERSION 3.5)
set(CMAKE_CXX_STANDARD 11)
set(CMAKE_CXX_STANDARD 14)
project(alexnet LANGUAGES CXX)
......@@ -16,8 +16,8 @@ set(LIBS ${LIBS} mxnet)
# create static library
include_directories(${INCLUDE_DIRS})
add_library(alexnet alexnet.h)
target_include_directories(alexnet PUBLIC ${CMAKE_CURRENT_SOURCE_DIR})
add_library(alexnet alexnet.cpp)
target_include_directories(alexnet PUBLIC ${CMAKE_CURRENT_SOURCE_DIR} ${INCLUDE_DIRS})
target_link_libraries(alexnet PUBLIC ${LIBS})
set_target_properties(alexnet PROPERTIES LINKER_LANGUAGE CXX)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment