Extended layer support checker to analyze also elements of constructed layers

parent e7606f12
Pipeline #101444 passed with stages
in 4 minutes and 2 seconds
......@@ -26,6 +26,7 @@ import de.monticore.lang.monticar.cnnarch._cocos.CNNArchCocos;
import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.ArchitectureElementSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.CompositeElementSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.IOSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.CNNArchCompilationUnitSymbol;
import de.monticore.lang.monticar.cnnarch._symboltable.CNNArchLanguage;
import de.monticore.lang.monticar.generator.FileContent;
......@@ -41,19 +42,37 @@ import java.nio.file.Path;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
import java.util.List;
public class CNNArch2Caffe2 implements CNNArchGenerator{
private String generationTargetPath;
private void supportCheck(ArchitectureSymbol architecture){
private boolean isSupportedLayer(ArchitectureElementSymbol element, LayerSupportChecker layerChecker){
List<ArchitectureElementSymbol> constructLayerElemList;
if (!(element instanceof IOSymbol) && (element.getResolvedThis().get() instanceof CompositeElementSymbol))
{
constructLayerElemList = ((CompositeElementSymbol)element.getResolvedThis().get()).getElements();
for (ArchitectureElementSymbol constructedLayerElement : constructLayerElemList) {
if (!isSupportedLayer(constructedLayerElement, layerChecker)) return false;
}
}
if (!layerChecker.isSupported(element.toString())) {
Log.error("Unsupported layer " + "'" + element.getName() + "'" + " for the backend CAFFE2.");
return false;
}
else {
return true;
}
}
private boolean supportCheck(ArchitectureSymbol architecture){
LayerSupportChecker layerChecker = new LayerSupportChecker();
for (ArchitectureElementSymbol element : ((CompositeElementSymbol)architecture.getBody()).getElements()){
if (!layerChecker.isSupported(element.toString())) {
Log.error("Unsupported layer " + "'" + element.getName() + "'" + " for the backend CAFFE2. Code generation aborted.");
System.exit(1);
}
if(!isSupportedLayer(element, layerChecker)) return false;
}
return true;
}
public CNNArch2Caffe2() {
......@@ -90,7 +109,10 @@ public class CNNArch2Caffe2 implements CNNArchGenerator{
}
CNNArchCocos.checkAll(compilationUnit.get());
supportCheck(compilationUnit.get().getArchitecture());
if (!supportCheck(compilationUnit.get().getArchitecture())){
Log.error("Code generation aborted.");
System.exit(1);
}
try{
generateFiles(compilationUnit.get().getArchitecture());
......
......@@ -46,9 +46,9 @@ public class GenerationTest extends AbstractSymtabTest{
}
@Test
public void testCifar10Classifier() throws IOException, TemplateException {
public void testLeNetGeneration() throws IOException, TemplateException {
Log.getFindings().clear();
String[] args = {"-m", "src/test/resources/valid_tests", "-r", "CifarClassifierNetwork", "-o", "./target/generated-sources-cnnarch/"};
String[] args = {"-m", "src/test/resources/architectures", "-r", "LeNet", "-o", "./target/generated-sources-cnnarch/"};
CNNArch2Caffe2Cli.main(args);
assertTrue(Log.getFindings().isEmpty());
......@@ -56,9 +56,18 @@ public class GenerationTest extends AbstractSymtabTest{
Paths.get("./target/generated-sources-cnnarch"),
Paths.get("./src/test/resources/target_code"),
Arrays.asList(
"CNNCreator_CifarClassifierNetwork.py",
"CNNPredictor_CifarClassifierNetwork.h",
"execute_CifarClassifierNetwork"));
"CNNCreator_LeNet.py",
"CNNPredictor_LeNet.h",
"execute_LeNet"));
}
@Test
public void testCifar10Classifier() throws IOException, TemplateException {
Log.getFindings().clear();
String[] args = {"-m", "src/test/resources/valid_tests", "-r", "CifarClassifierNetwork", "-o", "./target/generated-sources-cnnarch/"};
exit.expectSystemExit();
CNNArch2Caffe2Cli.main(args);
assertTrue(Log.getFindings().size() == 2);
}
@Test
......@@ -67,6 +76,8 @@ public class GenerationTest extends AbstractSymtabTest{
String[] args = {"-m", "src/test/resources/architectures", "-r", "Alexnet", "-o", "./target/generated-sources-cnnarch/"};
exit.expectSystemExit();
CNNArch2Caffe2Cli.main(args);
assertTrue(Log.getFindings().size() == 2);
}
@Test
......@@ -92,14 +103,16 @@ public class GenerationTest extends AbstractSymtabTest{
String[] args = {"-m", "src/test/resources/architectures", "-r", "ThreeInputCNN_M14"};
exit.expectSystemExit();
CNNArch2Caffe2Cli.main(args);
assertTrue(Log.getFindings().size() == 2);
}
@Test
public void testResNeXtGeneration() throws IOException, TemplateException {
Log.getFindings().clear();;
String[] args = {"-m", "src/test/resources/architectures", "-r", "ResNeXt50"};
exit.expectSystemExit();
CNNArch2Caffe2Cli.main(args);
assertTrue(Log.getFindings().isEmpty());
assertTrue(Log.getFindings().size() == 2);
}
@Test
......
This diff is collapsed.
#ifndef CNNPREDICTOR_LENET
#define CNNPREDICTOR_LENET
#include "caffe2/core/common.h"
#include "caffe2/utils/proto_utils.h"
#include "caffe2/core/workspace.h"
#include "caffe2/core/tensor.h"
#include "caffe2/core/init.h"
// Define USE_GPU for GPU computation. Default is CPU computation.
//#define USE_GPU
#ifdef USE_GPU
#include "caffe2/core/context_gpu.h"
#endif
#include <string>
#include <iostream>
#include <map>
CAFFE2_DEFINE_string(init_net, "./model/LeNet/init_net.pb", "The given path to the init protobuffer.");
CAFFE2_DEFINE_string(predict_net, "./model/LeNet/predict_net.pb", "The given path to the predict protobuffer.");
using namespace caffe2;
class CNNPredictor_LeNet{
private:
TensorCPU input;
Workspace workSpace;
NetDef initNet, predictNet;
public:
const std::vector<TIndex> input_shapes = {{1,1,28,28}};
explicit CNNPredictor_LeNet(){
init(input_shapes);
}
~CNNPredictor_LeNet(){};
void init(const std::vector<TIndex> &input_shapes){
int n = 0;
char **a[1];
caffe2::GlobalInit(&n, a);
if (!std::ifstream(FLAGS_init_net).good()) {
std::cerr << "\nNetwork loading failure, init_net file '" << FLAGS_init_net << "' does not exist." << std::endl;
exit(1);
}
if (!std::ifstream(FLAGS_predict_net).good()) {
std::cerr << "\nNetwork loading failure, predict_net file '" << FLAGS_predict_net << "' does not exist." << std::endl;
exit(1);
}
std::cout << "\nLoading network..." << std::endl;
// Read protobuf
CAFFE_ENFORCE(ReadProtoFromFile(FLAGS_init_net, &initNet));
CAFFE_ENFORCE(ReadProtoFromFile(FLAGS_predict_net, &predictNet));
// Set device type
#ifdef USE_GPU
predictNet.mutable_device_option()->set_device_type(CUDA);
initNet.mutable_device_option()->set_device_type(CUDA);
std::cout << "== GPU mode selected " << " ==" << std::endl;
#else
predictNet.mutable_device_option()->set_device_type(CPU);
initNet.mutable_device_option()->set_device_type(CPU);
for(int i = 0; i < predictNet.op_size(); ++i){
predictNet.mutable_op(i)->mutable_device_option()->set_device_type(CPU);
}
for(int i = 0; i < initNet.op_size(); ++i){
initNet.mutable_op(i)->mutable_device_option()->set_device_type(CPU);
}
std::cout << "== CPU mode selected " << " ==" << std::endl;
#endif
// Load network
CAFFE_ENFORCE(workSpace.RunNetOnce(initNet));
CAFFE_ENFORCE(workSpace.CreateNet(predictNet));
std::cout << "== Network loaded " << " ==" << std::endl;
input.Resize(input_shapes);
}
void predict(const std::vector<float> &image, std::vector<float> &predictions){
//Note: ShareExternalPointer requires a float pointer.
input.ShareExternalPointer((float *) image.data());
// Get input blob
#ifdef USE_GPU
auto dataBlob = workSpace.GetBlob("data")->GetMutable<TensorCUDA>();
#else
auto dataBlob = workSpace.GetBlob("data")->GetMutable<TensorCPU>();
#endif
// Copy from input data
dataBlob->CopyFrom(input);
// Forward
workSpace.RunNet(predictNet.name());
// Get output blob
#ifdef USE_GPU
auto predictionsBlob = TensorCPU(workSpace.GetBlob("predictions")->Get<TensorCUDA>());
#else
auto predictionsBlob = workSpace.GetBlob("predictions")->Get<TensorCPU>();
#endif
predictions.assign(predictionsBlob.data<float>(),predictionsBlob.data<float>() + predictionsBlob.size());
google::protobuf::ShutdownProtobufLibrary();
}
};
#endif // CNNPREDICTOR_LENET
vector<float> CNN_predictions(10);
_cnn_.predict(CNNTranslator::translate(image),
CNN_predictions);
predictions = CNNTranslator::translateToCol(CNN_predictions, std::vector<size_t> {10});
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment