Commit 20d0b24c authored by Evgeny Kusmenko's avatar Evgeny Kusmenko

Merge branch 'refactor-architecture-creation' into 'master'

Refactor architecture creation

See merge request !25
parents 6bf25ee8 6a467f0f
Pipeline #148379 passed with stages
in 4 minutes and 25 seconds
......@@ -8,7 +8,7 @@
<groupId>de.monticore.lang.monticar</groupId>
<artifactId>cnnarch-mxnet-generator</artifactId>
<version>0.2.14-SNAPSHOT</version>
<version>0.2.15-SNAPSHOT</version>
<!-- == PROJECT DEPENDENCIES ============================================= -->
......
......@@ -21,13 +21,9 @@
package de.monticore.lang.monticar.cnnarch.mxnetgenerator;
import de.monticore.lang.monticar.cnnarch.CNNArchGenerator;
import de.monticore.lang.monticar.cnnarch._cocos.CNNArchCocos;
import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureElementSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.CompositeElementSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.CNNArchCompilationUnitSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.CNNArchLanguage;
import de.monticore.lang.monticar.cnnarch.DataPathConfigParser;
import de.monticore.lang.monticar.cnnarch.mxnetgenerator.checker.AllowAllLayerSupportChecker;
import de.monticore.lang.monticar.generator.FileContent;
import de.monticore.lang.monticar.generator.cmake.CMakeConfig;
import de.monticore.lang.monticar.generator.cmake.CMakeFindModule;
......@@ -36,67 +32,25 @@ import de.monticore.symboltable.Scope;
import de.se_rwth.commons.logging.Log;
import java.io.IOException;
import java.lang.System;
import java.nio.file.Path;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
import java.util.List;
public class CNNArch2MxNet extends CNNArchGenerator {
private boolean isSupportedLayer(ArchitectureElementSymbol element, LayerSupportChecker layerChecker){
List<ArchitectureElementSymbol> constructLayerElemList;
if (element.getResolvedThis().get() instanceof CompositeElementSymbol) {
constructLayerElemList = ((CompositeElementSymbol)element.getResolvedThis().get()).getElements();
for (ArchitectureElementSymbol constructedLayerElement : constructLayerElemList) {
if (!isSupportedLayer(constructedLayerElement, layerChecker)) {
return false;
}
}
}
if (!layerChecker.isSupported(element.toString())) {
Log.error("Unsupported layer " + "'" + element.getName() + "'" + " for the backend MXNET.");
return false;
} else {
return true;
}
}
private boolean supportCheck(ArchitectureSymbol architecture){
LayerSupportChecker layerChecker = new LayerSupportChecker();
for (ArchitectureElementSymbol element : ((CompositeElementSymbol)architecture.getBody()).getElements()){
if(!isSupportedLayer(element, layerChecker)) {
return false;
}
}
return true;
}
public CNNArch2MxNet() {
setGenerationTargetPath("./target/generated-sources-cnnarch/");
}
public void generate(Scope scope, String rootModelName){
Optional<CNNArchCompilationUnitSymbol> compilationUnit = scope.resolve(rootModelName, CNNArchCompilationUnitSymbol.KIND);
if (!compilationUnit.isPresent()){
Log.error("could not resolve architecture " + rootModelName);
quitGeneration();
}
CNNArchCocos.checkAll(compilationUnit.get());
if (!supportCheck(compilationUnit.get().getArchitecture())){
quitGeneration();
}
CNNArchSymbolCompiler symbolCompiler = new CNNArchSymbolCompiler(new AllowAllLayerSupportChecker());
ArchitectureSymbol architectureSymbol = symbolCompiler.compileArchitectureSymbol(scope, rootModelName);
try{
String confPath = getModelsDirPath() + "/data_paths.txt";
DataPathConfigParser newParserConfig = new DataPathConfigParser(confPath);
String dataPath = newParserConfig.getDataPath(rootModelName);
compilationUnit.get().getArchitecture().setDataPath(dataPath);
compilationUnit.get().getArchitecture().setComponentName(rootModelName);
generateFiles(compilationUnit.get().getArchitecture());
architectureSymbol.setDataPath(dataPath);
architectureSymbol.setComponentName(rootModelName);
generateFiles(architectureSymbol);
} catch (IOException e){
Log.error(e.toString());
}
......
package de.monticore.lang.monticar.cnnarch.mxnetgenerator;
import de.monticore.io.paths.ModelPath;
import de.monticore.lang.monticar.cnnarch._cocos.CNNArchCocos;
import de.monticore.lang.monticar.cnnarch._symboltable.*;
import de.monticore.lang.monticar.cnnarch.mxnetgenerator.checker.LayerSupportChecker;
import de.monticore.symboltable.GlobalScope;
import de.monticore.symboltable.Scope;
import de.se_rwth.commons.logging.Log;
import java.nio.file.Path;
import java.util.List;
import java.util.Optional;
public class CNNArchSymbolCompiler {
private final LayerSupportChecker layerChecker;
public CNNArchSymbolCompiler(final LayerSupportChecker layerChecker) {
this.layerChecker = layerChecker;
}
public ArchitectureSymbol compileArchitectureSymbolFromModelsDir(
final Path modelsDirPath, final String rootModel) {
ModelPath mp = new ModelPath(modelsDirPath);
GlobalScope scope = new GlobalScope(mp, new CNNArchLanguage());
return compileArchitectureSymbol(scope, rootModel);
}
public ArchitectureSymbol compileArchitectureSymbol(Scope scope, String rootModelName) {
Optional<CNNArchCompilationUnitSymbol> compilationUnit = scope.resolve(rootModelName, CNNArchCompilationUnitSymbol.KIND);
if (!compilationUnit.isPresent()){
failWithMessage("Could not resolve architecture " + rootModelName);
}
CNNArchCocos.checkAll(compilationUnit.get());
if (!supportCheck(compilationUnit.get().getArchitecture())){
failWithMessage("Architecture not supported by generator");
}
return compilationUnit.get().getArchitecture();
}
private void failWithMessage(final String message) {
Log.error(message);
System.exit(1);
}
private boolean supportCheck(ArchitectureSymbol architecture){
for (ArchitectureElementSymbol element : ((CompositeElementSymbol)architecture.getBody()).getElements()){
if(!isSupportedLayer(element, layerChecker)) {
return false;
}
}
return true;
}
private boolean isSupportedLayer(ArchitectureElementSymbol element, LayerSupportChecker layerChecker){
List<ArchitectureElementSymbol> constructLayerElemList;
if (element.getResolvedThis().get() instanceof CompositeElementSymbol) {
constructLayerElemList = ((CompositeElementSymbol)element.getResolvedThis().get()).getElements();
for (ArchitectureElementSymbol constructedLayerElement : constructLayerElemList) {
if (!isSupportedLayer(constructedLayerElement, layerChecker)) {
return false;
}
}
}
if (!layerChecker.isSupported(element.toString())) {
Log.error("Unsupported layer " + "'" + element.getName() + "'" + " for the backend.");
return false;
} else {
return true;
}
}
}
package de.monticore.lang.monticar.cnnarch.mxnetgenerator;
package de.monticore.lang.monticar.cnnarch.mxnetgenerator.checker;
import de.monticore.lang.monticar.cnnarch.mxnetgenerator.checker.LayerSupportChecker;
import java.util.ArrayList;
import java.util.List;
public class LayerSupportChecker {
public class AllowAllLayerSupportChecker implements LayerSupportChecker {
private List<String> unsupportedLayerList = new ArrayList<>();
public LayerSupportChecker() {
public AllowAllLayerSupportChecker() {
//Set the unsupported layers for the backend
//this.unsupportedLayerList.add(PREDEFINED_LAYER_NAME);
}
@Override
public boolean isSupported(String element) {
return !this.unsupportedLayerList.contains(element);
}
......
package de.monticore.lang.monticar.cnnarch.mxnetgenerator.checker;
public interface LayerSupportChecker {
boolean isSupported(String element);
}
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment