Commit 7a65ba70 authored by Evgeny Kusmenko's avatar Evgeny Kusmenko

Merge branch 'unsupported_layer_check' into 'master'

Unsupported layer check

See merge request !19
parents f426dffd 23c1cad8
Pipeline #99689 passed with stages
in 6 minutes and 8 seconds
......@@ -101,6 +101,11 @@
<version>${junit.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>com.github.stefanbirkner</groupId>
<artifactId>system-rules</artifactId>
<version>1.3.0</version>
</dependency>
<dependency>
<groupId>ch.qos.logback</groupId>
......
......@@ -24,6 +24,8 @@ import de.monticore.io.paths.ModelPath;
import de.monticore.lang.monticar.cnnarch.CNNArchGenerator;
import de.monticore.lang.monticar.cnnarch._cocos.CNNArchCocos;
import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureElementSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.CompositeElementSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.CNNArchCompilationUnitSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.CNNArchLanguage;
import de.monticore.lang.monticar.generator.FileContent;
......@@ -44,6 +46,16 @@ public class CNNArch2Caffe2 implements CNNArchGenerator{
private String generationTargetPath;
private void supportCheck(ArchitectureSymbol architecture){
LayerSupportChecker layerChecker = new LayerSupportChecker();
for (ArchitectureElementSymbol element : ((CompositeElementSymbol)architecture.getBody()).getElements()){
if (!layerChecker.isSupported(element.toString())) {
Log.error("Unsupported layer " + "'" + element.getName() + "'" + " for the backend CAFFE2. Code generation aborted.");
System.exit(1);
}
}
}
public CNNArch2Caffe2() {
setGenerationTargetPath("./target/generated-sources-cnnarch/");
}
......@@ -78,6 +90,7 @@ public class CNNArch2Caffe2 implements CNNArchGenerator{
}
CNNArchCocos.checkAll(compilationUnit.get());
supportCheck(compilationUnit.get().getArchitecture());
try{
generateFiles(compilationUnit.get().getArchitecture());
......
package de.monticore.lang.monticar.cnnarch.caffe2generator;
import static de.monticore.lang.monticar.cnnarch.predefined.AllPredefinedLayers.*;
import java.util.ArrayList;
import java.util.List;
public class LayerSupportChecker {
private List<String> unsupportedLayerList = new ArrayList();
public LayerSupportChecker() {
//Set the unsupported layers for the backend
this.unsupportedLayerList.add(ADD_NAME);
this.unsupportedLayerList.add(SPLIT_NAME);
this.unsupportedLayerList.add(GET_NAME);
this.unsupportedLayerList.add(CONCATENATE_NAME);
this.unsupportedLayerList.add(BATCHNORM_NAME);
}
public boolean isSupported(String element) {
return !this.unsupportedLayerList.contains(element);
}
}
......@@ -23,6 +23,7 @@ package de.monticore.lang.monticar.cnnarch.caffe2generator;
import de.se_rwth.commons.logging.Log;
import freemarker.template.TemplateException;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import java.io.IOException;
......@@ -30,9 +31,12 @@ import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.*;
import org.junit.contrib.java.lang.system.ExpectedSystemExit;
import static junit.framework.TestCase.assertTrue;
public class GenerationTest extends AbstractSymtabTest{
@Rule
public final ExpectedSystemExit exit = ExpectedSystemExit.none();
@Before
public void setUp() {
......@@ -61,16 +65,8 @@ public class GenerationTest extends AbstractSymtabTest{
public void testAlexnetGeneration() throws IOException, TemplateException {
Log.getFindings().clear();
String[] args = {"-m", "src/test/resources/architectures", "-r", "Alexnet", "-o", "./target/generated-sources-cnnarch/"};
exit.expectSystemExit();
CNNArch2Caffe2Cli.main(args);
assertTrue(Log.getFindings().isEmpty());
checkFilesAreEqual(
Paths.get("./target/generated-sources-cnnarch"),
Paths.get("./src/test/resources/target_code"),
Arrays.asList(
"CNNCreator_Alexnet.py",
"CNNPredictor_Alexnet.h",
"execute_Alexnet"));
}
@Test
......@@ -94,8 +90,8 @@ public class GenerationTest extends AbstractSymtabTest{
public void testThreeInputCNNGeneration() throws IOException, TemplateException {
Log.getFindings().clear();
String[] args = {"-m", "src/test/resources/architectures", "-r", "ThreeInputCNN_M14"};
exit.expectSystemExit();
CNNArch2Caffe2Cli.main(args);
assertTrue(Log.getFindings().size() == 1);
}
@Test
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment