Commit e785fbc2 authored by Evgeny Kusmenko's avatar Evgeny Kusmenko

Merge branch 'oneclick_nn_training' into 'master'

Oneclick nn training

Closes #7

See merge request !21
parents abbd6de0 59972bb9
Pipeline #110151 failed with stages
in 7 minutes and 38 seconds
......@@ -22,29 +22,44 @@
stages:
- windows
- linux
- deploy
masterJobLinux:
stage: linux
stage: deploy
image: maven:3-jdk-8
script:
- mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B clean deploy --settings settings.xml
- mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B clean deploy --settings settings.xml -DskipTests
- cat target/site/jacoco/index.html
- mvn package sonar:sonar -s settings.xml
only:
- master
integrationMXNetJobLinux:
stage: linux
image: registry.git.rwth-aachen.de/monticore/embeddedmontiarc/generators/emadl2cpp/integrationtests/mxnet:v0.0.3
script:
- mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B clean install --settings settings.xml -Dtest=IntegrationMXNetTest
integrationCaffe2JobLinux:
stage: linux
image: registry.git.rwth-aachen.de/monticore/embeddedmontiarc/generators/emadl2cpp/integrationtests/caffe2:v0.0.3
script:
- mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B clean install --settings settings.xml -Dtest=IntegrationCaffe2Test
masterJobWindows:
stage: windows
script:
- mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B clean install --settings settings.xml
- mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B clean install --settings settings.xml -Dtest="GenerationTest,SymtabTest"
tags:
- Windows10
BranchJobLinux:
UnitTestJobLinux:
stage: linux
image: maven:3-jdk-8
script:
- mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B clean install --settings settings.xml
- mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B clean install --settings settings.xml -Dtest="GenerationTest,SymtabTest*"
- cat target/site/jacoco/index.html
except:
- master
{
"configurations": [
{
"type": "java",
"name": "CodeLens (Launch) - EMADLGeneratorCli",
"request": "launch",
"mainClass": "de.monticore.lang.monticar.emadl.generator.EMADLGeneratorCli",
"projectName": "embedded-montiarc-emadl-generator"
}
]
}
\ No newline at end of file
......@@ -2,4 +2,84 @@
![coverage](https://git.rwth-aachen.de/monticore/EmbeddedMontiArc/generators/EMADL2CPP/badges/master/coverage.svg)
# EMADL2CPP
Generates CPP/Python code for EmbeddedMontiArcDL.
See example project [EMADL-Demo](https://git.rwth-aachen.de/thomas.timmermanns/EMADL-Demo) for more information on how the generated code can be used.
\ No newline at end of file
See example project [EMADL-Demo](https://git.rwth-aachen.de/thomas.timmermanns/EMADL-Demo) for more information on how the generated code can be used.
[ How to develop and train a CNN component using EMADL2CPP](#nn)
<a name="nn"></a>
# Development and training of a CNN component using EMADL2CPP
## Prerequisites
* Linux. Ubuntu Linux 16.04 and 18.04 were used during testing.
* Deep learning backend:
* MXNet
* training - generated is Python code. Required is Python 2.7 or higher, Python packages `h5py`, `mxnet` (for training on CPU) or e.g. `mxnet-cu75` for CUDA 7.5 (for training on GPU with CUDA, concrete package should be selected according to CUDA version). Follow [official instructions on MXNet site](https://mxnet.incubator.apache.org/install/index.html?platform=Linux&language=Python&processor=CPU)
* prediction - generated code is C++. Install MXNet using [official instructions on MXNet site](https://mxnet.incubator.apache.org) for C++.
* Caffe2
* training - generated is Python code. Follow [ official instructions on Caffe2 site ](https://caffe2.ai/docs/getting-started.html?platform=ubuntu&configuration=prebuilt)
* Gluon
### HowTo
1. Define a EMADL component containing architecture of a neural network and save it in a `.emadl` file. For more information on architecture language please refer to [CNNArchLang project](https://git.rwth-aachen.de/monticore/EmbeddedMontiArc/languages/CNNArchLang). An example of NN architecture:
```
component VGG16{
ports in Z(0:255)^{3, 224, 224} image,
out Q(0:1)^{1000} predictions;
implementation CNN {
def conv(filter, channels){
Convolution(kernel=(filter,filter), channels=channels) ->
Relu()
}
def fc(){
FullyConnected(units=4096) ->
Relu() ->
Dropout(p=0.5)
}
image ->
conv(filter=3, channels=64, ->=2) ->
Pooling(pool_type="max", kernel=(2,2), stride=(2,2)) ->
conv(filter=3, channels=128, ->=2) ->
Pooling(pool_type="max", kernel=(2,2), stride=(2,2)) ->
conv(filter=3, channels=256, ->=3) ->
Pooling(pool_type="max", kernel=(2,2), stride=(2,2)) ->
conv(filter=3, channels=512, ->=3) ->
Pooling(pool_type="max", kernel=(2,2), stride=(2,2)) ->
conv(filter=3, channels=512, ->=3) ->
Pooling(pool_type="max", kernel=(2,2), stride=(2,2)) ->
fc() ->
fc() ->
FullyConnected(units=1000) ->
Softmax() ->
predictions
}
}
```
2. Define a training configuration for this network and store it in a `.cnnt file`, the name of the file should be the same as that of the corresponding architecture (e.g. `VGG16.emadl` and `VGG16.cnnt`). For more information on architecture language please refer to [CNNTrainLang project](https://git.rwth-aachen.de/monticore/EmbeddedMontiArc/languages/CNNTrainLang). An example of a training configuration:
```
configuration VGG16{
num_epoch:10
batch_size:64
normalize:true
load_checkpoint:false
optimizer:adam{
learning_rate:0.01
learning_rate_decay:0.8
step_size:1000
}
}
```
3. Generate C code which uses neural networks that were trained using the specified deep learning backend. The generator receives the following command line parameters:
* `-m` path to directory with EMADL models
* `-r` name of the root model
* `-o` output path
* `-b` backend
* `-p` path to python (Not mandatory; Default is `/usr/bin/python`)
* `-f` forced training (Not mandatory; values can be `y` for a forced training and `n` for a skip (a forced no-training)). By default, the hash value (from the training and test data, the structure of the model (.emadl) and the training parameters (.cnnt) of the model) will be compared. The model is retrained only if the hash changes. This can be used to distribute trained models, by distributing the corresponding `.training_hash` file as well, which will prevent a retraining.
Assuming both the architecture definition `VGG16.emadl` and the corresponding training configuration `VGG16.cnnt` are located in a folder `models` and the target code should be generated in a `target` folder using the `MXNet` backend, an example of a command is then:
```java -jar embedded-montiarc-emadl-generator-0.2.10-jar-with-dependencies.jar -m models -r VGG16 -o target -b MXNET```
......@@ -8,7 +8,7 @@
<groupId>de.monticore.lang.monticar</groupId>
<artifactId>embedded-montiarc-emadl-generator</artifactId>
<version>0.2.8</version>
<version>0.2.10</version>
<!-- == PROJECT DEPENDENCIES ============================================= -->
......@@ -17,8 +17,8 @@
<!-- .. SE-Libraries .................................................. -->
<emadl.version>0.2.6</emadl.version>
<CNNTrain.version>0.2.6</CNNTrain.version>
<cnnarch-mxnet-generator.version>0.2.12</cnnarch-mxnet-generator.version>
<cnnarch-caffe2-generator.version>0.2.9</cnnarch-caffe2-generator.version>
<cnnarch-mxnet-generator.version>0.2.14-SNAPSHOT</cnnarch-mxnet-generator.version>
<cnnarch-caffe2-generator.version>0.2.10-SNAPSHOT</cnnarch-caffe2-generator.version>
<embedded-montiarc-math-opt-generator>0.1.4</embedded-montiarc-math-opt-generator>
<!-- .. Libraries .................................................. -->
......
......@@ -39,7 +39,7 @@
<id>se-nexus</id>
<username>cibuild</username>
<password>${env.cibuild}</password>
</server>
</server>
<server>
<id>github</id>
......@@ -53,7 +53,7 @@
<id>se-nexus</id>
<mirrorOf>external:*</mirrorOf>
<url>https://nexus.se.rwth-aachen.de/content/groups/public</url>
</mirror>
</mirror>
</mirrors>
<profiles>
......@@ -99,5 +99,5 @@
<activeProfiles>
<activeProfile>se-nexus</activeProfile>
</activeProfiles>
</settings>
\ No newline at end of file
</activeProfiles>
</settings>
......@@ -47,4 +47,13 @@ public enum Backend {
return Optional.empty();
}
}
public static String getBackendString(Backend backend){
switch (backend){
case CAFFE2:
return "CAFFE2";
default:
return "MXNET";
}
}
}
......@@ -20,11 +20,11 @@
*/
package de.monticore.lang.monticar.emadl.generator;
import de.se_rwth.commons.logging.Log;
import freemarker.template.TemplateException;
import org.apache.commons.cli.*;
import java.io.File;
import java.io.IOException;
import java.util.Optional;
......@@ -57,6 +57,21 @@ public class EMADLGeneratorCli {
.required(false)
.build();
public static final Option OPTION_TRAINING_PYTHON_PATH = Option.builder("p")
.longOpt("python")
.desc("path to python. Default is /usr/bin/python")
.hasArg(true)
.required(false)
.build();
public static final Option OPTION_RESTRAINED_TRAINING = Option.builder("f")
.longOpt("forced")
.desc("no training or a forced training. Options: y (a forced training), n (no training)")
.hasArg(true)
.required(false)
.build();
private EMADLGeneratorCli() {
}
......@@ -75,6 +90,8 @@ public class EMADLGeneratorCli {
options.addOption(OPTION_ROOT_MODEL);
options.addOption(OPTION_OUTPUT_PATH);
options.addOption(OPTION_BACKEND);
options.addOption(OPTION_RESTRAINED_TRAINING);
options.addOption(OPTION_TRAINING_PYTHON_PATH);
return options;
}
......@@ -94,7 +111,10 @@ public class EMADLGeneratorCli {
String rootModelName = cliArgs.getOptionValue(OPTION_ROOT_MODEL.getOpt());
String outputPath = cliArgs.getOptionValue(OPTION_OUTPUT_PATH.getOpt());
String backendString = cliArgs.getOptionValue(OPTION_BACKEND.getOpt());
String forced = cliArgs.getOptionValue(OPTION_RESTRAINED_TRAINING.getOpt());
String pythonPath = cliArgs.getOptionValue(OPTION_TRAINING_PYTHON_PATH.getOpt());
final String DEFAULT_BACKEND = "MXNET";
final String DEFAULT_FORCED = "UNSET";
if (backendString == null) {
Log.warn("backend not specified. backend set to default value " + DEFAULT_BACKEND);
......@@ -106,13 +126,25 @@ public class EMADLGeneratorCli {
Log.warn("specified backend " + backendString + " not supported. backend set to default value " + DEFAULT_BACKEND);
backend = Backend.getBackendFromString(DEFAULT_BACKEND);
}
if (pythonPath == null) {
pythonPath = "/usr/bin/python";
}
if (forced == null) {
forced = DEFAULT_FORCED;
}
else if (!forced.equals("y") && !forced.equals("n")) {
Log.error("specified setting ("+forced+") for forcing/preventing training not supported. set to default value " + DEFAULT_FORCED);
forced = DEFAULT_FORCED;
}
EMADLGenerator generator = new EMADLGenerator(backend.get());
if (outputPath != null){
generator.setGenerationTargetPath(outputPath);
}
try{
generator.generate(cliArgs.getOptionValue(OPTION_MODELS_PATH.getOpt()), rootModelName);
generator.generate(cliArgs.getOptionValue(OPTION_MODELS_PATH.getOpt()), rootModelName, pythonPath, forced);
}
catch (IOException e){
Log.error("io error during generation", e);
......
......@@ -20,6 +20,8 @@
*/
package de.monticore.lang.monticar.emadl;
import de.monticore.lang.monticar.emadl.generator.Backend;
import de.monticore.lang.monticar.emadl.generator.EMADLGenerator;
import de.monticore.lang.monticar.emadl.generator.EMADLGeneratorCli;
import de.se_rwth.commons.logging.Log;
import freemarker.template.TemplateException;
......@@ -27,10 +29,15 @@ import org.junit.Before;
import org.junit.Test;
import java.io.IOException;
import java.nio.charset.Charset;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Arrays;
import java.util.List;
import static junit.framework.TestCase.assertTrue;
import static org.junit.Assert.assertFalse;
public class GenerationTest extends AbstractSymtabTest {
......@@ -44,7 +51,7 @@ public class GenerationTest extends AbstractSymtabTest {
@Test
public void testCifar10Generation() throws IOException, TemplateException {
Log.getFindings().clear();
String[] args = {"-m", "src/test/resources/models/", "-r", "cifar10.Cifar10Classifier", "-b", "MXNET"};
String[] args = {"-m", "src/test/resources/models/", "-r", "cifar10.Cifar10Classifier", "-b", "MXNET", "-f", "n"};
EMADLGeneratorCli.main(args);
assertTrue(Log.getFindings().isEmpty());
......@@ -66,7 +73,7 @@ public class GenerationTest extends AbstractSymtabTest {
@Test
public void testSimulatorGeneration() throws IOException, TemplateException {
Log.getFindings().clear();
String[] args = {"-m", "src/test/resources/models/", "-r", "simulator.MainController", "-b", "MXNET"};
String[] args = {"-m", "src/test/resources/models/", "-r", "simulator.MainController", "-b", "MXNET", "-f", "n"};
EMADLGeneratorCli.main(args);
assertTrue(Log.getFindings().isEmpty());
}
......@@ -74,7 +81,7 @@ public class GenerationTest extends AbstractSymtabTest {
@Test
public void testAddGeneration() throws IOException, TemplateException {
Log.getFindings().clear();
String[] args = {"-m", "src/test/resources/models/", "-r", "Add", "-b", "MXNET"};
String[] args = {"-m", "src/test/resources/models/", "-r", "Add", "-b", "MXNET", "-f", "n"};
EMADLGeneratorCli.main(args);
assertTrue(Log.getFindings().isEmpty());
}
......@@ -82,7 +89,7 @@ public class GenerationTest extends AbstractSymtabTest {
@Test
public void testAlexnetGeneration() throws IOException, TemplateException {
Log.getFindings().clear();
String[] args = {"-m", "src/test/resources/models/", "-r", "Alexnet", "-b", "MXNET"};
String[] args = {"-m", "src/test/resources/models/", "-r", "Alexnet", "-b", "MXNET", "-f", "n"};
EMADLGeneratorCli.main(args);
assertTrue(Log.getFindings().isEmpty());
}
......@@ -90,7 +97,7 @@ public class GenerationTest extends AbstractSymtabTest {
@Test
public void testResNeXtGeneration() throws IOException, TemplateException {
Log.getFindings().clear();
String[] args = {"-m", "src/test/resources/models/", "-r", "ResNeXt50", "-b", "MXNET"};
String[] args = {"-m", "src/test/resources/models/", "-r", "ResNeXt50", "-b", "MXNET", "-f", "n"};
EMADLGeneratorCli.main(args);
assertTrue(Log.getFindings().isEmpty());
}
......@@ -98,7 +105,7 @@ public class GenerationTest extends AbstractSymtabTest {
@Test
public void testThreeInputGeneration() throws IOException, TemplateException {
Log.getFindings().clear();
String[] args = {"-m", "src/test/resources/models/", "-r", "ThreeInputCNN_M14", "-b", "MXNET"};
String[] args = {"-m", "src/test/resources/models/", "-r", "ThreeInputCNN_M14", "-b", "MXNET", "-f", "n"};
EMADLGeneratorCli.main(args);
assertTrue(Log.getFindings().size() == 1);
}
......@@ -106,31 +113,36 @@ public class GenerationTest extends AbstractSymtabTest {
@Test
public void testMultipleOutputsGeneration() throws IOException, TemplateException {
Log.getFindings().clear();
String[] args = {"-m", "src/test/resources/models/", "-r", "MultipleOutputs", "-b", "MXNET"};
String[] args = {"-m", "src/test/resources/models/", "-r", "MultipleOutputs", "-b", "MXNET", "-f", "n"};
EMADLGeneratorCli.main(args);
assertTrue(Log.getFindings().size() == 1);
}
@Test
public void tesVGGGeneration() throws IOException, TemplateException {
public void testVGGGeneration() throws IOException, TemplateException {
Log.getFindings().clear();
String[] args = {"-m", "src/test/resources/models/", "-r", "VGG16", "-b", "MXNET"};
String[] args = {"-m", "src/test/resources/models/", "-r", "VGG16", "-b", "MXNET", "-f", "n"};
EMADLGeneratorCli.main(args);
assertTrue(Log.getFindings().isEmpty());
}
@Test
public void testMultipleInstances() throws IOException, TemplateException {
Log.getFindings().clear();
String[] args = {"-m", "src/test/resources/models/", "-r", "InstanceTest.MainB", "-b", "MXNET"};
EMADLGeneratorCli.main(args);
assertTrue(Log.getFindings().isEmpty());
try {
Log.getFindings().clear();
String[] args = {"-m", "src/test/resources/models/", "-r", "InstanceTest.MainB", "-b", "MXNET", "-f", "n"};
EMADLGeneratorCli.main(args);
assertTrue(Log.getFindings().isEmpty());
}
catch(Exception e) {
e.printStackTrace();
}
}
@Test
public void testMnistClassifier() throws IOException, TemplateException {
Log.getFindings().clear();
String[] args = {"-m", "src/test/resources/models/", "-r", "mnist.MnistClassifier", "-b", "CAFFE2"};
String[] args = {"-m", "src/test/resources/models/", "-r", "mnist.MnistClassifier", "-b", "CAFFE2", "-f", "n"};
EMADLGeneratorCli.main(args);
assertTrue(Log.getFindings().isEmpty());
......@@ -147,4 +159,15 @@ public class GenerationTest extends AbstractSymtabTest {
"mnist_mnistClassifier_calculateClass.h",
"CNNTrainer_mnist_mnistClassifier_net.py"));
}
@Test
public void testHashFunction() {
EMADLGenerator tester = new EMADLGenerator(Backend.MXNET);
try{
tester.getChecksumForFile("invalid Path!");
assertTrue("Hash method should throw IOException on invalid path", false);
} catch(IOException e){
}
}
}
/**
*
* ******************************************************************************
* MontiCAR Modeling Family, www.se-rwth.de
* Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
* All rights reserved.
*
* This project is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 3.0 of the License, or (at your option) any later version.
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package de.monticore.lang.monticar.emadl;
import de.monticore.lang.monticar.emadl.generator.Backend;
import de.monticore.lang.monticar.emadl.generator.EMADLGenerator;
import de.monticore.lang.monticar.emadl.generator.EMADLGeneratorCli;
import de.se_rwth.commons.logging.Log;
import freemarker.template.TemplateException;
import org.junit.Before;
import org.junit.Test;
import java.io.IOException;
import java.nio.charset.Charset;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Arrays;
import java.util.List;
import static junit.framework.TestCase.assertTrue;
import static org.junit.Assert.assertFalse;
public class IntegrationCaffe2Test extends IntegrationTest {
public IntegrationCaffe2Test() {
super("CAFFE2", "39253EC049D4A4E5FA0536AD34874B9D#1DBAEE1B1BD83FB7CB5F70AE91B29638#13D139510DC5681639AA91D7250288D3#1A42D4842D0664937A9F6B727BD60CEF");
}
}
/**
*
* ******************************************************************************
* MontiCAR Modeling Family, www.se-rwth.de
* Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
* All rights reserved.
*
* This project is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 3.0 of the License, or (at your option) any later version.
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package de.monticore.lang.monticar.emadl;
import de.monticore.lang.monticar.emadl.generator.Backend;
import de.monticore.lang.monticar.emadl.generator.EMADLGenerator;
import de.monticore.lang.monticar.emadl.generator.EMADLGeneratorCli;
import de.se_rwth.commons.logging.Log;
import freemarker.template.TemplateException;
import org.junit.Before;
import org.junit.Test;
import java.io.IOException;
import java.nio.charset.Charset;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Arrays;
import java.util.List;
import static junit.framework.TestCase.assertTrue;
import static org.junit.Assert.assertFalse;
public class IntegrationMXNetTest extends IntegrationTest {
public IntegrationMXNetTest() {
super("MXNET", "39253EC049D4A4E5FA0536AD34874B9D#1DBAEE1B1BD83FB7CB5F70AE91B29638#C4C23549E737A759721D6694C75D9771#5AF0CE68E408E8C1F000E49D72AC214A");
}
}
/**
*
* ******************************************************************************
* MontiCAR Modeling Family, www.se-rwth.de
* Copyright (c) 2017, Software Engineering Group at RWTH Aachen,
* All rights reserved.
*
* This project is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 3.0 of the License, or (at your option) any later version.
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this project. If not, see <http://www.gnu.org/licenses/>.
* *******************************************************************************
*/
package de.monticore.lang.monticar.emadl;
import de.monticore.lang.monticar.emadl.generator.Backend;
import de.monticore.lang.monticar.emadl.generator.EMADLGenerator;
import de.monticore.lang.monticar.emadl.generator.EMADLGeneratorCli;
import de.se_rwth.commons.logging.Log;
import freemarker.template.TemplateException;
import org.junit.Before;
import org.junit.Test;
import java.io.IOException;
import java.nio.charset.Charset;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Arrays;
import java.util.List;
import static junit.framework.TestCase.assertTrue;
import static org.junit.Assert.assertFalse;
public abstract class IntegrationTest extends AbstractSymtabTest {
private String backend;
private String trainingHash;
public IntegrationTest(String backend, String trainingHash) {
this.backend = backend;
this.trainingHash = trainingHash;
}
private Path netTrainingHashFile = Paths.get("./target/generated-sources-emadl/simpleCifar10/CifarNetwork.training_hash");
private void createHashFile() {
try {
netTrainingHashFile.toFile().getParentFile().mkdirs();
List<String> lines = Arrays.asList(this.trainingHash);
Files.write(netTrainingHashFile, lines, Charset.forName("UTF-8"));
}
catch(Exception e) {
assertFalse("Hash file could not be created", true);
}
}
private void deleteHashFile() {
try {
Files.delete(netTrainingHashFile);
}
catch(Exception e) {
assertFalse("Could not delete hash file", true);
}
}
@Before
public void setUp() {
// ensure an empty log
Log.getFindings().clear();
Log.enableFailQuick(false);
}
@Test
public void testDontRetrain1() {
// The training hash is stored during the first training, so the second one is skipped
Log.getFindings().clear();
String[] args = {"-m", "src/test/resources/models/", "-r", "simpleCifar10.Cifar10Classifier", "-b", this.backend};
EMADLGeneratorCli.main(args);
assertTrue(Log.getFindings().isEmpty());
Log.getFindings().clear();
EMADLGeneratorCli.main(args);
assertTrue(Log.getFindings().size() == 1);
assertTrue(Log.getFindings().get(0).getMsg().contains("skipped"));
deleteHashFile();
}
@Test
public void testDontRetrain2() {
// The training hash is written manually, so even the first training should be skipped
Log.getFindings().clear();
createHashFile();
String[] args = {"-m", "src/test/resources/models/", "-r", "simpleCifar10.Cifar10Classifier", "-b", this.backend};
EMADLGeneratorCli.main(args);
assertTrue(Log.getFindings().size() == 1);
assertTrue(Log.getFindings().get(0).getMsg().contains("skipped"));
deleteHashFile();
}
@Test
public void testDontRetrain3() {
// Multiple instances of the first NN are used. Only the first one should cause a training
Log.getFindings().clear();
String[] args = {"-m", "src/test/resources/models/", "-r", "instanceTestCifar.MainC", "-b", this.backend};
EMADLGeneratorCli.main(args);
assertTrue(Log.getFindings().size() == 1);
assertTrue(Log.getFindings().get(0).getMsg().contains("skipped"));