Commit 7a65ba70 authored by Evgeny Kusmenko's avatar Evgeny Kusmenko

Merge branch 'unsupported_layer_check' into 'master'

Unsupported layer check

See merge request !19
parents f426dffd 23c1cad8
Pipeline #99689 passed with stages
in 6 minutes and 8 seconds
...@@ -101,6 +101,11 @@ ...@@ -101,6 +101,11 @@
<version>${junit.version}</version> <version>${junit.version}</version>
<scope>test</scope> <scope>test</scope>
</dependency> </dependency>
<dependency>
<groupId>com.github.stefanbirkner</groupId>
<artifactId>system-rules</artifactId>
<version>1.3.0</version>
</dependency>
<dependency> <dependency>
<groupId>ch.qos.logback</groupId> <groupId>ch.qos.logback</groupId>
......
...@@ -24,6 +24,8 @@ import de.monticore.io.paths.ModelPath; ...@@ -24,6 +24,8 @@ import de.monticore.io.paths.ModelPath;
import de.monticore.lang.monticar.cnnarch.CNNArchGenerator; import de.monticore.lang.monticar.cnnarch.CNNArchGenerator;
import de.monticore.lang.monticar.cnnarch._cocos.CNNArchCocos; import de.monticore.lang.monticar.cnnarch._cocos.CNNArchCocos;
import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureSymbol; import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureElementSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.CompositeElementSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.CNNArchCompilationUnitSymbol; import de.monticore.lang.monticar.cnnarch._symboltable.CNNArchCompilationUnitSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.CNNArchLanguage; import de.monticore.lang.monticar.cnnarch._symboltable.CNNArchLanguage;
import de.monticore.lang.monticar.generator.FileContent; import de.monticore.lang.monticar.generator.FileContent;
...@@ -44,6 +46,16 @@ public class CNNArch2Caffe2 implements CNNArchGenerator{ ...@@ -44,6 +46,16 @@ public class CNNArch2Caffe2 implements CNNArchGenerator{
private String generationTargetPath; private String generationTargetPath;
private void supportCheck(ArchitectureSymbol architecture){
LayerSupportChecker layerChecker = new LayerSupportChecker();
for (ArchitectureElementSymbol element : ((CompositeElementSymbol)architecture.getBody()).getElements()){
if (!layerChecker.isSupported(element.toString())) {
Log.error("Unsupported layer " + "'" + element.getName() + "'" + " for the backend CAFFE2. Code generation aborted.");
System.exit(1);
}
}
}
public CNNArch2Caffe2() { public CNNArch2Caffe2() {
setGenerationTargetPath("./target/generated-sources-cnnarch/"); setGenerationTargetPath("./target/generated-sources-cnnarch/");
} }
...@@ -78,6 +90,7 @@ public class CNNArch2Caffe2 implements CNNArchGenerator{ ...@@ -78,6 +90,7 @@ public class CNNArch2Caffe2 implements CNNArchGenerator{
} }
CNNArchCocos.checkAll(compilationUnit.get()); CNNArchCocos.checkAll(compilationUnit.get());
supportCheck(compilationUnit.get().getArchitecture());
try{ try{
generateFiles(compilationUnit.get().getArchitecture()); generateFiles(compilationUnit.get().getArchitecture());
......
package de.monticore.lang.monticar.cnnarch.caffe2generator;
import static de.monticore.lang.monticar.cnnarch.predefined.AllPredefinedLayers.*;
import java.util.ArrayList;
import java.util.List;
public class LayerSupportChecker {
private List<String> unsupportedLayerList = new ArrayList();
public LayerSupportChecker() {
//Set the unsupported layers for the backend
this.unsupportedLayerList.add(ADD_NAME);
this.unsupportedLayerList.add(SPLIT_NAME);
this.unsupportedLayerList.add(GET_NAME);
this.unsupportedLayerList.add(CONCATENATE_NAME);
this.unsupportedLayerList.add(BATCHNORM_NAME);
}
public boolean isSupported(String element) {
return !this.unsupportedLayerList.contains(element);
}
}
...@@ -23,6 +23,7 @@ package de.monticore.lang.monticar.cnnarch.caffe2generator; ...@@ -23,6 +23,7 @@ package de.monticore.lang.monticar.cnnarch.caffe2generator;
import de.se_rwth.commons.logging.Log; import de.se_rwth.commons.logging.Log;
import freemarker.template.TemplateException; import freemarker.template.TemplateException;
import org.junit.Before; import org.junit.Before;
import org.junit.Rule;
import org.junit.Test; import org.junit.Test;
import java.io.IOException; import java.io.IOException;
...@@ -30,9 +31,12 @@ import java.nio.file.Path; ...@@ -30,9 +31,12 @@ import java.nio.file.Path;
import java.nio.file.Paths; import java.nio.file.Paths;
import java.util.*; import java.util.*;
import org.junit.contrib.java.lang.system.ExpectedSystemExit;
import static junit.framework.TestCase.assertTrue; import static junit.framework.TestCase.assertTrue;
public class GenerationTest extends AbstractSymtabTest{ public class GenerationTest extends AbstractSymtabTest{
@Rule
public final ExpectedSystemExit exit = ExpectedSystemExit.none();
@Before @Before
public void setUp() { public void setUp() {
...@@ -61,16 +65,8 @@ public class GenerationTest extends AbstractSymtabTest{ ...@@ -61,16 +65,8 @@ public class GenerationTest extends AbstractSymtabTest{
public void testAlexnetGeneration() throws IOException, TemplateException { public void testAlexnetGeneration() throws IOException, TemplateException {
Log.getFindings().clear(); Log.getFindings().clear();
String[] args = {"-m", "src/test/resources/architectures", "-r", "Alexnet", "-o", "./target/generated-sources-cnnarch/"}; String[] args = {"-m", "src/test/resources/architectures", "-r", "Alexnet", "-o", "./target/generated-sources-cnnarch/"};
exit.expectSystemExit();
CNNArch2Caffe2Cli.main(args); CNNArch2Caffe2Cli.main(args);
assertTrue(Log.getFindings().isEmpty());
checkFilesAreEqual(
Paths.get("./target/generated-sources-cnnarch"),
Paths.get("./src/test/resources/target_code"),
Arrays.asList(
"CNNCreator_Alexnet.py",
"CNNPredictor_Alexnet.h",
"execute_Alexnet"));
} }
@Test @Test
...@@ -94,8 +90,8 @@ public class GenerationTest extends AbstractSymtabTest{ ...@@ -94,8 +90,8 @@ public class GenerationTest extends AbstractSymtabTest{
public void testThreeInputCNNGeneration() throws IOException, TemplateException { public void testThreeInputCNNGeneration() throws IOException, TemplateException {
Log.getFindings().clear(); Log.getFindings().clear();
String[] args = {"-m", "src/test/resources/architectures", "-r", "ThreeInputCNN_M14"}; String[] args = {"-m", "src/test/resources/architectures", "-r", "ThreeInputCNN_M14"};
exit.expectSystemExit();
CNNArch2Caffe2Cli.main(args); CNNArch2Caffe2Cli.main(args);
assertTrue(Log.getFindings().size() == 1);
} }
@Test @Test
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment