Commit 9b400095 authored by Julian Dierkes's avatar Julian Dierkes

added tests for GAN generation

parent 4460975d
Pipeline #265129 failed with stages
......@@ -101,6 +101,19 @@ public class CNNTrain2Gluon extends CNNTrainGenerator {
generateFilesFromConfigurationSymbol(configurationSymbol);
}
public void generate(Path modelsDirPath,
String rootModelName,
NNArchitectureSymbol trainedArchitecture,
NNArchitectureSymbol discriminatorNetwork,
NNArchitectureSymbol qNetwork) {
ConfigurationSymbol configurationSymbol = this.getConfigurationSymbol(modelsDirPath, rootModelName);
configurationSymbol.setTrainedArchitecture(trainedArchitecture);
configurationSymbol.setDiscriminatorNetwork(discriminatorNetwork);
configurationSymbol.setQNetwork(qNetwork);
this.setRootProjectModelsDir(modelsDirPath.toString());
generateFilesFromConfigurationSymbol(configurationSymbol);
}
public void generate(Path modelsDirPath, String rootModelName, NNArchitectureSymbol trainedArchitecture) {
generate(modelsDirPath, rootModelName, trainedArchitecture, null);
}
......
......@@ -368,4 +368,55 @@ public class GenerationTest extends AbstractSymtabTest {
);
}
@Test
public void testDefaultGANConfig() {
Log.getFindings().clear();
Path modelPath = Paths.get("src/test/resources/valid_tests/default-gan");
CNNTrain2Gluon trainGenerator = new CNNTrain2Gluon(rewardFunctionSourceGenerator);
NNArchitectureSymbol genArchitecture = NNArchitectureMockFactory.createArchitectureSymbolByCNNArchModel(
Paths.get("./src/test/resources/valid_tests/default-gan/arc"), "DefaultGAN");
NNArchitectureSymbol disArchitecture = NNArchitectureMockFactory.createArchitectureSymbolByCNNArchModel(
Paths.get("./src/test/resources/valid_tests/default-gan/arc"), "Discriminator");
trainGenerator.generate(modelPath, "DefaultGAN", genArchitecture, disArchitecture, null);
assertTrue(Log.getFindings().stream().noneMatch(Finding::isError));
checkFilesAreEqual(
Paths.get("./target/generated-sources-cnnarch"),
Paths.get("./src/test/resources/target_code/default-gan"),
Arrays.asList(
"gan/CNNCreator_Discriminator.py",
"gan/CNNNet_Discriminator.py",
"CNNTrainer_defaultGAN.py"
)
);
}
@Test
public void testInfoGANConfig() {
Log.getFindings().clear();
Path modelPath = Paths.get("src/test/resources/valid_tests/info-gan");
CNNTrain2Gluon trainGenerator = new CNNTrain2Gluon(rewardFunctionSourceGenerator);
NNArchitectureSymbol genArchitecture = NNArchitectureMockFactory.createArchitectureSymbolByCNNArchModel(
Paths.get("./src/test/resources/valid_tests/info-gan/arc"), "InfoGAN");
NNArchitectureSymbol disArchitecture = NNArchitectureMockFactory.createArchitectureSymbolByCNNArchModel(
Paths.get("./src/test/resources/valid_tests/info-gan/arc"), "InfoDiscriminator");
NNArchitectureSymbol qnetArchitecture = NNArchitectureMockFactory.createArchitectureSymbolByCNNArchModel(
Paths.get("./src/test/resources/valid_tests/info-gan/arc"), "InfoQNetwork");
trainGenerator.generate(modelPath, "InfoGAN", genArchitecture, disArchitecture, qnetArchitecture);
assertTrue(Log.getFindings().stream().noneMatch(Finding::isError));
checkFilesAreEqual(
Paths.get("./target/generated-sources-cnnarch"),
Paths.get("./src/test/resources/target_code/info-gan"),
Arrays.asList(
"gan/CNNCreator_InfoDiscriminator.py",
"gan/CNNNet_InfoDiscriminator.py",
"gan/CNNCreator_InfoQNetwork.py",
"gan/CNNNet_InfoQNetwork.py",
"CNNTrainer_infoGAN.py"
)
);
}
}
package de.monticore.lang.monticar.cnnarch.gluongenerator.preprocessing;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Lists;
import de.monticore.lang.monticar.cnnarch.gluongenerator.reinforcement.FunctionParameterChecker;
import de.monticore.lang.monticar.cnnarch.gluongenerator.reinforcement.RewardFunctionParameterAdapter;
import de.monticore.lang.monticar.cnntrain._symboltable.NNArchitectureSymbol;
import de.monticore.lang.monticar.generator.pythonwrapper.symbolservices.data.ComponentPortInformation;
import de.monticore.lang.monticar.generator.pythonwrapper.symbolservices.data.EmadlType;
import de.monticore.lang.monticar.generator.pythonwrapper.symbolservices.data.PortDirection;
import de.monticore.lang.monticar.generator.pythonwrapper.symbolservices.data.PortVariable;
import de.se_rwth.commons.logging.Finding;
import de.se_rwth.commons.logging.Log;
import org.junit.Before;
import org.junit.Test;
import java.util.List;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
public class PreprocessingParameterCheckerTest {
private static final PortVariable INPUT1_PORT = PortVariable.primitiveVariableFrom("port1", EmadlType.Q,
PortDirection.INPUT);
private static final PortVariable INPUT2_PORT = PortVariable.primitiveVariableFrom("port2", EmadlType.B,
PortDirection.INPUT);
private static final PortVariable OUTPUT1_PORT = PortVariable.primitiveVariableFrom("port1_out", EmadlType.Q,
PortDirection.OUTPUT);
private static final PortVariable OUTPUT2_PORT = PortVariable.primitiveVariableFrom("port2_out", EmadlType.Q,
PortDirection.OUTPUT);
private static final PortVariable OUTPUT3_PORT = PortVariable.primitiveVariableFrom("port1", EmadlType.Q,
PortDirection.OUTPUT);
private static final PortVariable OUTPUT4_PORT = PortVariable.primitiveVariableFrom("port2", EmadlType.Q,
PortDirection.OUTPUT);
private static final String COMPONENT_NAME = "TestProcessingComponent";
PreprocessingPortChecker uut = new PreprocessingPortChecker();
@Before
public void setup() {
Log.getFindings().clear();
Log.enableFailQuick(false);
}
@Test
public void validProcessing() {
// given
PreprocessingComponentParameterAdapter adapter = getValidProcessingAdapter();
// when
uut.check(adapter);
List<Finding> findings = Log.getFindings();
assertEquals(0, findings.stream().filter(Finding::isError).count());
}
@Test
public void invalidProcessingOutputNames() {
// given
PreprocessingComponentParameterAdapter adapter = getInvalidProcessingOutputNameAdapter();
// when
uut.check(adapter);
List<Finding> findings = Log.getFindings();
assertTrue(findings.stream().anyMatch(Finding::isError));
}
@Test
public void invalidProcessingOutputNumber() {
// given
PreprocessingComponentParameterAdapter adapter = getInvalidProcessingOutputNumberAdapter();
// when
uut.check(adapter);
List<Finding> findings = Log.getFindings();
assertTrue(findings.stream().anyMatch(Finding::isError));
}
private RewardFunctionParameterAdapter getComponentWithNonScalarOutput() {
ComponentPortInformation componentPortInformation = new ComponentPortInformation(COMPONENT_NAME);
componentPortInformation.addAllInputs(getValidInputPortVariables());
List<PortVariable> outputs = Lists.newArrayList(PortVariable.multidimensionalVariableFrom(
"output", EmadlType.Q, PortDirection.OUTPUT, Lists.newArrayList(2,2)));
componentPortInformation.addAllOutputs(outputs);
return new RewardFunctionParameterAdapter(componentPortInformation);
}
private RewardFunctionParameterAdapter getComponentWithTwoOutputs() {
ComponentPortInformation componentPortInformation
= new ComponentPortInformation(COMPONENT_NAME);
componentPortInformation.addAllInputs(getValidInputPortVariables());
List<PortVariable> outputs = getValidOutputPorts();
outputs.add(PortVariable.primitiveVariableFrom("output2", EmadlType.B, PortDirection.OUTPUT));
componentPortInformation.addAllOutputs(outputs);
return new RewardFunctionParameterAdapter(componentPortInformation);
}
private PreprocessingComponentParameterAdapter getValidProcessingAdapter() {
ComponentPortInformation componentPortInformation
= new ComponentPortInformation(COMPONENT_NAME);
componentPortInformation.addAllInputs(getValidInputPortVariables());
componentPortInformation.addAllOutputs(getValidOutputPorts());
return new PreprocessingComponentParameterAdapter(componentPortInformation);
}
private PreprocessingComponentParameterAdapter getInvalidProcessingOutputNameAdapter() {
ComponentPortInformation componentPortInformation
= new ComponentPortInformation(COMPONENT_NAME);
componentPortInformation.addAllInputs(getValidInputPortVariables());
componentPortInformation.addAllOutputs(getInvalidOutputPorts());
return new PreprocessingComponentParameterAdapter(componentPortInformation);
}
private PreprocessingComponentParameterAdapter getInvalidProcessingOutputNumberAdapter() {
ComponentPortInformation componentPortInformation
= new ComponentPortInformation(COMPONENT_NAME);
componentPortInformation.addAllInputs(getValidInputPortVariables());
componentPortInformation.addAllOutputs(getInvalidOutputPorts());
return new PreprocessingComponentParameterAdapter(componentPortInformation);
}
private List<PortVariable> getValidOutputPorts() {
return Lists.newArrayList(OUTPUT1_PORT, OUTPUT2_PORT);
}
private List<PortVariable> getInvalidOutputPorts() {
return Lists.newArrayList(OUTPUT3_PORT, OUTPUT4_PORT);
}
private List<PortVariable> getInvalidOutputNumberPorts() {
return Lists.newArrayList(OUTPUT1_PORT);
}
private List<PortVariable> getValidInputPortVariables() {
return Lists.newArrayList(INPUT1_PORT, INPUT2_PORT);
}
}
\ No newline at end of file
import mxnet as mx
import logging
import os
import numpy as np
import time
import shutil
from mxnet import gluon, autograd, nd
import CNNCreator_defaultGAN
import CNNDataLoader_defaultGAN
import CNNGanTrainer_defaultGAN
from gan.CNNCreator_Discriminator import CNNCreator_Discriminator
if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger()
handler = logging.FileHandler("train.log", "w", encoding=None, delay="true")
logger.addHandler(handler)
data_loader = CNNDataLoader_defaultGAN.CNNDataLoader_defaultGAN()
gen_creator = CNNCreator_defaultGAN.CNNCreator_defaultGAN()
dis_creator = CNNCreator_Discriminator()
defaultGAN_trainer = CNNGanTrainer_defaultGAN.CNNGanTrainer_defaultGAN(
data_loader,
gen_creator,
dis_creator,
)
defaultGAN_trainer.train(
batch_size=64,
num_epoch=10,
load_checkpoint=False,
context='cpu',
normalize=False,
preprocessing=False,
optimizer='adam',
optimizer_params={
'beta1': 0.5,
'learning_rate': 2.0E-4 },
discriminator_optimizer= 'adam',
discriminator_optimizer_params= {
'beta1': 0.5,
'learning_rate': 2.0E-4},
noise_distribution = 'gaussian',
noise_distribution_params = {
'mean_value': 0,
'spread_value': 1
},
k_value=1,
noise_input="noise",
gen_loss_weight=0.5,
dis_loss_weight=0.5,
log_period=10,
print_images=True,
)
import mxnet as mx
import logging
import os
from CNNNet_Discriminator import Net_0
class CNNCreator_Discriminator:
_model_dir_ = "model/Discriminator/"
_model_prefix_ = "model"
def __init__(self):
self.weight_initializer = mx.init.Normal()
self.networks = {}
def load(self, context):
earliestLastEpoch = None
for i, network in self.networks.items():
lastEpoch = 0
param_file = None
try:
os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + "_newest-0000.params")
except OSError:
pass
try:
os.remove(self._model_dir_ + self._model_prefix_ + "_" + str(i) + "_newest-symbol.json")
except OSError:
pass
if os.path.isdir(self._model_dir_):
for file in os.listdir(self._model_dir_):
if ".params" in file and self._model_prefix_ + "_" + str(i) in file:
epochStr = file.replace(".params","").replace(self._model_prefix_ + "_" + str(i) + "-","")
epoch = int(epochStr)
if epoch > lastEpoch:
lastEpoch = epoch
param_file = file
if param_file is None:
earliestLastEpoch = 0
else:
logging.info("Loading checkpoint: " + param_file)
network.load_parameters(self._model_dir_ + param_file)
if earliestLastEpoch == None or lastEpoch < earliestLastEpoch:
earliestLastEpoch = lastEpoch
return earliestLastEpoch
def construct(self, context, data_mean=None, data_std=None):
self.networks[0] = Net_0(data_mean=data_mean, data_std=data_std)
self.networks[0].collect_params().initialize(self.weight_initializer, ctx=context)
self.networks[0].hybridize()
self.networks[0](mx.nd.zeros((1, 1,64,64,), ctx=context))
if not os.path.exists(self._model_dir_):
os.makedirs(self._model_dir_)
for i, network in self.networks.items():
network.export(self._model_dir_ + self._model_prefix_ + "_" + str(i), epoch=0)
def getInputs(self):
inputs = {}
input_dimensions = (1,64,64,)
input_domains = (float,-1.0,1.0,)
inputs["data_"] = input_domains + (input_dimensions,)
return inputs
def getOutputs(self):
outputs = {}
output_dimensions = (1,4,4,)
output_domains = (float,0.0,1.0,)
outputs["dis_"] = output_domains + (output_dimensions,)
return outputs
import mxnet as mx
import numpy as np
import math
from mxnet import gluon
class ZScoreNormalization(gluon.HybridBlock):
def __init__(self, data_mean, data_std, **kwargs):
super(ZScoreNormalization, self).__init__(**kwargs)
with self.name_scope():
self.data_mean = self.params.get('data_mean', shape=data_mean.shape,
init=mx.init.Constant(data_mean.asnumpy().tolist()), differentiable=False)
self.data_std = self.params.get('data_std', shape=data_mean.shape,
init=mx.init.Constant(data_std.asnumpy().tolist()), differentiable=False)
def hybrid_forward(self, F, x, data_mean, data_std):
x = F.broadcast_sub(x, data_mean)
x = F.broadcast_div(x, data_std)
return x
class Padding(gluon.HybridBlock):
def __init__(self, padding, **kwargs):
super(Padding, self).__init__(**kwargs)
with self.name_scope():
self.pad_width = padding
def hybrid_forward(self, F, x):
x = F.pad(data=x,
mode='constant',
pad_width=self.pad_width,
constant_value=0)
return x
class NoNormalization(gluon.HybridBlock):
def __init__(self, **kwargs):
super(NoNormalization, self).__init__(**kwargs)
def hybrid_forward(self, F, x):
return x
class Reshape(gluon.HybridBlock):
def __init__(self, shape, **kwargs):
super(Reshape, self).__init__(**kwargs)
with self.name_scope():
self.shape = shape
def hybrid_forward(self, F, x):
return F.reshape(data=x, shape=self.shape)
class CustomRNN(gluon.HybridBlock):
def __init__(self, hidden_size, num_layers, dropout, bidirectional, **kwargs):
super(CustomRNN, self).__init__(**kwargs)
with self.name_scope():
self.rnn = gluon.rnn.RNN(hidden_size=hidden_size, num_layers=num_layers, dropout=dropout,
bidirectional=bidirectional, activation='tanh', layout='NTC')
def hybrid_forward(self, F, data, state0):
output, [state0] = self.rnn(data, [F.swapaxes(state0, 0, 1)])
return output, F.swapaxes(state0, 0, 1)
class CustomLSTM(gluon.HybridBlock):
def __init__(self, hidden_size, num_layers, dropout, bidirectional, **kwargs):
super(CustomLSTM, self).__init__(**kwargs)
with self.name_scope():
self.lstm = gluon.rnn.LSTM(hidden_size=hidden_size, num_layers=num_layers, dropout=dropout,
bidirectional=bidirectional, layout='NTC')
def hybrid_forward(self, F, data, state0, state1):
output, [state0, state1] = self.lstm(data, [F.swapaxes(state0, 0, 1), F.swapaxes(state1, 0, 1)])
return output, F.swapaxes(state0, 0, 1), F.swapaxes(state1, 0, 1)
class CustomGRU(gluon.HybridBlock):
def __init__(self, hidden_size, num_layers, dropout, bidirectional, **kwargs):
super(CustomGRU, self).__init__(**kwargs)
with self.name_scope():
self.gru = gluon.rnn.GRU(hidden_size=hidden_size, num_layers=num_layers, dropout=dropout,
bidirectional=bidirectional, layout='NTC')
def hybrid_forward(self, F, data, state0):
output, [state0] = self.gru(data, [F.swapaxes(state0, 0, 1)])
return output, F.swapaxes(state0, 0, 1)
class Net_0(gluon.HybridBlock):
def __init__(self, data_mean=None, data_std=None, **kwargs):
super(Net_0, self).__init__(**kwargs)
with self.name_scope():
if data_mean:
assert(data_std)
self.input_normalization_data_ = ZScoreNormalization(data_mean=data_mean['data_'],
data_std=data_std['data_'])
else:
self.input_normalization_data_ = NoNormalization()
self.conv1_padding = Padding(padding=(0,0,0,0,1,1,1,1))
self.conv1_ = gluon.nn.Conv2D(channels=64,
kernel_size=(4,4),
strides=(2,2),
use_bias=True)
# conv1_, output shape: {[64,32,32]}
self.leakyrelu1_ = gluon.nn.LeakyReLU(0.2)
self.conv2_padding = Padding(padding=(0,0,0,0,1,1,1,1))
self.conv2_ = gluon.nn.Conv2D(channels=128,
kernel_size=(4,4),
strides=(2,2),
use_bias=True)
# conv2_, output shape: {[128,16,16]}
self.batchnorm2_ = gluon.nn.BatchNorm()
# batchnorm2_, output shape: {[128,16,16]}
self.leakyrelu2_ = gluon.nn.LeakyReLU(0.2)
self.conv3_padding = Padding(padding=(0,0,0,0,1,1,1,1))
self.conv3_ = gluon.nn.Conv2D(channels=256,
kernel_size=(4,4),
strides=(2,2),
use_bias=True)
# conv3_, output shape: {[256,8,8]}
self.batchnorm3_ = gluon.nn.BatchNorm()
# batchnorm3_, output shape: {[256,8,8]}
self.leakyrelu3_ = gluon.nn.LeakyReLU(0.2)
self.conv4_padding = Padding(padding=(0,0,0,0,1,1,1,1))
self.conv4_ = gluon.nn.Conv2D(channels=512,
kernel_size=(4,4),
strides=(2,2),
use_bias=True)
# conv4_, output shape: {[512,4,4]}
self.batchnorm4_ = gluon.nn.BatchNorm()
# batchnorm4_, output shape: {[512,4,4]}
self.leakyrelu4_ = gluon.nn.LeakyReLU(0.2)
self.conv5_padding = Padding(padding=(0,0,0,0,2,1,2,1))
self.conv5_ = gluon.nn.Conv2D(channels=1,
kernel_size=(4,4),
strides=(1,1),
use_bias=True)
# conv5_, output shape: {[1,4,4]}
self.sigmoid5_ = gluon.nn.Activation(activation='sigmoid')
pass
def hybrid_forward(self, F, data_):
data_ = self.input_normalization_data_(data_)
conv1_padding = self.conv1_padding(data_)
conv1_ = self.conv1_(conv1_padding)
leakyrelu1_ = self.leakyrelu1_(conv1_)
conv2_padding = self.conv2_padding(leakyrelu1_)
conv2_ = self.conv2_(conv2_padding)
batchnorm2_ = self.batchnorm2_(conv2_)
leakyrelu2_ = self.leakyrelu2_(batchnorm2_)
conv3_padding = self.conv3_padding(leakyrelu2_)
conv3_ = self.conv3_(conv3_padding)
batchnorm3_ = self.batchnorm3_(conv3_)
leakyrelu3_ = self.leakyrelu3_(batchnorm3_)
conv4_padding = self.conv4_padding(leakyrelu3_)
conv4_ = self.conv4_(conv4_padding)
batchnorm4_ = self.batchnorm4_(conv4_)
leakyrelu4_ = self.leakyrelu4_(batchnorm4_)
conv5_padding = self.conv5_padding(leakyrelu4_)
conv5_ = self.conv5_(conv5_padding)
sigmoid5_ = self.sigmoid5_(conv5_)
dis_ = F.identity(sigmoid5_)
return dis_
import mxnet as mx
import logging
import os
import numpy as np
import time
import shutil
from mxnet import gluon, autograd, nd
import CNNCreator_infoGAN
import CNNDataLoader_infoGAN
import CNNGanTrainer_infoGAN
from gan.CNNCreator_InfoDiscriminator import CNNCreator_InfoDiscriminator
from gan.CNNCreator_InfoQNetwork import CNNCreator_InfoQNetwork
if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger()
handler = logging.FileHandler("train.log", "w", encoding=None, delay="true")
logger.addHandler(handler)
data_loader = CNNDataLoader_infoGAN.CNNDataLoader_infoGAN()
gen_creator = CNNCreator_infoGAN.CNNCreator_infoGAN()
dis_creator = CNNCreator_InfoDiscriminator()
qnet_creator = CNNCreator_InfoQNetwork()
infoGAN_trainer = CNNGanTrainer_infoGAN.CNNGanTrainer_infoGAN(
data_loader,
gen_creator,
dis_creator,
qnet_creator
)
infoGAN_trainer.train(
batch_size=64,
num_epoch=10,
load_checkpoint=False,
context='cpu',
normalize=False,
preprocessing=False,
optimizer='adam',
optimizer_params={
'beta1': 0.5,
'learning_rate': 2.0E-4 },
discriminator_optimizer= 'adam',
discriminator_optimizer_params= {
'beta1': 0.5,
'learning_rate': 2.0E-4},
noise_distribution = 'gaussian',
noise_distribution_params = {
'mean_value': 0,
'spread_value': 1
},
k_value=1,
noise_input="noise",
gen_loss_weight=0.5,
dis_loss_weight=0.5,
log_period=10,
print_images=True,
)
import mxnet as mx
import logging
import os
from CNNNet_InfoDiscriminator import Net_0