Commit 65ed59d4 authored by Sebastian Nickels's avatar Sebastian Nickels

Enable more in-depth generator-specific compability testing of architectures before generating them

parent cf4423f2
Pipeline #142105 failed with stages
in 3 minutes and 33 seconds
package de.monticore.lang.monticar.cnnarch.mxnetgenerator;
import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureSymbol;
import de.se_rwth.commons.logging.Log;
public class ArchitectureSupportChecker {
public ArchitectureSupportChecker() {}
// Overload functions returning always true to enable the features
protected boolean checkMultipleStreams(ArchitectureSymbol architecture) {
if (architecture.getStreams().size() != 1) {
Log.error("This cnn architecture has multiple instructions, " +
"which is currently not supported by the code generator. "
, architecture.getSourcePosition());
return false;
}
return true;
}
protected boolean checkMultipleInputs(ArchitectureSymbol architecture) {
if (architecture.getInputs().size() > 1) {
Log.error("This cnn architecture has multiple inputs, " +
"which is currently not supported by the code generator. "
, architecture.getSourcePosition());
return false;
}
return true;
}
protected boolean checkMultipleOutputs(ArchitectureSymbol architecture) {
if (architecture.getOutputs().size() > 1) {
Log.error("This cnn architecture has multiple outputs, " +
"which is currently not supported by the code generator. "
, architecture.getSourcePosition());
return false;
}
return true;
}
protected boolean checkMultiDimensionalOutput(ArchitectureSymbol architecture) {
if (architecture.getOutputs().get(0).getDefinition().getType().getWidth() != 1 ||
architecture.getOutputs().get(0).getDefinition().getType().getHeight() != 1) {
Log.error("This cnn architecture has a multi-dimensional output, " +
"which is currently not supported by the code generator."
, architecture.getSourcePosition());
return false;
}
return true;
}
public boolean check(ArchitectureSymbol architecture) {
return checkMultipleStreams(architecture)
&& checkMultipleInputs(architecture)
&& checkMultipleOutputs(architecture)
&& checkMultiDimensionalOutput(architecture);
}
}
......@@ -45,59 +45,33 @@ import java.util.List;
public class CNNArch2MxNet extends CNNArchGenerator {
private boolean isSupportedLayer(ArchitectureElementSymbol element, LayerSupportChecker layerChecker){
List<ArchitectureElementSymbol> constructLayerElemList;
if (element.getResolvedThis().get() instanceof CompositeElementSymbol) {
constructLayerElemList = ((CompositeElementSymbol)element.getResolvedThis().get()).getElements();
for (ArchitectureElementSymbol constructedLayerElement : constructLayerElemList) {
if (!isSupportedLayer(constructedLayerElement, layerChecker)) {
return false;
}
}
}
if (!layerChecker.isSupported(element.toString())) {
Log.error("Unsupported layer " + "'" + element.getName() + "'" + " for the backend MXNET.");
return false;
} else {
return true;
}
}
private boolean supportCheck(ArchitectureSymbol architecture){
List<CompositeElementSymbol> streams = architecture.getStreams();
// This generator only supports one stream
if (streams.size() != 1)
{
return false;
}
LayerSupportChecker layerChecker = new LayerSupportChecker();
for (CompositeElementSymbol stream : streams) {
for (ArchitectureElementSymbol element : stream.getElements()) {
if (!isSupportedLayer(element, layerChecker)) {
return false;
}
}
}
return true;
}
protected ArchitectureSupportChecker architectureSupportChecker;
protected LayerSupportChecker layerSupportChecker;
public CNNArch2MxNet() {
architectureSupportChecker = new ArchitectureSupportChecker();
layerSupportChecker = new LayerSupportChecker();
setGenerationTargetPath("./target/generated-sources-cnnarch/");
}
public void generate(Scope scope, String rootModelName){
public boolean generate(Scope scope, String rootModelName) {
Optional<CNNArchCompilationUnitSymbol> compilationUnit = scope.resolve(rootModelName, CNNArchCompilationUnitSymbol.KIND);
if (!compilationUnit.isPresent()){
Log.error("could not resolve architecture " + rootModelName);
quitGeneration();
return false;
}
CNNArchCocos.checkAll(compilationUnit.get());
if (!supportCheck(compilationUnit.get().getArchitecture())){
quitGeneration();
ArchitectureSymbol architecture = compilationUnit.get().getArchitecture();
if (!architectureSupportChecker.check(architecture)) {
return false;
}
if (!layerSupportChecker.check(architecture)) {
return false;
}
try{
......@@ -109,7 +83,10 @@ public class CNNArch2MxNet extends CNNArchGenerator {
generateFiles(compilationUnit.get().getArchitecture());
} catch (IOException e){
Log.error(e.toString());
return false;
}
return true;
}
//check cocos with CNNArchCocos.checkAll(architecture) before calling this method.
......@@ -133,8 +110,6 @@ public class CNNArch2MxNet extends CNNArchGenerator {
temp = archTc.process("CNNBufferFile", Target.CPP);
fileContentMap.put("CNNBufferFile.h", temp.getValue());
checkValidGeneration(architecture);
return fileContentMap;
}
......
......@@ -79,6 +79,9 @@ public class GenericCNNArchCli {
if (outputPath != null){
cnnArchGenerator.setGenerationTargetPath(outputPath);
}
cnnArchGenerator.generate(modelsDirPath, rootModelName);
if (!cnnArchGenerator.generate(modelsDirPath, rootModelName)) {
Log.error("Code generation failed");
}
}
}
\ No newline at end of file
package de.monticore.lang.monticar.cnnarch.mxnetgenerator;
import de.monticore.lang.monticar.cnnarch.predefined.AllPredefinedLayers;
import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureElementSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.CompositeElementSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.LayerDeclarationSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.LayerSymbol;
import de.se_rwth.commons.logging.Log;
import java.util.ArrayList;
import java.util.List;
public class LayerSupportChecker {
private List<String> unsupportedLayerList = new ArrayList<>();
protected List<String> supportedLayerList = new ArrayList<>();
public LayerSupportChecker() {
//Set the unsupported layers for the backend
//this.unsupportedLayerList.add(PREDEFINED_LAYER_NAME);
supportedLayerList.add(AllPredefinedLayers.FULLY_CONNECTED_NAME);
supportedLayerList.add(AllPredefinedLayers.CONVOLUTION_NAME);
supportedLayerList.add(AllPredefinedLayers.SOFTMAX_NAME);
supportedLayerList.add(AllPredefinedLayers.SIGMOID_NAME);
supportedLayerList.add(AllPredefinedLayers.TANH_NAME);
supportedLayerList.add(AllPredefinedLayers.RELU_NAME);
supportedLayerList.add(AllPredefinedLayers.DROPOUT_NAME);
supportedLayerList.add(AllPredefinedLayers.POOLING_NAME);
supportedLayerList.add(AllPredefinedLayers.GLOBAL_POOLING_NAME);
supportedLayerList.add(AllPredefinedLayers.LRN_NAME);
supportedLayerList.add(AllPredefinedLayers.BATCHNORM_NAME);
supportedLayerList.add(AllPredefinedLayers.SPLIT_NAME);
supportedLayerList.add(AllPredefinedLayers.GET_NAME);
supportedLayerList.add(AllPredefinedLayers.ADD_NAME);
supportedLayerList.add(AllPredefinedLayers.CONCATENATE_NAME);
supportedLayerList.add(AllPredefinedLayers.FLATTEN_NAME);
}
private boolean isSupportedLayer(ArchitectureElementSymbol element){
ArchitectureElementSymbol resolvedElement = element.getResolvedThis().get();
List<ArchitectureElementSymbol> constructLayerElemList;
if (resolvedElement instanceof CompositeElementSymbol) {
constructLayerElemList = ((CompositeElementSymbol) resolvedElement).getElements();
for (ArchitectureElementSymbol constructedLayerElement : constructLayerElemList) {
if (!isSupportedLayer(constructedLayerElement)) {
return false;
}
}
return true;
}
// Support all inputs and outputs
if (resolvedElement.isInput() || resolvedElement.isOutput()) {
return true;
}
// Support all layer declarations
if (resolvedElement instanceof LayerSymbol) {
if (!((LayerSymbol) resolvedElement).getDeclaration().isPredefined()) {
return true;
}
}
if (!supportedLayerList.contains(element.toString())) {
Log.error("Unsupported layer " + "'" + element.getName() + "'" + " for the current backend.");
return false;
} else {
return true;
}
}
public boolean isSupported(String element) {
return !this.unsupportedLayerList.contains(element);
public boolean check(ArchitectureSymbol architecture) {
for (CompositeElementSymbol stream : architecture.getStreams()) {
for (ArchitectureElementSymbol element : stream.getElements()) {
if (!isSupportedLayer(element)) {
return false;
}
}
}
return true;
}
}
......@@ -96,7 +96,7 @@ public class GenerationTest extends AbstractSymtabTest{
Log.getFindings().clear();
String[] args = {"-m", "src/test/resources/architectures", "-r", "ThreeInputCNN_M14"};
CNNArch2MxNetCli.main(args);
assertTrue(Log.getFindings().size() == 1);
assertTrue(Log.getFindings().size() == 2);
}
@Test
......@@ -107,22 +107,20 @@ public class GenerationTest extends AbstractSymtabTest{
assertTrue(Log.getFindings().isEmpty());
}
/* TODO: Change quitGeneration() call and maybe add Exception?
@Test
public void testMultipleStreams() throws IOException, TemplateException {
Log.getFindings().clear();
String[] args = {"-m", "src/test/resources/invalid_tests", "-r", "MultipleStreams"};
CNNArch2MxNetCli.main(args);
//assertTrue(Log.getFindings().isEmpty());
assertTrue(Log.getFindings().size() == 2);
}
*/
@Test
public void testMultipleOutputs() throws IOException, TemplateException {
Log.getFindings().clear();
String[] args = {"-m", "src/test/resources/valid_tests", "-r", "MultipleOutputs"};
String[] args = {"-m", "src/test/resources/invalid_tests", "-r", "MultipleOutputs"};
CNNArch2MxNetCli.main(args);
assertTrue(Log.getFindings().size() == 3);
assertTrue(Log.getFindings().size() == 2);
}
@Test
......
MultipleStreams data/MultipleStreams
\ No newline at end of file
MultipleStreams data/MultipleStreams
MultipleOutputs data/MultipleOutputs
\ No newline at end of file
CifarClassifierNetwork data/CifarClassifierNetwork
MultipleOutputs data/MultipleOutputs
\ No newline at end of file
CifarClassifierNetwork data/CifarClassifierNetwork
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment