...
 
Commits (42)
# (c) https://github.com/MontiCore/monticore
# (c) https://github.com/MontiCore/monticore
stages:
- windows
#- windows
- linux
- deploy
......@@ -9,7 +9,7 @@ stages:
git masterJobLinux:
stage: deploy
image: maven:3-jdk-8
script:
script:
- mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B -U clean deploy --settings settings.xml -DskipTests
# - cat target/site/jacoco/index.html
# - mvn package sonar:sonar -s settings.xml
......@@ -19,7 +19,7 @@ git masterJobLinux:
integrationMXNetJobLinux:
stage: linux
image: registry.git.rwth-aachen.de/monticore/embeddedmontiarc/generators/emadl2cpp/integrationtests/mxnet:v0.0.4
image: registry.git.rwth-aachen.de/monticore/embeddedmontiarc/applications/gans/mnist-infogan/gans_mxnet:latest
script:
- mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B -U clean install --settings settings.xml -Dtest=IntegrationMXNetTest
......@@ -33,7 +33,7 @@ integrationCaffe2JobLinux:
integrationGluonJobLinux:
stage: linux
image: registry.git.rwth-aachen.de/monticore/embeddedmontiarc/generators/emadl2cpp/integrationtests/mxnet:v0.0.4
image: registry.git.rwth-aachen.de/monticore/embeddedmontiarc/applications/gans/mnist-infogan/gans_mxnet:latest
script:
- mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B -U clean install --settings settings.xml -Dtest=IntegrationGluonTest
......@@ -51,19 +51,19 @@ integrationPythonWrapperTest:
- mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B -U clean install --settings settings.xml -Dtest=IntegrationPythonWrapperTest
masterJobWindows:
stage: windows
script:
- mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B -U clean install --settings settings.xml -Dtest="GenerationTest,SymtabTest"
tags:
- Windows10
#masterJobWindows:
# stage: windows
# script:
# - mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B -U clean install --settings settings.xml -Dtest="GenerationTest,SymtabTest"
# tags:
# - Windows10
UnitTestJobLinux:
stage: linux
image: registry.git.rwth-aachen.de/monticore/embeddedmontiarc/applications/gans/mnist-infogan/gans_mxnet:latest
# image: maven:3-jdk-8
image: registry.git.rwth-aachen.de/monticore/embeddedmontiarc/generators/emadl2cpp/integrationtests/mxnet:v0.0.4
script:
- mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B -U clean install sonar:sonar --settings settings.xml -Dtest="GenerationTest,SymtabTest*"
script:
- mvn -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=warn -B -U clean install sonar:sonar --settings settings.xml -Dtest="GenerationTest,SymtabTest*"
# - cat target/site/jacoco/index.html
This diff is collapsed.
......@@ -9,7 +9,7 @@
<groupId>de.monticore.lang.monticar</groupId>
<artifactId>embedded-montiarc-emadl-generator</artifactId>
<version>0.3.8-SNAPSHOT</version>
<version>0.3.9-SNAPSHOT</version>
<!-- == PROJECT DEPENDENCIES ============================================= -->
......@@ -17,11 +17,11 @@
<!-- .. SE-Libraries .................................................. -->
<emadl.version>0.2.11-SNAPSHOT</emadl.version>
<CNNTrain.version>0.3.9-SNAPSHOT</CNNTrain.version>
<cnnarch-generator.version>0.0.5-SNAPSHOT</cnnarch-generator.version>
<CNNTrain.version>0.3.10-SNAPSHOT</CNNTrain.version>
<cnnarch-generator.version>0.0.6-SNAPSHOT</cnnarch-generator.version>
<cnnarch-mxnet-generator.version>0.2.17-SNAPSHOT</cnnarch-mxnet-generator.version>
<cnnarch-caffe2-generator.version>0.2.14-SNAPSHOT</cnnarch-caffe2-generator.version>
<cnnarch-gluon-generator.version>0.2.10-SNAPSHOT</cnnarch-gluon-generator.version>
<cnnarch-gluon-generator.version>0.2.11-SNAPSHOT</cnnarch-gluon-generator.version>
<cnnarch-tensorflow-generator.version>0.1.0-SNAPSHOT</cnnarch-tensorflow-generator.version>
<Common-MontiCar.version>0.0.19-SNAPSHOT</Common-MontiCar.version>
<embedded-montiarc-math-opt-generator>0.1.6</embedded-montiarc-math-opt-generator>
......@@ -94,7 +94,7 @@
<artifactId>common-monticar</artifactId>
<version>${Common-MontiCar.version}</version>
</dependency>
<dependency>
<groupId>de.monticore.lang.monticar</groupId>
<artifactId>cnnarch-tensorflow-generator</artifactId>
......@@ -246,12 +246,13 @@
</execution>
</executions>
</plugin>
<plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-surefire-plugin</artifactId>
<version>2.19.1</version>
<configuration>
<useSystemClassLoader>false</useSystemClassLoader>
<argLine>-Xmx1024m -XX:MaxPermSize=256m</argLine>
</configuration>
</plugin>
<plugin>
......
......@@ -14,10 +14,15 @@ import de.monticore.lang.monticar.cnnarch._symboltable.NetworkInstructionSymbol;
import de.monticore.lang.monticar.cnnarch.generator.CNNArchGenerator;
import de.monticore.lang.monticar.cnnarch.generator.CNNTrainGenerator;
import de.monticore.lang.monticar.cnnarch.generator.DataPathConfigParser;
import de.monticore.lang.monticar.cnnarch.generator.WeightsPathConfigParser;
import de.monticore.lang.monticar.cnnarch.gluongenerator.CNNTrain2Gluon;
import de.monticore.lang.monticar.cnnarch.gluongenerator.annotations.ArchitectureAdapter;
import de.monticore.lang.monticar.cnnarch.gluongenerator.preprocessing.PreprocessingComponentParameterAdapter;
import de.monticore.lang.monticar.cnnarch.gluongenerator.preprocessing.PreprocessingPortChecker;
import de.monticore.lang.monticar.cnntrain._cocos.CNNTrainCocos;
import de.monticore.lang.monticar.cnntrain._symboltable.ConfigurationSymbol;
import de.monticore.lang.monticar.cnntrain._symboltable.LearningMethod;
import de.monticore.lang.monticar.cnntrain._symboltable.PreprocessingComponentSymbol;
import de.monticore.lang.monticar.emadl._cocos.DataPathCocos;
import de.monticore.lang.monticar.emadl._cocos.EMADLCocos;
import de.monticore.lang.monticar.emadl.tagging.dltag.DataPathSymbol;
......@@ -30,6 +35,7 @@ import de.monticore.lang.monticar.generator.pythonwrapper.GeneratorPythonWrapper
import de.monticore.lang.monticar.generator.cpp.converter.TypeConverter;
import de.monticore.lang.monticar.generator.pythonwrapper.GeneratorPythonWrapperFactory;
import de.monticore.lang.monticar.generator.pythonwrapper.GeneratorPythonWrapperStandaloneApi;
import de.monticore.lang.monticar.generator.pythonwrapper.symbolservices.data.ComponentPortInformation;
import de.monticore.lang.tagging._symboltable.TagSymbol;
import de.monticore.lang.tagging._symboltable.TaggingResolver;
import de.monticore.symboltable.Scope;
......@@ -241,7 +247,7 @@ public class EMADLGenerator {
String b = backend.getBackendString(backend);
String trainingDataHash = "";
String testDataHash = "";
if (architecture.get().getDataPath() != null) {
if (b.equals("CAFFE2")) {
trainingDataHash = getChecksumForLargerFile(architecture.get().getDataPath() + "/train_lmdb/data.mdb");
......@@ -405,6 +411,21 @@ public class EMADLGenerator {
return dataPath;
}
protected String getWeightsPath(EMAComponentSymbol component, EMAComponentInstanceSymbol instance){
String weightsPath;
Path weightsPathDefinition = Paths.get(getModelsPath(), "weights_paths.txt");
if (weightsPathDefinition.toFile().exists()) {
WeightsPathConfigParser newParserConfig = new WeightsPathConfigParser(getModelsPath() + "weights_paths.txt");
weightsPath = newParserConfig.getWeightsPath(component.getFullName());
} else {
Log.info("No weights path definition found in " + weightsPathDefinition + ": "
+ "No pretrained weights will be loaded.", "EMADLGenerator");
weightsPath = null;
}
return weightsPath;
}
protected void generateComponent(List<FileContent> fileContents,
Set<EMAComponentInstanceSymbol> allInstances,
TaggingResolver taggingResolver,
......@@ -426,7 +447,9 @@ public class EMADLGenerator {
if (architecture.isPresent()){
cnnArchGenerator.check(architecture.get());
String dPath = getDataPath(taggingResolver, EMAComponentSymbol, componentInstanceSymbol);
String wPath = getWeightsPath(EMAComponentSymbol, componentInstanceSymbol);
architecture.get().setDataPath(dPath);
architecture.get().setWeightsPath(wPath);
architecture.get().setComponentName(EMAComponentSymbol.getFullName());
generateCNN(fileContents, taggingResolver, componentInstanceSymbol, architecture.get());
if (processedArchitecture != null) {
......@@ -621,7 +644,6 @@ public class EMADLGenerator {
}
discriminator.get().setComponentName(fullDiscriminatorName);
configuration.setDiscriminatorNetwork(new ArchitectureAdapter(fullDiscriminatorName, discriminator.get()));
//CNNTrainCocos.checkCriticCocos(configuration);
}
// Resolve QNetwork if present
......@@ -643,11 +665,16 @@ public class EMADLGenerator {
}
qnetwork.get().setComponentName(fullQNetworkName);
configuration.setQNetwork(new ArchitectureAdapter(fullQNetworkName, qnetwork.get()));
//CNNTrainCocos.checkCriticCocos(configuration);
}
if (configuration.getLearningMethod() == LearningMethod.GAN)
CNNTrainCocos.checkGANCocos(configuration);
if (configuration.hasPreprocessor()) {
String fullPreprocessorName = configuration.getPreprocessingName().get();
PreprocessingComponentSymbol preprocessingSymbol = configuration.getPreprocessingComponent().get();
List<String> fullNameOfComponent = preprocessingSymbol.getPreprocessingComponentName();
String fullPreprocessorName = String.join(".", fullNameOfComponent);
int indexOfFirstNameCharacter = fullPreprocessorName.lastIndexOf('.') + 1;
fullPreprocessorName = fullPreprocessorName.substring(0, indexOfFirstNameCharacter)
+ fullPreprocessorName.substring(indexOfFirstNameCharacter, indexOfFirstNameCharacter + 1).toUpperCase()
......@@ -665,13 +692,16 @@ public class EMADLGenerator {
try {
emamGen.generateFile(fileContent);
} catch (IOException e) {
//todo: fancy error message
e.printStackTrace();
}
}
String targetPath = getGenerationTargetPath();
pythonWrapper.generateAndTryBuilding(processor_instance, targetPath + "/pythonWrapper", targetPath);
ComponentPortInformation componentPortInformation;
componentPortInformation = pythonWrapper.generateAndTryBuilding(processor_instance, targetPath + "/pythonWrapper", targetPath);
PreprocessingComponentParameterAdapter componentParameter = new PreprocessingComponentParameterAdapter(componentPortInformation);
PreprocessingPortChecker.check(componentParameter);
preprocessingSymbol.setPreprocessingComponentParameter(componentParameter);
}
cnnTrainGenerator.setInstanceName(componentInstance.getFullName().replaceAll("\\.", "_"));
......
......@@ -22,7 +22,7 @@ public class RewardFunctionCppGenerator implements RewardFunctionSourceGenerator
.<EMAComponentInstanceSymbol>resolve(rootModel, EMAComponentInstanceSymbol.KIND);
if (!instanceSymbol.isPresent()) {
Log.error("Generation of reward function is not possible: Cannot resolve component instance "
Log.error("Generation of reward is not possible: Cannot resolve component instance "
+ rootModel);
}
......
......@@ -82,7 +82,7 @@ public class GenerationTest extends AbstractSymtabTest {
assertTrue(Log.getFindings().isEmpty());
}
/*@Test
@Test
public void testThreeInputGeneration() throws IOException, TemplateException {
Log.getFindings().clear();
String[] args = {"-m", "src/test/resources/models/", "-r", "ThreeInputCNN_M14", "-b", "MXNET", "-f", "n", "-c", "n"};
......@@ -96,7 +96,7 @@ public class GenerationTest extends AbstractSymtabTest {
String[] args = {"-m", "src/test/resources/models/", "-r", "MultipleOutputs", "-b", "MXNET", "-f", "n", "-c", "n"};
EMADLGeneratorCli.main(args);
assertTrue(Log.getFindings().size() == 1);
}*/
}
@Test
public void testVGGGeneration() throws IOException, TemplateException {
......@@ -163,7 +163,7 @@ public class GenerationTest extends AbstractSymtabTest {
"mnist_mnistClassifier_calculateClass.h",
"CNNTrainer_mnist_mnistClassifier_net.py"));
}
@Test
public void testMnistClassifierForGluon() throws IOException, TemplateException {
Log.getFindings().clear();
......@@ -237,7 +237,7 @@ public class GenerationTest extends AbstractSymtabTest {
@Test
public void testHashFunction() {
EMADLGenerator tester = new EMADLGenerator(Backend.MXNET);
try{
tester.getChecksumForFile("invalid Path!");
assertTrue("Hash method should throw IOException on invalid path", false);
......@@ -281,6 +281,78 @@ public class GenerationTest extends AbstractSymtabTest {
);
}
@Test
public void testGluonDefaultGANGeneration() {
Log.getFindings().clear();
String[] args = {"-m", "src/test/resources/models/ganModel", "-r", "defaultGAN.DefaultGANConnector", "-b", "GLUON", "-f", "n", "-c", "n"};
EMADLGeneratorCli.main(args);
assertTrue(Log.getFindings().stream().filter(Finding::isError).collect(Collectors.toList()).isEmpty());
checkFilesAreEqual(
Paths.get("./target/generated-sources-emadl"),
Paths.get("./src/test/resources/target_code/gluon/ganModel/defaultGAN"),
Arrays.asList(
"gan/CNNCreator_defaultGAN_defaultGANDiscriminator.py",
"gan/CNNNet_defaultGAN_defaultGANDiscriminator.py",
"CNNCreator_defaultGAN_defaultGANConnector_predictor.py",
"CNNGanTrainer_defaultGAN_defaultGANConnector_predictor.py",
"CNNNet_defaultGAN_defaultGANConnector_predictor.py",
"CNNPredictor_defaultGAN_defaultGANConnector_predictor.h",
"CNNTrainer_defaultGAN_defaultGANConnector_predictor.py",
"defaultGAN_defaultGANConnector.cpp",
"defaultGAN_defaultGANConnector.h",
"defaultGAN_defaultGANConnector_predictor.h",
"defaultGAN_defaultGANConnector.cpp",
"defaultGAN_defaultGANConnector.h",
"defaultGAN_defaultGANConnector_predictor.h"
)
);
}
@Test
public void testGluonInfoGANGeneration() {
Log.getFindings().clear();
String[] args = {"-m", "src/test/resources/models/ganModel", "-r", "infoGAN.InfoGANConnector", "-b", "GLUON", "-f", "n", "-c", "n"};
EMADLGeneratorCli.main(args);
assertTrue(Log.getFindings().stream().filter(Finding::isError).collect(Collectors.toList()).isEmpty());
checkFilesAreEqual(
Paths.get("./target/generated-sources-emadl"),
Paths.get("./src/test/resources/target_code/gluon/ganModel/infoGAN"),
Arrays.asList(
"gan/CNNCreator_infoGAN_infoGANDiscriminator.py",
"gan/CNNNet_infoGAN_infoGANDiscriminator.py",
"gan/CNNCreator_infoGAN_infoGANQNetwork.py",
"gan/CNNNet_infoGAN_infoGANQNetwork.py",
"CNNCreator_infoGAN_infoGANConnector_predictor.py",
"CNNDataLoader_infoGAN_infoGANConnector_predictor.py",
"CNNGanTrainer_infoGAN_infoGANConnector_predictor.py",
"CNNNet_infoGAN_infoGANConnector_predictor.py",
"CNNPredictor_infoGAN_infoGANConnector_predictor.h",
"CNNTrainer_infoGAN_infoGANConnector_predictor.py",
"infoGAN_infoGANConnector.cpp",
"infoGAN_infoGANConnector.h",
"infoGAN_infoGANConnector_predictor.h"
)
);
}
@Test
public void testGluonPreprocessingWithSupervised() throws IOException, TemplateException {
Log.getFindings().clear();
String[] args = {"-m", "src/test/resources/models/", "-r", "PreprocessingNetwork", "-b", "GLUON", "-f", "n", "-c", "n"};
EMADLGeneratorCli.main(args);
Log.info(Log.getFindings().toString(), "testGluonPreprocessinWithSupervised");
assertTrue(Log.getFindings().size() == 0);
}
@Test
public void testGluonPreprocessingWithGAN() throws IOException, TemplateException {
Log.getFindings().clear();
String[] args = {"-m", "src/test/resources/models/ganModel", "-r", "defaultGANPreprocessing.GeneratorWithPreprocessing", "-b", "GLUON", "-f", "n", "-c", "n"};
EMADLGeneratorCli.main(args);
Log.info(Log.getFindings().toString(), "testGluonPreprocessingWithGAN");
assertTrue(Log.getFindings().size() == 0);
}
@Test
public void testAlexNetTagging() {
Log.getFindings().clear();
......
......@@ -70,6 +70,25 @@ public class IntegrationGluonTest extends IntegrationTest {
assertTrue(Log.getFindings().isEmpty());
}
@Test
public void testGluonPreprocessingWithSupervised() {
Log.getFindings().clear();
deleteHashFile(Paths.get("./target/generated-sources-emadl/PreprocessingNetwork.training_hash"));
String[] args = {"-m", "src/test/resources/models/", "-r", "PreprocessingNetwork", "-b", "GLUON"};
EMADLGeneratorCli.main(args);
assertTrue(Log.getFindings().toString(),Log.getFindings().size() == 0);
}
@Test
public void testGluonPreprocessingWithGAN() {
Log.getFindings().clear();
deleteHashFile(Paths.get("./target/generated-sources-emadl/defaultGANPreprocessing/GeneratorWithPreprocessing.training_hash"));
String[] args = {"-m", "src/test/resources/models/ganModel", "-r", "defaultGANPreprocessing.GeneratorWithPreprocessing", "-b", "GLUON"};
EMADLGeneratorCli.main(args);
assertTrue(Log.getFindings().toString(), Log.getFindings().size() == 0);
}
private void deleteHashFile(Path hashFile) {
try {
Files.delete(hashFile);
......
/* (c) https://github.com/MontiCore/monticore */
configuration PreprocessingNetwork{
num_epoch:1
batch_size:1
log_period: 1
normalize:false
preprocessing_name: PreprocessingProcessing
context:cpu
load_checkpoint:false
optimizer:sgd{
learning_rate:0.1
learning_rate_decay:0.85
step_size:1000
weight_decay:0.0
}
}
/* (c) https://github.com/MontiCore/monticore */
component PreprocessingNetwork {
ports in Z(0:255)^{3, 32, 32} data,
out Q(0:1)^{10} softmax;
implementation CNN {
def conv(channels, kernel=1, stride=1){
Convolution(kernel=(kernel,kernel),channels=channels) ->
Relu() ->
Pooling(pool_type="max", kernel=(2,2), stride=(stride,stride))
}
data ->
conv(kernel=5, channels=20, stride=2) ->
conv(kernel=5, channels=50, stride=2) ->
FullyConnected(units=500) ->
Relu() ->
Dropout() ->
FullyConnected(units=10) ->
Softmax() ->
softmax;
}
}
/* (c) https://github.com/MontiCore/monticore */
component PreprocessingProcessing
{
ports in Q(-oo:oo)^{3,32,32} data,
in Q(0:1) softmax_label,
out Q(-1:1)^{3,32,32} data_out,
out Q(0:1) softmax_label_out;
implementation Math
{
data = data * 2;
data_out = data - 1;
softmax_label_out = softmax_label;
}
}
/* (c) https://github.com/MontiCore/monticore */
package cNNCalculator;
component Network{
ports in Z(0:255)^{1, 28, 28} image,
out Q(0:1)^{10} predictions;
......
cifar10.CifarNetwork src/test/resources/training_data/Cifar
simpleCifar10.CifarNetwork src/test/resources/training_data/Cifar
cNNCalculator.Network src/test/resources/training_data/Cifar
PreprocessingNetwork src/test/resources/training_data/Cifar
InstanceTest.NetworkB data/InstanceTest.NetworkB
Alexnet data/Alexnet
ThreeInputCNN_M14 data/ThreeInputCNN_M14
......
defaultGANPreprocessing.GeneratorWithPreprocessing src/test/resources/training_data/Cifar
defaultGAN.DefaultGANGenerator src/test/resources/training_data/Cifar
infoGAN.InfoGANGenerator src/test/resources/training_data/Cifar
/* (c) https://github.com/MontiCore/monticore */
package defaultGAN;
component DefaultGANConnector {
ports in Q(0:1)^{100} noise,
out Q(0:1)^{1, 64, 64} res;
instance DefaultGANGenerator predictor;
connect noise -> predictor.noise;
connect predictor.data -> res;
}
/* (c) https://github.com/MontiCore/monticore */
package defaultGAN;
component DefaultGANDiscriminator{
ports in Q(-1:1)^{1, 64, 64} data,
out Q(-oo:oo)^{1} dis;
implementation CNN {
data ->
Convolution(kernel=(4,4),channels=64, stride=(2,2)) ->
LeakyRelu(alpha=0.2) ->
Convolution(kernel=(4,4),channels=128, stride=(2,2)) ->
BatchNorm() ->
LeakyRelu(alpha=0.2) ->
Convolution(kernel=(4,4),channels=256, stride=(2,2)) ->
BatchNorm() ->
LeakyRelu(alpha=0.2) ->
Convolution(kernel=(4,4),channels=512, stride=(2,2)) ->
BatchNorm() ->
LeakyRelu(alpha=0.2) ->
Convolution(kernel=(4,4),channels=1, stride=(1,1)) ->
Sigmoid() ->
dis;
}
}
/* (c) https://github.com/MontiCore/monticore */
configuration DefaultGANGenerator{
learning_method:gan
discriminator_name: defaultGAN.DefaultGANDiscriminator
num_epoch:10
batch_size:64
normalize:false
context:cpu
noise_input: "noise"
print_images: true
log_period: 10
load_checkpoint:false
optimizer:adam{
learning_rate:0.0002
beta1:0.5
}
discriminator_optimizer:adam{
learning_rate:0.0002
beta1:0.5
}
noise_distribution:gaussian{
mean_value:0
spread_value:1
}
}
/* (c) https://github.com/MontiCore/monticore */
package defaultGAN;
component DefaultGANGenerator{
ports in Q(0:1)^{100} noise,
out Q(-1:1)^{1, 64, 64} data;
implementation CNN {
noise ->
Reshape(shape=(100,1,1)) ->
UpConvolution(kernel=(4,4), channels=512, stride=(1,1), padding="valid", no_bias=true) ->
BatchNorm() ->
Relu() ->
UpConvolution(kernel=(4,4), channels=256, stride=(2,2), no_bias=true) ->
BatchNorm() ->
Relu() ->
UpConvolution(kernel=(4,4), channels=128, stride=(2,2), no_bias=true) ->
BatchNorm() ->
Relu() ->
UpConvolution(kernel=(4,4), channels=64, stride=(2,2), no_bias=true) ->
BatchNorm() ->
Relu() ->
UpConvolution(kernel=(4,4), channels=1, stride=(2,2), no_bias=true) ->
Tanh() ->
data;
}
}
/* (c) https://github.com/MontiCore/monticore */
package defaultGANPreprocessing;
component DiscriminatorWithPreprocessing{
ports in Q(-1:1)^{3, 64, 64} data,
out Q(-oo:oo)^{1} dis;
implementation CNN {
data ->
Convolution(kernel=(4,4),channels=64, stride=(2,2)) ->
LeakyRelu(alpha=0.2) ->
Convolution(kernel=(4,4),channels=128, stride=(2,2)) ->
BatchNorm() ->
LeakyRelu(alpha=0.2) ->
Convolution(kernel=(4,4),channels=256, stride=(2,2)) ->
BatchNorm() ->
LeakyRelu(alpha=0.2) ->
Convolution(kernel=(4,4),channels=512, stride=(2,2)) ->
BatchNorm() ->
LeakyRelu(alpha=0.2) ->
Convolution(kernel=(4,4),channels=1, stride=(1,1)) ->
Sigmoid() ->
dis;
}
}
/* (c) https://github.com/MontiCore/monticore */
configuration GeneratorWithPreprocessing{
learning_method:gan
discriminator_name: defaultGANPreprocessing.DiscriminatorWithPreprocessing
num_epoch:1
batch_size:1
normalize:false
preprocessing_name: defaultGANPreprocessing.ProcessingWithPreprocessing
context:cpu
noise_input: "noise"
print_images: false
log_period: 1
load_checkpoint:false
optimizer:adam{
learning_rate:0.0002
beta1:0.5
}
discriminator_optimizer:adam{
learning_rate:0.0002
beta1:0.5
}
noise_distribution:gaussian{
mean_value:0
spread_value:1
}
}
/* (c) https://github.com/MontiCore/monticore */
package defaultGANPreprocessing;
component GeneratorWithPreprocessing{
ports in Q(0:1)^{100} noise,
out Q(-1:1)^{3, 64, 64} data;
implementation CNN {
noise ->
Reshape(shape=(100,1,1)) ->
UpConvolution(kernel=(4,4), channels=512, stride=(1,1), padding="valid", no_bias=true) ->
BatchNorm() ->
Relu() ->
UpConvolution(kernel=(4,4), channels=256, stride=(2,2), no_bias=true) ->
BatchNorm() ->
Relu() ->
UpConvolution(kernel=(4,4), channels=128, stride=(2,2), no_bias=true) ->
BatchNorm() ->
Relu() ->
UpConvolution(kernel=(4,4), channels=64, stride=(2,2), no_bias=true) ->
BatchNorm() ->
Relu() ->
UpConvolution(kernel=(4,4), channels=3, stride=(2,2), no_bias=true) ->
Tanh() ->
data;
}
}
/* (c) https://github.com/MontiCore/monticore */
package defaultGANPreprocessing;
component ProcessingWithPreprocessing
{
ports in Q(-oo:oo)^{3,32,32} data,
in Q(0:1) softmax_label,
out Q(-1:1)^{3,64,64} data_out,
out Q(0:1) softmax_label_out;
implementation Math
{
data = data * 2;
data = data - 1;
data_out = scaleCube(data, 0, 64, 64);
softmax_label_out = softmax_label;
}
}
/* (c) https://github.com/MontiCore/monticore */
package infoGAN;
component InfoGANConnector {
ports in Q(0:1)^{62} noise,
in Z(0:9)^{10} c1,
out Q(0:1)^{1, 64, 64} res;
instance InfoGANGenerator predictor;
connect noise -> predictor.noise;
connect c1 -> predictor.c1;
connect predictor.data -> res;
}
/* (c) https://github.com/MontiCore/monticore */
package infoGAN;
component InfoGANDiscriminator{
ports in Q(-1:1)^{1, 28, 28} data,
out Q(-oo:oo)^{1024} features,
out Q(-oo:oo)^{1} dis;
implementation CNN {
data ->
Convolution(kernel=(4,4),channels=64, stride=(2,2)) ->
LeakyRelu(alpha=0.2) ->
Convolution(kernel=(4,4),channels=128, stride=(2,2)) ->
BatchNorm() ->
LeakyRelu(alpha=0.2) ->
Convolution(kernel=(4,4),channels=256, stride=(2,2)) ->
BatchNorm() ->
LeakyRelu(alpha=0.2) ->
Convolution(kernel=(4,4),channels=512, stride=(2,2)) ->
BatchNorm() ->
LeakyRelu(alpha=0.2) ->
(
Convolution(kernel=(4,4),channels=1,stride=(1,1)) ->
Sigmoid() ->
dis
|
features
);
}
}
/* (c) https://github.com/MontiCore/monticore */
configuration InfoGANGenerator{
learning_method:gan
discriminator_name: infoGAN.InfoGANDiscriminator
qnet_name: infoGAN.InfoGANQNetwork
num_epoch: 5
batch_size: 64
normalize: false
noise_input: "noise"
context: cpu
load_checkpoint: false
optimizer:adam{
learning_rate:0.0002
beta1:0.5
}
discriminator_optimizer:adam{
learning_rate:0.0002
beta1:0.5
}
noise_distribution:gaussian{
mean_value:0
spread_value:1
}
log_period: 10
print_images: true
}
/* (c) https://github.com/MontiCore/monticore */
package infoGAN;
component InfoGANGenerator{
ports in Q(0:1)^{62} noise,
in Z(0:9)^{10} c1,
out Q(-1:1)^{1, 64, 64} data;
implementation CNN {
(
noise
|
c1
) ->
Concatenate() ->
Reshape(shape=(72,1,1)) ->
UpConvolution(kernel=(4,4), channels=512, stride=(1,1), padding="valid", no_bias=true) ->
BatchNorm() ->
Relu() ->
UpConvolution(kernel=(4,4), channels=256, stride=(2,2), no_bias=true) ->
BatchNorm() ->
Relu() ->
UpConvolution(kernel=(4,4), channels=128, stride=(2,2), no_bias=true) ->
BatchNorm() ->
Relu() ->
UpConvolution(kernel=(4,4), channels=64, stride=(2,2), no_bias=true) ->
BatchNorm() ->
Relu() ->
UpConvolution(kernel=(4,4), channels=1, stride=(2,2), no_bias=true) ->
Tanh() ->
data;
}
}
/* (c) https://github.com/MontiCore/monticore */
package infoGAN;
component InfoGANQNetwork{
ports in Q(-oo:oo)^{512, 4, 4} features,
out Q(-oo:oo)^{10} c1;
implementation CNN {
features ->
FullyConnected(units=128, no_bias=true) ->
BatchNorm() ->
Relu() ->
FullyConnected(units=10, no_bias=true) ->
Softmax() ->
c1;
}
}
......@@ -50,13 +50,42 @@ class SoftmaxCrossEntropyLossIgnoreIndices(gluon.loss.Loss):
if self._sparse_label:
loss = -pick(pred, label, axis=self._axis, keepdims=True)
else:
label = _reshape_like(F, label, pred)
label = gluon.loss._reshape_like(F, label, pred)
loss = -(pred * label).sum(axis=self._axis, keepdims=True)
# ignore some indices for loss, e.g. <pad> tokens in NLP applications
for i in self._ignore_indices:
loss = loss * mx.nd.logical_not(mx.nd.equal(mx.nd.argmax(pred, axis=1), mx.nd.ones_like(mx.nd.argmax(pred, axis=1))*i))
loss = loss * mx.nd.logical_not(mx.nd.equal(mx.nd.argmax(pred, axis=1), mx.nd.ones_like(mx.nd.argmax(pred, axis=1))*i) * mx.nd.equal(mx.nd.argmax(pred, axis=1), label))
return loss.mean(axis=self._batch_axis, exclude=True)
class DiceLoss(gluon.loss.Loss):
def __init__(self, axis=-1, sparse_label=True, from_logits=False, weight=None,
batch_axis=0, **kwargs):
super(DiceLoss, self).__init__(weight, batch_axis, **kwargs)
self._axis = axis
self._sparse_label = sparse_label
self._from_logits = from_logits
def dice_loss(self, F, pred, label):
smooth = 1.
pred_y = F.argmax(pred, axis = self._axis)
intersection = pred_y * label
score = (2 * F.mean(intersection, axis=self._batch_axis, exclude=True) + smooth) \
/ (F.mean(label, axis=self._batch_axis, exclude=True) + F.mean(pred_y, axis=self._batch_axis, exclude=True) + smooth)
return - F.log(score)
def hybrid_forward(self, F, pred, label, sample_weight=None):
if not self._from_logits:
pred = F.log_softmax(pred, self._axis)
if self._sparse_label:
loss = -F.pick(pred, label, axis=self._axis, keepdims=True)
else:
label = gluon.loss._reshape_like(F, label, pred)
loss = -F.sum(pred*label, axis=self._axis, keepdims=True)
loss = gluon.loss._apply_weighting(F, loss, self._weight, sample_weight)
diceloss = self.dice_loss(F, pred, label)
return F.mean(loss, axis=self._batch_axis, exclude=True) + diceloss
@mx.metric.register
class BLEU(mx.metric.EvalMetric):
N = 4
......@@ -244,14 +273,17 @@ class CNNSupervisedTrainer_mnist_mnistClassifier_net:
ignore_indices = [loss_params['ignore_indices']] if 'ignore_indices' in loss_params else []
if loss == 'softmax_cross_entropy':
fromLogits = loss_params['from_logits'] if 'from_logits' in loss_params else False
loss_function = mx.gluon.loss.SoftmaxCrossEntropyLoss(from_logits=fromLogits, sparse_label=sparseLabel)
loss_function = mx.gluon.loss.SoftmaxCrossEntropyLoss(axis=loss_axis, from_logits=fromLogits, sparse_label=sparseLabel, batch_axis=batch_axis)
elif loss == 'softmax_cross_entropy_ignore_indices':
fromLogits = loss_params['from_logits'] if 'from_logits' in loss_params else False
loss_function = SoftmaxCrossEntropyLossIgnoreIndices(ignore_indices=ignore_indices, from_logits=fromLogits, sparse_label=sparseLabel)
loss_function = SoftmaxCrossEntropyLossIgnoreIndices(axis=loss_axis, ignore_indices=ignore_indices, from_logits=fromLogits, sparse_label=sparseLabel, batch_axis=batch_axis)
elif loss == 'sigmoid_binary_cross_entropy':
loss_function = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss()
elif loss == 'cross_entropy':
loss_function = CrossEntropyLoss(sparse_label=sparseLabel)
loss_function = CrossEntropyLoss(axis=loss_axis, sparse_label=sparseLabel, batch_axis=batch_axis)
elif loss == 'dice_loss':
loss_weight = loss_params['loss_weight'] if 'loss_weight' in loss_params else None
loss_function = DiceLoss(axis=loss_axis, weight=loss_weight, sparse_label=sparseLabel, batch_axis=batch_axis)
elif loss == 'l2':
loss_function = mx.gluon.loss.L2Loss()
elif loss == 'l1':
......@@ -323,7 +355,7 @@ class CNNSupervisedTrainer_mnist_mnistClassifier_net:
train_test_iter.reset()
metric = mx.metric.create(eval_metric, **eval_metric_params)
for batch_i, batch in enumerate(train_test_iter):
if True:
if True:
labels = [batch.label[i].as_in_context(mx_context) for i in range(1)]
image_ = batch.data[0].as_in_context(mx_context)
......@@ -394,7 +426,7 @@ class CNNSupervisedTrainer_mnist_mnistClassifier_net:
test_iter.reset()
metric = mx.metric.create(eval_metric, **eval_metric_params)
for batch_i, batch in enumerate(test_iter):
if True:
if True:
labels = [batch.label[i].as_in_context(mx_context) for i in range(1)]
image_ = batch.data[0].as_in_context(mx_context)
......
import mxnet as mx
import logging
import os
import shutil
from CNNNet_mnist_mnistClassifier_net import Net_0
......@@ -11,6 +12,7 @@ class CNNCreator_mnist_mnistClassifier_net:
def __init__(self):
self.weight_initializer = mx.init.Normal()
self.networks = {}
self._weights_dir_ = None
def load(self, context):
earliestLastEpoch = None
......@@ -47,6 +49,29 @@ class CNNCreator_mnist_mnistClassifier_net:
return earliestLastEpoch
def load_pretrained_weights(self, context):
if os.path.isdir(self._model_dir_):
shutil.rmtree(self._model_dir_)
if self._weights_dir_ is not None:
for i, network in self.networks.items():
# param_file = self._model_prefix_ + "_" + str(i) + "_newest-0000.params"
param_file = None
if os.path.isdir(self._weights_dir_):
lastEpoch = 0
for file in os.listdir(self._weights_dir_):
if ".params" in file and self._model_prefix_ + "_" + str(i) in file:
epochStr = file.replace(".params","").replace(self._model_prefix_ + "_" + str(i) + "-","")
epoch = int(epochStr)
if epoch > lastEpoch:
lastEpoch = epoch
param_file = file
logging.info("Loading pretrained weights: " + self._weights_dir_ + param_file)
network.load_parameters(self._weights_dir_ + param_file, allow_missing=True, ignore_extra=True)
else:
logging.info("No pretrained weights available at: " + self._weights_dir_ + param_file)
def construct(self, context, data_mean=None, data_std=None):
self.networks[0] = Net_0(data_mean=data_mean, data_std=data_std)
self.networks[0].collect_params().initialize(self.weight_initializer, ctx=context)
......@@ -58,3 +83,17 @@ class CNNCreator_mnist_mnistClassifier_net:
for i, network in self.networks.items():
network.export(self._model_dir_ + self._model_prefix_ + "_" + str(i), epoch=0)
def getInputs(self):
inputs = {}
input_dimensions = (1,28,28,)
input_domains = (int,0.0,255.0,)
inputs["image_"] = input_domains + (input_dimensions,)
return inputs
def getOutputs(self):
outputs = {}
output_dimensions = (10,1,1,)
output_domains = (float,0.0,1.0,)
outputs["predictions_"] = output_domains + (output_dimensions,)
return outputs
......@@ -4,7 +4,6 @@ import mxnet as mx
import logging
import sys
import numpy as np
import cv2
import importlib
from mxnet import nd
......@@ -78,6 +77,7 @@ class CNNDataLoader_mnist_mnistClassifier_net:
train_label = {}
data_mean = {}
data_std = {}
train_images = {}
shape_output = self.preprocess_data(instance, inp, 0, train_h5)
train_len = len(train_h5[self._input_names_[0]])
......@@ -140,6 +140,7 @@ class CNNDataLoader_mnist_mnistClassifier_net:
for output_name in self._output_names_:
test_label[output_name][i] = getattr(shape_output, output_name + "_out")
test_images = {}
if 'images' in test_h5:
test_images = test_h5['images']
......@@ -151,7 +152,7 @@ class CNNDataLoader_mnist_mnistClassifier_net:
def preprocess_data(self, instance_wrapper, input_wrapper, index, data_h5):
for input_name in self._input_names_:
data = data_h5[input_name][0]
data = data_h5[input_name][index]
attr = getattr(input_wrapper, input_name)
if (type(data)) == np.ndarray:
data = np.asfortranarray(data).astype(attr.dtype)
......@@ -159,7 +160,7 @@ class CNNDataLoader_mnist_mnistClassifier_net:
data = type(attr)(data)
setattr(input_wrapper, input_name, data)
for output_name in self._output_names_:
data = data_h5[output_name][0]
data = data_h5[output_name][index]
attr = getattr(input_wrapper, output_name)
if (type(data)) == np.ndarray:
data = np.asfortranarray(data).astype(attr.dtype)
......
......@@ -148,16 +148,3 @@ class Net_0(gluon.HybridBlock):
return predictions_
def getInputs(self):
inputs = {}
input_dimensions = (1,28,28)
input_domains = (int,0.0,255.0)
inputs["image_"] = input_domains + (input_dimensions,)
return inputs
def getOutputs(self):
outputs = {}
output_dimensions = (10,1,1)
output_domains = (float,0.0,1.0)
outputs["predictions_"] = output_domains + (output_dimensions,)
return outputs
......@@ -50,13 +50,89 @@ class SoftmaxCrossEntropyLossIgnoreIndices(gluon.loss.Loss):
if self._sparse_label:
loss = -pick(pred, label, axis=self._axis, keepdims=True)
else:
label = _reshape_like(F, label, pred)
label = gluon.loss._reshape_like(F, label, pred)
loss = -(pred * label).sum(axis=self._axis, keepdims=True)
# ignore some indices for loss, e.g. <pad> tokens in NLP applications
for i in self._ignore_indices:
loss = loss * mx.nd.logical_not(mx.nd.equal(mx.nd.argmax(pred, axis=1), mx.nd.ones_like(mx.nd.argmax(pred, axis=1))*i) * mx.nd.equal(mx.nd.argmax(pred, axis=1), label))
return loss.mean(axis=self._batch_axis, exclude=True)
class DiceLoss(gluon.loss.Loss):
def __init__(self, axis=-1, sparse_label=True, from_logits=False, weight=None,
batch_axis=0, **kwargs):
super(DiceLoss, self).__init__(weight, batch_axis, **kwargs)
self._axis = axis
self._sparse_label = sparse_label
self._from_logits = from_logits
def dice_loss(self, F, pred, label):
smooth = 1.
pred_y = F.argmax(pred, axis = self._axis)
intersection = pred_y * label
score = (2 * F.mean(intersection, axis=self._batch_axis, exclude=True) + smooth) \
/ (F.mean(label, axis=self._batch_axis, exclude=True) + F.mean(pred_y, axis=self._batch_axis, exclude=True) + smooth)
return - F.log(score)
def hybrid_forward(self, F, pred, label, sample_weight=None):
if not self._from_logits:
pred = F.log_softmax(pred, self._axis)
if self._sparse_label:
loss = -F.pick(pred, label, axis=self._axis, keepdims=True)
else:
label = gluon.loss._reshape_like(F, label, pred)
loss = -F.sum(pred*label, axis=self._axis, keepdims=True)
loss = gluon.loss._apply_weighting(F, loss, self._weight, sample_weight)
diceloss = self.dice_loss(F, pred, label)
return F.mean(loss, axis=self._batch_axis, exclude=True) + diceloss
class SoftmaxCrossEntropyLossIgnoreLabel(gluon.loss.Loss):
def __init__(self, axis=-1, from_logits=False, weight=None,
batch_axis=0, ignore_label=255, **kwargs):
super(SoftmaxCrossEntropyLossIgnoreLabel, self).__init__(weight, batch_axis, **kwargs)
self._axis = axis
self._from_logits = from_logits
self._ignore_label = ignore_label
def hybrid_forward(self, F, output, label, sample_weight=None):
if not self._from_logits:
output = F.log_softmax(output, axis=self._axis)
valid_label_map = (label != self._ignore_label)
loss = -(F.pick(output, label, axis=self._axis, keepdims=True) * valid_label_map )
loss = gluon.loss._apply_weighting(F, loss, self._weight, sample_weight)
return F.sum(loss) / F.sum(valid_label_map)
@mx.metric.register
class ACCURACY_IGNORE_LABEL(mx.metric.EvalMetric):
"""Ignores a label when computing accuracy.
"""
def __init__(self, axis=1, metric_ignore_label=255, name='accuracy',
output_names=None, label_names=None):
super(ACCURACY_IGNORE_LABEL, self).__init__(
name, axis=axis,
output_names=output_names, label_names=label_names)
self.axis = axis
self.ignore_label = metric_ignore_label
def update(self, labels, preds):
mx.metric.check_label_shapes(labels, preds)
for label, pred_label in zip(labels, preds):
if pred_label.shape != label.shape:
pred_label = mx.nd.argmax(pred_label, axis=self.axis, keepdims=True)
label = label.astype('int32')
pred_label = pred_label.astype('int32').as_in_context(label.context)
mx.metric.check_label_shapes(label, pred_label)
correct = mx.nd.sum( (label == pred_label) * (label != self.ignore_label) ).asscalar()
total = mx.nd.sum( (label != self.ignore_label) ).asscalar()
self.sum_metric += correct
self.num_inst += total
@mx.metric.register
class BLEU(mx.metric.EvalMetric):
N = 4
......@@ -192,6 +268,7 @@ class CNNSupervisedTrainer_mnist_mnistClassifier_net:
optimizer_params=(('learning_rate', 0.001),),
load_checkpoint=True,
checkpoint_period=5,
load_pretrained=False,
log_period=50,
context='gpu',
save_attention_image=False,
......@@ -236,6 +313,8 @@ class CNNSupervisedTrainer_mnist_mnistClassifier_net:
begin_epoch = 0
if load_checkpoint:
begin_epoch = self._net_creator.load(mx_context)
elif load_pretrained:
self._net_creator.load_pretrained_weights(mx_context)
else:
if os.path.isdir(self._net_creator._model_dir_):
shutil.rmtree(self._net_creator._model_dir_)
......@@ -253,16 +332,25 @@ class CNNSupervisedTrainer_mnist_mnistClassifier_net:
margin = loss_params['margin'] if 'margin' in loss_params else 1.0
sparseLabel = loss_params['sparse_label'] if 'sparse_label' in loss_params else True
ignore_indices = [loss_params['ignore_indices']] if 'ignore_indices' in loss_params else []
loss_axis = loss_params['loss_axis'] if 'loss_axis' in loss_params else -1
batch_axis = loss_params['batch_axis'] if 'batch_axis' in loss_params else 0
if loss == 'softmax_cross_entropy':
fromLogits = loss_params['from_logits'] if 'from_logits' in loss_params else False
loss_function = mx.gluon.loss.SoftmaxCrossEntropyLoss(from_logits=fromLogits, sparse_label=sparseLabel)
loss_function = mx.gluon.loss.SoftmaxCrossEntropyLoss(axis=loss_axis, from_logits=fromLogits, sparse_label=sparseLabel, batch_axis=batch_axis)
elif loss == 'softmax_cross_entropy_ignore_indices':
fromLogits = loss_params['from_logits'] if 'from_logits' in loss_params else False
loss_function = SoftmaxCrossEntropyLossIgnoreIndices(ignore_indices=ignore_indices, from_logits=fromLogits, sparse_label=sparseLabel)
loss_function = SoftmaxCrossEntropyLossIgnoreIndices(axis=loss_axis, ignore_indices=ignore_indices, from_logits=fromLogits, sparse_label=sparseLabel, batch_axis=batch_axis)
elif loss == 'sigmoid_binary_cross_entropy':
loss_function = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss()
elif loss == 'cross_entropy':
loss_function = CrossEntropyLoss(sparse_label=sparseLabel)
loss_function = CrossEntropyLoss(axis=loss_axis, sparse_label=sparseLabel, batch_axis=batch_axis)
elif loss == 'dice_loss':
loss_weight = loss_params['loss_weight'] if 'loss_weight' in loss_params else None
loss_function = DiceLoss(axis=loss_axis, weight=loss_weight, sparse_label=sparseLabel, batch_axis=batch_axis)
elif loss == 'softmax_cross_entropy_ignore_label':
loss_weight = loss_params['loss_weight'] if 'loss_weight' in loss_params else None
loss_ignore_label = loss_params['loss_ignore_label'] if 'loss_ignore_label' in loss_params else None
loss_function = SoftmaxCrossEntropyLossIgnoreLabel(axis=loss_axis, ignore_label=loss_ignore_label, weight=loss_weight, batch_axis=batch_axis)
elif loss == 'l2':
loss_function = mx.gluon.loss.L2Loss()
elif loss == 'l1':
......@@ -510,11 +598,7 @@ class CNNSupervisedTrainer_mnist_mnistClassifier_net:
predictions = []
for output_name in outputs:
if mx.nd.shape_array(mx.nd.squeeze(output_name)).size > 1:
predictions.append(mx.nd.argmax(output_name, axis=1))
#ArgMax already applied
else:
predictions.append(output_name)
predictions.append(output_name)
metric.update(preds=predictions, labels=labels)
test_metric_score = metric.get()[1]
......
......@@ -21,6 +21,7 @@ if __name__ == "__main__":
batch_size=64,
num_epoch=11,
context='gpu',
preprocessing=False,
eval_metric='accuracy',
eval_metric_params={
},
......
import mxnet as mx
import logging
import os
import shutil
from CNNNet_defaultGAN_defaultGANConnector_predictor import Net_0
class CNNCreator_defaultGAN_defaultGANConnector_predictor:
_model_dir_ = "model/defaultGAN.DefaultGANGenerator/"
_model_prefix_ = "model"
def __init__(self):
self.weight_initializer = mx.init.Normal()
self.networks = {}
self._weights_dir_ = None
def load(self, context):
earliestLastEpoch = None
for i, network in self.networks.items():
lastEpoch = 0
param_file = None
try:
os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + "_newest-0000.params")
except OSError:
pass
try:
os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + "_newest-symbol.json")
except OSError:
pass
if os.path.isdir(self._model_dir_):
for file in os.listdir(self._model_dir_):
if ".params" in file and self._model_prefix_ + "_" + str(i) in file:
epochStr = file.replace(".params","").replace(self._model_prefix_ + "_" + str(i) + "-","")
epoch = int(epochStr)
if epoch > lastEpoch:
lastEpoch = epoch
param_file = file
if param_file is None:
earliestLastEpoch = 0
else:
logging.info("Loading checkpoint: " + param_file)
network.load_parameters(self._model_dir_ + param_file)
if earliestLastEpoch == None or lastEpoch < earliestLastEpoch:
earliestLastEpoch = lastEpoch
return earliestLastEpoch
def load_pretrained_weights(self, context):
if os.path.isdir(self._model_dir_):
shutil.rmtree(self._model_dir_)
if self._weights_dir_ is not None:
for i, network in self.networks.items():
# param_file = self._model_prefix_ + "_" + str(i) + "_newest-0000.params"
param_file = None
if os.path.isdir(self._weights_dir_):
lastEpoch = 0
for file in os.listdir(self._weights_dir_):
if ".params" in file and self._model_prefix_ + "_" + str(i) in file:
epochStr = file.replace(".params","").replace(self._model_prefix_ + "_" + str(i) + "-","")
epoch = int(epochStr)
if epoch > lastEpoch:
lastEpoch = epoch
param_file = file
logging.info("Loading pretrained weights: " + self._weights_dir_ + param_file)
network.load_parameters(self._weights_dir_ + param_file, allow_missing=True, ignore_extra=True)
else:
logging.info("No pretrained weights available at: " + self._weights_dir_ + param_file)
def construct(self, context, data_mean=None, data_std=None):
self.networks[0] = Net_0(data_mean=data_mean, data_std=data_std)
self.networks[0].collect_params().initialize(self.weight_initializer, ctx=context)
self.networks[0].hybridize()
self.networks[0](mx.nd.zeros((1, 100,), ctx=context))
if not os.path.exists(self._model_dir_):