Commit f756778c authored by nilsfreyer's avatar nilsfreyer

fixed tests for cnnarch2caffe2 2.10-SNAPSHOT, added lmdb database for tests

parent f6a97730
......@@ -18,7 +18,7 @@
<emadl.version>0.2.6</emadl.version>
<CNNTrain.version>0.2.6</CNNTrain.version>
<cnnarch-mxnet-generator.version>0.2.14-SNAPSHOT</cnnarch-mxnet-generator.version>
<cnnarch-caffe2-generator.version>0.2.9</cnnarch-caffe2-generator.version>
<cnnarch-caffe2-generator.version>0.2.10-SNAPSHOT</cnnarch-caffe2-generator.version>
<embedded-montiarc-math-opt-generator>0.1.4</embedded-montiarc-math-opt-generator>
<!-- .. Libraries .................................................. -->
......
......@@ -41,6 +41,28 @@ import static org.junit.Assert.assertFalse;
public class IntegrationCaffe2Test extends AbstractSymtabTest {
private Path vggTrainingHashFile = Paths.get("./target/generated-sources-emadl/cNNCalculator/VGG16.training_hash");
private void createHashFile() {
try {
vggTrainingHashFile.toFile().getParentFile().mkdirs();
List<String> lines = Arrays.asList("7A7FBAC4E0AD84993C1C5F8B4F431055#D85A46E95F839BBEE22D9AC3E6A4BC5C#6BE4AED3D0DA1940B750FEA8088A7D21#6BE4AED3D0DA1940B750FEA8088A7D21");
Files.write(vggTrainingHashFile, lines, Charset.forName("UTF-8"));
}
catch(Exception e) {
assertFalse("Hash file could not be created", true);
}
}
private void deleteHashFile() {
try {
Files.delete(vggTrainingHashFile);
}
catch(Exception e) {
assertFalse("Could not delete hash file", true);
}
}
@Before
public void setUp() {
// ensure an empty log
......@@ -48,20 +70,33 @@ public class IntegrationCaffe2Test extends AbstractSymtabTest {
Log.enableFailQuick(false);
}
@Test
public void testDontRetrain() {
public void testDontRetrain1() {
// The training hash is stored during the first training, so the second one is skipped
Log.getFindings().clear();
String[] args = {"-m", "src/test/resources/models/", "-r", "cNNCalculator.Network", "-b", "CAFFE2"};
String[] args = {"-m", "src/test/resources/models/", "-r", "cNNCalculator.VGG16", "-b", "CAFFE2"};
EMADLGeneratorCli.main(args);
assertTrue(Log.getFindings().isEmpty());
//assertTrue(!Log.getFindings().isEmpty());
Log.getFindings().clear();
EMADLGeneratorCli.main(args);
assertTrue(Log.getFindings().size() == 1);
//assertTrue(Log.getFindings().size() == 1);
assertTrue(Log.getFindings().get(0).getMsg().contains("skipped"));
deleteHashFile();
}
@Test
public void testForceRetrain() {
// The training hash is written manually, but training is forced
Log.getFindings().clear();
createHashFile();
String[] args = {"-m", "src/test/resources/models/", "-r", "cNNCalculator.VGG16", "-b", "CAFFE2", "-f", "y"};
EMADLGeneratorCli.main(args);
assertTrue(Log.getFindings().isEmpty());
deleteHashFile();
}
}
cifar10.CifarNetwork src/test/resources/training_data
cNNCalculator.Network src/test/resources/training_data
cNNCalculator.VGG16 src/test/resources/training_data
MultipleOutputs data/MultipleOutputs
InstanceTest.NetworkB data/InstanceTest.NetworkB
Alexnet data/Alexnet
......
......@@ -13,8 +13,8 @@ class CNNCreator_mnist_mnistClassifier_net:
module = None
_current_dir_ = os.path.join('./')
_data_dir_ = os.path.join(_current_dir_, 'data', 'mnist_mnistClassifier_net')
_model_dir_ = os.path.join(_current_dir_, 'model', 'mnist_mnistClassifier_net')
_data_dir_ = os.path.join(_current_dir_, 'data/mnist.LeNetNetwork')
_model_dir_ = os.path.join(_current_dir_, 'model', 'mnist.LeNetNetwork')
_init_net_ = os.path.join(_model_dir_, 'init_net.pb')
_predict_net_ = os.path.join(_model_dir_, 'predict_net.pb')
......
......@@ -18,8 +18,8 @@
#include <iostream>
#include <map>
CAFFE2_DEFINE_string(init_net, "./model/mnist_mnistClassifier_net/init_net.pb", "The given path to the init protobuffer.");
CAFFE2_DEFINE_string(predict_net, "./model/mnist_mnistClassifier_net/predict_net.pb", "The given path to the predict protobuffer.");
CAFFE2_DEFINE_string(init_net_CNNPredictor_mnist_mnistClassifier_net, "./model/mnist.LeNetNetwork/init_net.pb", "The given path to the init protobuffer.");
CAFFE2_DEFINE_string(predict_net_CNNPredictor_mnist_mnistClassifier_net, "./model/mnist.LeNetNetwork/predict_net.pb", "The given path to the predict protobuffer.");
using namespace caffe2;
......@@ -43,21 +43,21 @@ class CNNPredictor_mnist_mnistClassifier_net{
char **a[1];
caffe2::GlobalInit(&n, a);
if (!std::ifstream(FLAGS_init_net).good()) {
std::cerr << "\nNetwork loading failure, init_net file '" << FLAGS_init_net << "' does not exist." << std::endl;
if (!std::ifstream(FLAGS_init_net_CNNPredictor_mnist_mnistClassifier_net).good()) {
std::cerr << "\nNetwork loading failure, init_net file '" << FLAGS_init_net_CNNPredictor_mnist_mnistClassifier_net << "' does not exist." << std::endl;
exit(1);
}
if (!std::ifstream(FLAGS_predict_net).good()) {
std::cerr << "\nNetwork loading failure, predict_net file '" << FLAGS_predict_net << "' does not exist." << std::endl;
if (!std::ifstream(FLAGS_predict_net_CNNPredictor_mnist_mnistClassifier_net).good()) {
std::cerr << "\nNetwork loading failure, predict_net file '" << FLAGS_predict_net_CNNPredictor_mnist_mnistClassifier_net << "' does not exist." << std::endl;
exit(1);
}
std::cout << "\nLoading network..." << std::endl;
// Read protobuf
CAFFE_ENFORCE(ReadProtoFromFile(FLAGS_init_net, &initNet));
CAFFE_ENFORCE(ReadProtoFromFile(FLAGS_predict_net, &predictNet));
CAFFE_ENFORCE(ReadProtoFromFile(FLAGS_init_net_CNNPredictor_mnist_mnistClassifier_net, &initNet));
CAFFE_ENFORCE(ReadProtoFromFile(FLAGS_predict_net_CNNPredictor_mnist_mnistClassifier_net, &predictNet));
// Set device type
#ifdef USE_GPU
......
No preview for this file type
No preview for this file type
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment