Commit 89f8039f authored by nilsfreyer's avatar nilsfreyer

simpleCifar10 tests added

parent e519db57
Pipeline #110021 failed with stage
in 2 minutes and 14 seconds
......@@ -74,7 +74,7 @@ public class IntegrationCaffe2Test extends AbstractSymtabTest {
public void testDontRetrain1() {
// The training hash is stored during the first training, so the second one is skipped
Log.getFindings().clear();
String[] args = {"-m", "src/test/resources/models/", "-r", "simpleCifar10.Cifar10Classifier", "-b", "CAFFE2"};
String[] args = {"-m", "src/test/resources/models/", "-r", "simplesimpleCifar10.Cifar10Classifier", "-b", "CAFFE2"};
EMADLGeneratorCli.main(args);
assertTrue(Log.getFindings().isEmpty());
......@@ -86,19 +86,42 @@ public class IntegrationCaffe2Test extends AbstractSymtabTest {
deleteHashFile();
}
// @Test
// public void testForceRetrain() {
// // The training hash is written manually, but training is forced
// Log.getFindings().clear();
// createHashFile();
//
// String[] args = {"-m", "src/test/resources/models/", "-r", "cNNCalculator.Network", "-b", "CAFFE2", "-f", "y"};
// EMADLGeneratorCli.main(args);
// assertTrue(Log.getFindings().isEmpty());
//
// deleteHashFile();
// }
@Test
public void testDontRetrain2() {
// The training hash is written manually, so even the first training should be skipped
Log.getFindings().clear();
createHashFile();
String[] args = {"-m", "src/test/resources/models/", "-r", "simpleCifar10.Cifar10Classifier", "-b", "CAFFE2"};
EMADLGeneratorCli.main(args);
assertTrue(Log.getFindings().size() == 1);
assertTrue(Log.getFindings().get(0).getMsg().contains("skipped"));
deleteHashFile();
}
@Test
public void testDontRetrain3() {
// Multiple instances of the first NN are used. Only the first one should cause a training
Log.getFindings().clear();
String[] args = {"-m", "src/test/resources/models/", "-r", "instanceTestCifar.MainC", "-b", "CAFFE2"};
EMADLGeneratorCli.main(args);
assertTrue(Log.getFindings().size() == 1);
assertTrue(Log.getFindings().get(0).getMsg().contains("skipped"));
}
@Test
public void testForceRetrain() {
// The training hash is written manually, but training is forced
Log.getFindings().clear();
createHashFile();
String[] args = {"-m", "src/test/resources/models/", "-r", "simpleCifar10.Cifar10Classifier", "-b", "CAFFE2", "-f", "y"};
EMADLGeneratorCli.main(args);
assertTrue(Log.getFindings().isEmpty());
deleteHashFile();
}
......
......@@ -41,7 +41,7 @@ import static org.junit.Assert.assertFalse;
public class IntegrationMXNetTest extends AbstractSymtabTest {
private Path cifarTrainingHashFile = Paths.get("./target/generated-sources-emadl/cifar10/CifarNetwork.training_hash");
private Path cifarTrainingHashFile = Paths.get("./target/generated-sources-emadl/simpleCifar10/CifarNetwork.training_hash");
private void createHashFile() {
try {
......@@ -76,7 +76,7 @@ public class IntegrationMXNetTest extends AbstractSymtabTest {
public void testDontRetrain1() {
// The training hash is stored during the first training, so the second one is skipped
Log.getFindings().clear();
String[] args = {"-m", "src/test/resources/models/", "-r", "cifar10.Cifar10Classifier", "-b", "MXNET"};
String[] args = {"-m", "src/test/resources/models/", "-r", "simpleCifar10.Cifar10Classifier", "-b", "MXNET"};
EMADLGeneratorCli.main(args);
assertTrue(Log.getFindings().isEmpty());
......@@ -94,7 +94,7 @@ public class IntegrationMXNetTest extends AbstractSymtabTest {
Log.getFindings().clear();
createHashFile();
String[] args = {"-m", "src/test/resources/models/", "-r", "cifar10.Cifar10Classifier", "-b", "MXNET"};
String[] args = {"-m", "src/test/resources/models/", "-r", "simpleCifar10.Cifar10Classifier", "-b", "MXNET"};
EMADLGeneratorCli.main(args);
assertTrue(Log.getFindings().size() == 1);
assertTrue(Log.getFindings().get(0).getMsg().contains("skipped"));
......@@ -118,7 +118,7 @@ public class IntegrationMXNetTest extends AbstractSymtabTest {
Log.getFindings().clear();
createHashFile();
String[] args = {"-m", "src/test/resources/models/", "-r", "cifar10.Cifar10Classifier", "-b", "MXNET", "-f", "y"};
String[] args = {"-m", "src/test/resources/models/", "-r", "simpleCifar10.Cifar10Classifier", "-b", "MXNET", "-f", "y"};
EMADLGeneratorCli.main(args);
assertTrue(Log.getFindings().isEmpty());
......
configuration CifarNetwork{
num_epoch:10
batch_size:5
normalize:true
context:cpu
load_checkpoint:false
optimizer:adam{
learning_rate:0.01
learning_rate_decay:0.8
......
......@@ -5,31 +5,9 @@ component CifarNetwork<Z(2:oo) classes = 10>{
out Q(0:1)^{classes} softmax;
implementation CNN {
def conv(kernel, channels, stride=1, act=true){
Convolution(kernel=(kernel,kernel), channels=channels, stride=(stride,stride)) ->
BatchNorm() ->
Relu(?=act)
}
def resLayer(channels, stride=1, addSkipConv=false){
(
conv(kernel=3, channels=channels, stride=stride) ->
conv(kernel=3, channels=channels, act=false)
|
conv(kernel=1, channels=channels, stride=stride, act=false, ? = addSkipConv)
) ->
Add() ->
Relu()
}
data ->
resLayer(channels=8, addSkipConv=true) ->
resLayer(channels=16, stride=2, addSkipConv=true) ->
resLayer(channels=16, ->=2) ->
resLayer(channels=32, stride=2, addSkipConv=true) ->
resLayer(channels=32, ->=2) ->
resLayer(channels=64, stride=2, addSkipConv=true) ->
resLayer(channels=64, ->=2) ->
GlobalPooling(pool_type="avg") ->
Convolution(kernel=(5,5), channels=8, padding="valid") ->
Convolution(kernel=(5,5), channels=8, padding="valid") ->
FullyConnected(units=128) ->
Dropout()->
FullyConnected(units=classes) ->
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment