Added tests for training/hashing, fixed some bugs, removed debug output

parent 12c352dd
Pipeline #104498 failed with stages
......@@ -210,7 +210,7 @@ public class EMADLGenerator {
continue;
}
if(forced == "n") {
if(forced.equals("n")) {
continue;
}
......@@ -237,11 +237,10 @@ public class EMADLGenerator {
String trainingHash = emadlHash + "#" + cnntHash + "#" + trainingDataHash + "#" + testDataHash;
boolean alreadyTrained = newHashes.contains(trainingHash) || isAlreadyTrained(trainingHash, componentInstance);
if(alreadyTrained && forced !="y") {
System.out.println("Already trained");
if(alreadyTrained && !forced.equals("y")) {
Log.warn("Training of model " + componentInstance.getFullName() + " skipped");
}
else {
System.out.println("Not trained yet");
String parsedFullName = componentInstance.getFullName().substring(0, 1).toLowerCase() + componentInstance.getFullName().substring(1).replaceAll("\\.", "_");
String trainerScriptName = "CNNTrainer_" + parsedFullName + ".py";
String trainingPath = getGenerationTargetPath() + trainerScriptName;
......@@ -267,7 +266,8 @@ public class EMADLGenerator {
fileContentsTrainingHashes.add(new FileContent(trainingHash, componentConfigFilename + ".training_hash"));
newHashes.add(trainingHash);
}else{
}
else{
System.out.println("Trainingfile " + trainingPath + " not found.");
}
}
......
......@@ -127,21 +127,16 @@ public class EMADLGeneratorCli {
backend = Backend.getBackendFromString(DEFAULT_BACKEND);
}
if(pythonPath == null) {
if (pythonPath == null) {
pythonPath = "/usr/bin/python";
}
if (forced == null) {
Log.warn("forced not specified. forced set to default value" + DEFAULT_FORCED);
forced = DEFAULT_FORCED;
}else if (forced == "y") {
Log.warn("training with enforcement");
/**/
}else if (forced == "n") {
Log.warn("no training with enforcement");
}else{
Log.error("no such parameter" + forced);
System.exit(1);
}
else if (!forced.equals("y") && !forced.equals("n")) {
Log.error("specified setting ("+forced+") for forcing/preventing training not supported. set to default value " + DEFAULT_FORCED);
forced = DEFAULT_FORCED;
}
EMADLGenerator generator = new EMADLGenerator(backend.get());
......
......@@ -29,13 +29,40 @@ import org.junit.Before;
import org.junit.Test;
import java.io.IOException;
import java.nio.charset.Charset;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Arrays;
import java.util.List;
import static junit.framework.TestCase.assertTrue;
import static org.junit.Assert.assertFalse;
public class GenerationTest extends AbstractSymtabTest {
private Path cifarTrainingHashFile = Paths.get("./target/generated-sources-emadl/cifar10/CifarNetwork.training_hash");
private void createHashFile() {
try {
cifarTrainingHashFile.toFile().getParentFile().mkdirs();
List<String> lines = Arrays.asList("AF9A637D700CB002266D20BF242F4A59#B87F2C80B19CABE0899C30FA66763A47#C4C23549E737A759721D6694C75D9771#5AF0CE68E408E8C1F000E49D72AC214A");
Files.write(cifarTrainingHashFile, lines, Charset.forName("UTF-8"));
}
catch(Exception e) {
assertFalse("Hash file could not be created", true);
}
}
private void deleteHashFile() {
try {
Files.delete(cifarTrainingHashFile);
}
catch(Exception e) {
assertFalse("Could not delete hash file", true);
}
}
@Before
public void setUp() {
// ensure an empty log
......@@ -113,7 +140,7 @@ public class GenerationTest extends AbstractSymtabTest {
}
@Test
public void tesVGGGeneration() throws IOException, TemplateException {
public void testVGGGeneration() throws IOException, TemplateException {
Log.getFindings().clear();
String[] args = {"-m", "src/test/resources/models/", "-r", "VGG16", "-b", "MXNET", "-f", "n"};
EMADLGeneratorCli.main(args);
......@@ -122,10 +149,15 @@ public class GenerationTest extends AbstractSymtabTest {
@Test
public void testMultipleInstances() throws IOException, TemplateException {
Log.getFindings().clear();
String[] args = {"-m", "src/test/resources/models/", "-r", "InstanceTest.MainB", "-b", "MXNET", "-f", "n"};
EMADLGeneratorCli.main(args);
assertTrue(Log.getFindings().isEmpty());
try {
Log.getFindings().clear();
String[] args = {"-m", "src/test/resources/models/", "-r", "InstanceTest.MainB", "-b", "MXNET", "-f", "n"};
EMADLGeneratorCli.main(args);
assertTrue(Log.getFindings().isEmpty());
}
catch(Exception e) {
e.printStackTrace();
}
}
@Test
......@@ -138,4 +170,57 @@ public class GenerationTest extends AbstractSymtabTest {
} catch(IOException e){
}
}
@Test
public void testDontRetrain1() {
// The training hash is stored during the first training, so the second one is skipped
Log.getFindings().clear();
String[] args = {"-m", "src/test/resources/models/", "-r", "cifar10.Cifar10Classifier", "-b", "MXNET", "-p", "/home/christopher/anaconda3/bin/python"};
EMADLGeneratorCli.main(args);
assertTrue(Log.getFindings().isEmpty());
Log.getFindings().clear();
EMADLGeneratorCli.main(args);
assertTrue(Log.getFindings().size() == 1);
assertTrue(Log.getFindings().get(0).getMsg().contains("skipped"));
deleteHashFile();
}
@Test
public void testDontRetrain2() {
// The training hash is written manually, so even the first training should be skipped
Log.getFindings().clear();
createHashFile();
String[] args = {"-m", "src/test/resources/models/", "-r", "cifar10.Cifar10Classifier", "-b", "MXNET", "-p", "/home/christopher/anaconda3/bin/python"};
EMADLGeneratorCli.main(args);
assertTrue(Log.getFindings().size() == 1);
assertTrue(Log.getFindings().get(0).getMsg().contains("skipped"));
deleteHashFile();
}
@Test
public void testDontRetrain3() {
// Multiple instances of the first NN are used. Only the first one should cause a training
Log.getFindings().clear();
String[] args = {"-m", "src/test/resources/models/", "-r", "instanceTestCifar.MainC", "-b", "MXNET", "-p", "/home/christopher/anaconda3/bin/python"};
EMADLGeneratorCli.main(args);
assertTrue(Log.getFindings().size() == 1);
assertTrue(Log.getFindings().get(0).getMsg().contains("skipped"));
}
@Test
public void testForceRetrain() {
// The training hash is written manually, but training is forced
Log.getFindings().clear();
createHashFile();
String[] args = {"-m", "src/test/resources/models/", "-r", "cifar10.Cifar10Classifier", "-b", "MXNET", "-f", "y", "-p", "/home/christopher/anaconda3/bin/python"};
EMADLGeneratorCli.main(args);
assertTrue(Log.getFindings().isEmpty());
deleteHashFile();
}
}
configuration CifarNetwork{
num_epoch:10
batch_size:64
batch_size:5
normalize:true
context:gpu
load_checkpoint:false
......
cifar10.CifarNetwork data/cifar10_cifar10Classifier_net
cifar10.CifarNetwork src/test/resources/training_data
MultipleOutputs data/MultipleOutputs
InstanceTest.NetworkB data/InstanceTest.NetworkB
instanceTest.NetworkB data/InstanceTest.NetworkB
Alexnet data/Alexnet
ThreeInputCNN_M14 data/ThreeInputCNN_M14
VGG16 data/VGG16
ResNeXt50 data/ResNeXt50
\ No newline at end of file
ResNeXt50 data/ResNeXt50
instanceTestCifar.CifarNetwork src/test/resources/training_data
\ No newline at end of file
package instanceTestCifar;
component ArgMax<Z(1:oo) n = 2>{
ports in Q^{n} inputVector,
out Z(0:oo) maxIndex;
implementation Math{
maxIndex = 0;
Q maxValue = inputVector(1);
for i = 2:n
if inputVector(i) > maxValue
maxIndex = i - 1;
maxValue = inputVector(i);
end
end
}
}
\ No newline at end of file
package instanceTestCifar;
import Network;
import CalculateClass;
component Cifar10Classifier{
ports in Z(0:255)^{3, 32, 32} image,
out Z(0:9) classIndex;
instance CifarNetwork<10> net;
instance ArgMax<10> calculateClass;
connect image -> net.data;
connect net.softmax -> calculateClass.inputVector;
connect calculateClass.maxIndex -> classIndex;
}
\ No newline at end of file
configuration CifarNetwork{
num_epoch:10
batch_size:5
normalize:true
context:gpu
load_checkpoint:false
optimizer:adam{
learning_rate:0.01
learning_rate_decay:0.8
step_size:1000
weight_decay:0.0001
}
}
package instanceTestCifar;
component CifarNetwork<Z(2:oo) classes = 10>{
ports in Z(0:255)^{3, 32, 32} data,
out Q(0:1)^{classes} softmax;
implementation CNN {
def conv(kernel, channels, stride=1, act=true){
Convolution(kernel=(kernel,kernel), channels=channels, stride=(stride,stride)) ->
BatchNorm() ->
Relu(?=act)
}
def resLayer(channels, stride=1, addSkipConv=false){
(
conv(kernel=3, channels=channels, stride=stride) ->
conv(kernel=3, channels=channels, act=false)
|
conv(kernel=1, channels=channels, stride=stride, act=false, ? = addSkipConv)
) ->
Add() ->
Relu()
}
data ->
resLayer(channels=8, addSkipConv=true) ->
resLayer(channels=16, stride=2, addSkipConv=true) ->
resLayer(channels=16, ->=2) ->
resLayer(channels=32, stride=2, addSkipConv=true) ->
resLayer(channels=32, ->=2) ->
resLayer(channels=64, stride=2, addSkipConv=true) ->
resLayer(channels=64, ->=2) ->
GlobalPooling(pool_type="avg") ->
FullyConnected(units=128) ->
Dropout()->
FullyConnected(units=classes) ->
Softmax() ->
softmax
}
}
\ No newline at end of file
package instanceTestCifar;
import NetworkB;
import CalculateClassB;
import ArgMax;
import ResultAdder;
component MainC{
ports in Z(0:255)^{3, 32, 32} image1,
in Z(0:255)^{3, 32, 32} image2,
out Z(0:20) result;
instance CifarNetwork<10> net1;
instance CifarNetwork<10> net2;
instance ArgMax<10> calculateClass1;
instance ArgMax<10> calculateClass2;
instance ResultAdder adder;
connect image1 -> net1.data;
connect image2 -> net2.data;
connect net1.softmax -> calculateClass1.inputVector;
connect net2.softmax -> calculateClass2.inputVector;
connect calculateClass1.maxIndex -> adder.number1;
connect calculateClass2.maxIndex -> adder.number1;
connect adder.sum -> result;
}
\ No newline at end of file
package instanceTestCifar;
component ResultAdder{
ports in Z(0:oo) number1,
in Z(0:oo) number2,
out Z(0:oo) sum;
implementation Math{
sum = number1 + number2;
}
}
\ No newline at end of file
......@@ -18,7 +18,7 @@ class MyConstant(mx.init.Initializer):
class CNNCreator_cifar10_cifar10Classifier_net:
module = None
_data_dir_ = "data/cifar10_cifar10Classifier_net/"
_data_dir_ = "src/test/resources/training_data/"
_model_dir_ = "model/cifar10.CifarNetwork/"
_model_prefix_ = "model"
_input_names_ = ['data']
......
......@@ -10,7 +10,7 @@ if __name__ == "__main__":
cifar10_cifar10Classifier_net = CNNCreator_cifar10_cifar10Classifier_net.CNNCreator_cifar10_cifar10Classifier_net()
cifar10_cifar10Classifier_net.train(
batch_size=64,
batch_size=5,
num_epoch=10,
load_checkpoint=False,
context='gpu',
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment