Commit 97795107 authored by Julian Dierkes's avatar Julian Dierkes
Browse files

adding GAN features

parent 912f2fee
......@@ -64,6 +64,9 @@ public class CNNArch2Gluon extends CNNArchGenerator {
temp = controller.process("CNNSupervisedTrainer", Target.PYTHON);
fileContentMap.put(temp.getKey(), temp.getValue());
temp = controller.process("CNNGanTrainer", Target.PYTHON);
fileContentMap.put(temp.getKey(), temp.getValue());
temp = controller.process("execute", Target.CPP);
fileContentMap.put(temp.getKey().replace(".h", ""), temp.getValue());
......
......@@ -9,10 +9,12 @@ public class CNNArch2GluonLayerSupportChecker extends LayerSupportChecker {
public CNNArch2GluonLayerSupportChecker() {
supportedLayerList.add(AllPredefinedLayers.FULLY_CONNECTED_NAME);
supportedLayerList.add(AllPredefinedLayers.CONVOLUTION_NAME);
supportedLayerList.add(AllPredefinedLayers.TRANS_CONV_NAME);
supportedLayerList.add(AllPredefinedLayers.SOFTMAX_NAME);
supportedLayerList.add(AllPredefinedLayers.SIGMOID_NAME);
supportedLayerList.add(AllPredefinedLayers.TANH_NAME);
supportedLayerList.add(AllPredefinedLayers.RELU_NAME);
supportedLayerList.add(AllPredefinedLayers.LEAKY_RELU_NAME);
supportedLayerList.add(AllPredefinedLayers.DROPOUT_NAME);
supportedLayerList.add(AllPredefinedLayers.POOLING_NAME);
supportedLayerList.add(AllPredefinedLayers.GLOBAL_POOLING_NAME);
......
......@@ -32,6 +32,7 @@ import java.util.stream.Collectors;
public class CNNTrain2Gluon extends CNNTrainGenerator {
private static final String REINFORCEMENT_LEARNING_FRAMEWORK_MODULE = "reinforcement_learning";
private static final String GAN_LEARNING_FRAMEWORK_MODULE = "gan";
private final RewardFunctionSourceGenerator rewardFunctionSourceGenerator;
private String rootProjectModelsDir;
......@@ -62,7 +63,7 @@ public class CNNTrain2Gluon extends CNNTrainGenerator {
if (configuration.getLearningMethod().equals(LearningMethod.REINFORCEMENT)) {
throw new IllegalStateException("Cannot call generate of reinforcement configuration without specifying " +
"the trained architecture");
"the trained architecture");
}
generateFilesFromConfigurationSymbol(configuration);
......@@ -115,6 +116,46 @@ public class CNNTrain2Gluon extends CNNTrainGenerator {
if (configData.isSupervisedLearning()) {
String cnnTrainTemplateContent = templateConfiguration.processTemplate(ftlContext, "CNNTrainer.ftl");
fileContentMap.put("CNNTrainer_" + getInstanceName() + ".py", cnnTrainTemplateContent);
} else if(configData.isGan()) {
final String trainerName = "CNNTrainer_" + getInstanceName();
if(!configuration.getDiscriminatorNetwork().isPresent()) {
Log.error("No architecture model for discriminator available but is required for chosen " +
"GAN");
}
NNArchitectureSymbol genericArchitectureSymbol = configuration.getDiscriminatorNetwork().get();
ArchitectureSymbol architectureSymbol
= ((ArchitectureAdapter)genericArchitectureSymbol).getArchitectureSymbol();
CNNArch2Gluon gluonGenerator = new CNNArch2Gluon();
gluonGenerator.setGenerationTargetPath(
Paths.get(getGenerationTargetPath(), GAN_LEARNING_FRAMEWORK_MODULE).toString());
Map<String, String> architectureFileContentMap
= gluonGenerator.generateStringsAllowMultipleIO(architectureSymbol, true);
final String creatorName = architectureFileContentMap.keySet().iterator().next();
final String discriminatorInstanceName = creatorName.substring(
creatorName.indexOf('_') + 1, creatorName.lastIndexOf(".py"));
fileContentMap.putAll(architectureFileContentMap.entrySet().stream().collect(Collectors.toMap(
k -> GAN_LEARNING_FRAMEWORK_MODULE + "/" + k.getKey(),
Map.Entry::getValue))
);
ftlContext.put("ganFrameworkModule", GAN_LEARNING_FRAMEWORK_MODULE);
ftlContext.put("discriminatorInstanceName", discriminatorInstanceName);
ftlContext.put("trainerName", trainerName);
final String initContent = "";
fileContentMap.put(GAN_LEARNING_FRAMEWORK_MODULE + "/__init__.py", initContent);
final String ganTrainerContent = templateConfiguration.processTemplate(ftlContext, "gan/Trainer.ftl");
fileContentMap.put(trainerName + ".py", ganTrainerContent);
//final String startTrainerScriptContent = templateConfiguration.processTemplate(ftlContext, "gan/StartTrainer.ftl");
//fileContentMap.put("start_training.sh", startTrainerScriptContent);
} else if (configData.isReinforcementLearning()) {
final String trainerName = "CNNTrainer_" + getInstanceName();
final RLAlgorithm rlAlgorithm = configData.getRlAlgorithm();
......
......@@ -28,6 +28,10 @@ public class GluonConfigurationData extends ConfigurationData {
&& retrieveConfigurationEntryValueByKey(LEARNING_METHOD).equals(LearningMethod.REINFORCEMENT);
}
public Boolean isGan() {
return configurationContainsKey(LEARNING_METHOD)
&& retrieveConfigurationEntryValueByKey(LEARNING_METHOD).equals(LearningMethod.GAN);
}
public Integer getNumEpisodes() {
return !configurationContainsKey(NUM_EPISODES)
? null : (Integer)retrieveConfigurationEntryValueByKey(NUM_EPISODES);
......@@ -161,6 +165,23 @@ public class GluonConfigurationData extends ConfigurationData {
return getMultiParamEntry(REPLAY_MEMORY, "method");
}
public Map<String, Object> getNoiseDistribution() {
return getMultiParamEntry(NOISE_DISTRIBUTION, "name");
}
public String getImgResizeWidth() {
if (!this.getConfiguration().getEntryMap().containsKey("img_resize_width")) {
return null;
}
return String.valueOf(getConfiguration().getEntry("img_resize_width").getValue());
}
public String getImgResizeHeight() {
if (!this.getConfiguration().getEntryMap().containsKey("img_resize_height")) {
return null;
}
return String.valueOf(getConfiguration().getEntry("img_resize_height").getValue());
}
public Map<String, Object> getStrategy() {
assert isReinforcementLearning(): "Strategy parameter only for reinforcement learning but called in a " +
" non reinforcement learning context";
......
......@@ -4,6 +4,8 @@ import h5py
import mxnet as mx
import logging
import sys
import numpy as np
import cv2
from mxnet import nd
class ${tc.fileNameWithoutEnding}:
......@@ -50,6 +52,48 @@ class ${tc.fileNameWithoutEnding}:
return train_iter, test_iter, data_mean, data_std
def load_data(self, batch_size, img_size):
train_h5, test_h5 = self.load_h5_files()
width = img_size[0]
height = img_size[1]
comb_data = {}
data_mean = {}
data_std = {}
for input_name in self._input_names_:
train_data = train_h5[input_name][:]
test_data = test_h5[input_name][:]
train_shape = train_data.shape
test_shape = test_data.shape
comb_data[input_name] = mx.nd.zeros((train_shape[0]+test_shape[0], train_shape[1], width, height))
for i, img in enumerate(train_data):
img = img.transpose(1,2,0)
comb_data[input_name][i] = cv2.resize(img, (width, height)).reshape((train_shape[1],width,height))
for i, img in enumerate(test_data):
img = img.transpose(1, 2, 0)
comb_data[input_name][i+train_shape[0]] = cv2.resize(img, (width, height)).reshape((train_shape[1], width, height))
data_mean[input_name + '_'] = nd.array(comb_data[input_name][:].mean(axis=0))
data_std[input_name + '_'] = nd.array(comb_data[input_name][:].asnumpy().std(axis=0) + 1e-5)
comb_label = {}
for output_name in self._output_names_:
train_labels = train_h5[output_name][:]
test_labels = test_h5[output_name][:]
comb_label[output_name] = np.append(train_labels, test_labels, axis=0)
train_iter = mx.io.NDArrayIter(data=comb_data,
label=comb_label,
batch_size=batch_size)
test_iter = None
return train_iter, test_iter, data_mean, data_std
def load_h5_files(self):
train_h5 = None
test_h5 = None
......@@ -58,6 +102,7 @@ class ${tc.fileNameWithoutEnding}:
if os.path.isfile(train_path):
train_h5 = h5py.File(train_path, 'r')
print(train_path)
for input_name in self._input_names_:
if not input_name in train_h5:
......
import mxnet as mx
import logging
import numpy as np
import time
import os
import shutil
from mxnet import gluon, autograd, nd
# ugly hardcoded
import matplotlib as mpl
from matplotlib import pyplot as plt
def visualize(img_arr):
plt.imshow((img_arr.asnumpy().transpose(1, 2, 0) * 255).astype(np.uint8).reshape(64,64))
plt.axis('off')
def getDataIter(ctx, batch_size=64, Z=100):
img_number = 70
mnist_train = mx.gluon.data.vision.datasets.MNIST(train=True)
mnist_test = mx.gluon.data.vision.datasets.MNIST(train=False)
X = np.zeros((img_number, 28, 28))
for i in range(img_number/2):
X[i] = mnist_train[i][0].asnumpy()[:,:,0]
for i in range(img_number/2):
X[img_number/2+i] = mnist_test[i][0].asnumpy()[:,:,0]
#X = np.zeros((img_number, 28, 28))
#for i, (data, label) in enumerate(mnist_train):
# X[i] = mnist_train[i][0].asnumpy()[:,:,0]
#for i, (data, label) in enumerate(mnist_test):
# X[len(mnist_train)+i] = data.asnumpy()[:,:,0]
np.random.seed(1)
p = np.random.permutation(X.shape[0])
X = X[p]
import cv2
X = np.asarray([cv2.resize(x, (64,64)) for x in X])
X = X.astype(np.float32, copy=False)/(255.0/2) - 1.0
X = X.reshape((img_number, 1, 64, 64))
X = np.tile(X, (1, 3, 1, 1))
data = mx.nd.array(X)
for i in range(4):
plt.subplot(1,4,i+1)
visualize(data[i])
plt.show()
image_iter = mx.io.NDArrayIter(data, batch_size=batch_size)
return image_iter
# ugly hardcoded end
class ${tc.fileNameWithoutEnding}:
def __init__(self, data_loader, net_constructor_gen, net_constructor_dis):
self._data_loader = data_loader
self._net_creator_gen = net_constructor_gen
self._net_creator_dis = net_constructor_dis
def train(self, batch_size=64,
num_epoch=10,
eval_metric='acc',
loss ='softmax_cross_entropy',
loss_params={},
optimizer='adam',
optimizer_params=(('learning_rate', 0.001),),
load_checkpoint=True,
context='gpu',
checkpoint_period=5,
normalize=True,
img_resize=(64,64),
noise_distribution='gaussian',
noise_distribution_params=(('mean_value', 0),('spread_value', 1),)):
if context == 'gpu':
mx_context = mx.gpu()
elif context == 'cpu':
mx_context = mx.cpu()
else:
logging.error("Context argument is '" + context + "'. Only 'cpu' and 'gpu are valid arguments'.")
#train_iter = getDataIter(mx_context, batch_size, 100)
train_iter, test_iter, data_mean, data_std = self._data_loader.load_data(batch_size, img_resize)
if 'weight_decay' in optimizer_params:
optimizer_params['wd'] = optimizer_params['weight_decay']
del optimizer_params['weight_decay']
if 'learning_rate_decay' in optimizer_params:
min_learning_rate = 1e-08
if 'learning_rate_minimum' in optimizer_params:
min_learning_rate = optimizer_params['learning_rate_minimum']
del optimizer_params['learning_rate_minimum']
optimizer_params['lr_scheduler'] = mx.lr_scheduler.FactorScheduler(
optimizer_params['step_size'],
factor=optimizer_params['learning_rate_decay'],
stop_factor_lr=min_learning_rate)
del optimizer_params['step_size']
del optimizer_params['learning_rate_decay']
if normalize:
self._net_creator_dis.construct(mx_context, data_mean=data_mean, data_std=data_std)
else:
self._net_creator_dis.construct(mx_context)
self._net_creator_gen.construct(mx_context)
begin_epoch = 0
if load_checkpoint:
begin_epoch = self._net_creator_dis.load(mx_context)
begin_epoch = self._net_creator_gen.load(mx_context)
else:
if os.path.isdir(self._net_creator_dis._model_dir_):
shutil.rmtree(self._net_creator_dis._model_dir_)
if os.path.isdir(self._net_creator_gen._model_dir_):
shutil.rmtree(self._net_creator_gen._model_dir_)
dis_net = self._net_creator_dis.networks[0]
gen_net = self._net_creator_gen.networks[0]
try:
os.makedirs(self._net_creator_gen._model_dir_)
os.makedirs(self._net_creator_dis._model_dir_)
except OSError:
if not (os.path.isdir(self._net_creator_gen._model_dir_) and
os.path.isdir(self._net_creator_dis._model_dir_)):
raise
gen_trainer = mx.gluon.Trainer(gen_net.collect_params(), optimizer, optimizer_params)
dis_trainer = mx.gluon.Trainer(dis_net.collect_params(), optimizer, optimizer_params)
if loss == 'sigmoid_binary_cross_entropy':
loss_function = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss()
activation_name = 'sigmoid'
<#list tc.architecture.streams as stream>
<#if stream.isTrainable()>
input_shape = <#list tc.getStreamInputDimensions(stream) as dimensions>${tc.join(dimensions, ",")}</#list>
</#if>
</#list>
shape_list = list(input_shape)
shape_list[0] = batch_size
input_shape = tuple(shape_list)
metric_dis = mx.metric.create(eval_metric)
metric_gen = mx.metric.create(eval_metric)
if noise_distribution == "gaussian":
random_distributor = lambda : mx.ndarray.random.normal(noise_distribution_params["mean_value"],
noise_distribution_params["spread_value"],
shape=input_shape, ctx=mx_context)
speed_period = 5
tic = None
for epoch in range(begin_epoch, begin_epoch + num_epoch):
train_iter.reset()
for batch_i, batch in enumerate(train_iter):
real_data = batch.data[0].as_in_context(mx_context)
rbatch = random_distributor()
fake_labels = mx.nd.zeros((batch_size), ctx=mx_context)
real_labels = mx.nd.ones((batch_size), ctx=mx_context)
with autograd.record():
fake_data = gen_net(rbatch)
fake_data.detach()
discriminated_fake_dis = dis_net(fake_data)
loss_resultF = loss_function(discriminated_fake_dis, fake_labels)
discriminated_real_dis = dis_net(real_data)
loss_resultR = loss_function(discriminated_real_dis, real_labels)
loss_resultD = loss_resultR + loss_resultF
loss_resultD.backward()
dis_trainer.step(batch_size)
with autograd.record():
fake_data = gen_net(rbatch)
discriminated_fake_gen = dis_net(fake_data)
loss_resultG = loss_function(discriminated_fake_gen, real_labels)
loss_resultG.backward()
gen_trainer.step(batch_size)
if tic is None:
tic = time.time()
else:
if batch_i % speed_period == 0:
metric_dis = mx.metric.create(eval_metric)
discriminated = mx.nd.Concat(discriminated_real_dis.reshape((-1,1)), discriminated_fake_dis.reshape((-1,1)), dim=0)
labels = mx.nd.Concat(real_labels.reshape((-1,1)), fake_labels.reshape((-1,1)), dim=0)
discriminated = mx.ndarray.Activation(discriminated, activation_name)
discriminated = mx.ndarray.floor(discriminated + 0.5)
metric_dis.update(preds=discriminated, labels=labels)
print("DisAcc: ", metric_dis.get()[1])
metric_gen = mx.metric.create(eval_metric)
discriminated = mx.ndarray.Activation(discriminated_fake_gen.reshape((-1,1)), activation_name)
discriminated = mx.ndarray.floor(discriminated + 0.5)
metric_gen.update(preds=discriminated, labels=real_labels.reshape((-1,1)))
print("GenAcc: ", metric_gen.get()[1])
try:
speed = speed_period * batch_size / (time.time() - tic)
except ZeroDivisionError:
speed = float("inf")
logging.info("Epoch[%d] Batch[%d] Speed: %.2f samples/sec" % (epoch, batch_i, speed))
tic = time.time()
# ugly start
#if batch_i % 20 == 0:
# fake_data[0].asnumpy()
if batch_i % 50 == 0:
#gen_net.save_parameters(self.parameter_path_gen() + '-' + str(num_epoch + begin_epoch).zfill(4) + '.params')
#gen_net.export(self.parameter_path_gen() + '_newest', epoch=0)
#dis_net.save_parameters(self.parameter_path_dis() + '-' + str(num_epoch + begin_epoch).zfill(4) + '.params')
#dis_net.export(self.parameter_path_dis() + '_newest', epoch=0)
for i in range(10):
plt.subplot(1,10,i+1)
fake_img = fake_data[i]
visualize(fake_img)
plt.show()
# ugly end
if (epoch - begin_epoch) % checkpoint_period == 0:
gen_net.save_parameters(self.parameter_path_gen() + '-' + str(epoch).zfill(4) + '.params')
dis_net.save_parameters(self.parameter_path_dis() + '-' + str(epoch).zfill(4) + '.params')
gen_net.save_parameters(self.parameter_path_gen() + '-' + str(num_epoch + begin_epoch).zfill(4) + '.params')
gen_net.export(self.parameter_path_gen() + '_newest', epoch=0)
dis_net.save_parameters(self.parameter_path_dis() + '-' + str(num_epoch + begin_epoch).zfill(4) + '.params')
dis_net.export(self.parameter_path_dis() + '_newest', epoch=0)
def parameter_path_gen(self):
return self._net_creator_gen._model_dir_ + self._net_creator_gen._model_prefix_ + '_' + str(0)
def parameter_path_dis(self):
return self._net_creator_dis._model_dir_ + self._net_creator_dis._model_prefix_ + '_' + str(0)
<#-- (c) https://github.com/MontiCore/monticore -->
<#assign input = element.inputs[0]>
<#if mode == "ARCHITECTURE_DEFINITION">
self.${element.name} = gluon.nn.LeakyReLU(${element.alpha})
<#elseif mode == "FORWARD_FUNCTION">
${element.name} = self.${element.name}(${input})
</#if>
<#-- (c) https://github.com/MontiCore/monticore -->
<#assign input = element.inputs[0]>
<#if mode == "ARCHITECTURE_DEFINITION">
<#if element.padding??>
self.${element.name}padding = (${tc.join(element.transPadding, ",")})
</#if>
self.${element.name} = gluon.nn.Conv2DTranspose(channels=${element.channels?c},
kernel_size=(${tc.join(element.kernel, ",")}),
strides=(${tc.join(element.stride, ",")}),
padding=self.${element.name}padding,
use_bias=${element.noBias?string("False", "True")})
<#include "OutputShape.ftl">
<#elseif mode == "FORWARD_FUNCTION">
${element.name} = self.${element.name}(${input})
</#if>
<#setting number_format="computer">
<#assign config = configurations[0]>
import mxnet as mx
import logging
import os
import numpy as np
import time
import shutil
from mxnet import gluon, autograd, nd
import CNNCreator_${config.instanceName}
import CNNDataLoader_${config.instanceName}
import CNNGanTrainer_${config.instanceName}
from ${ganFrameworkModule}.CNNCreator_${discriminatorInstanceName} import CNNCreator_${discriminatorInstanceName}
if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger()
handler = logging.FileHandler("train.log", "w", encoding=None, delay="true")
logger.addHandler(handler)
data_loader = CNNDataLoader_${config.instanceName}.CNNDataLoader_${config.instanceName}()
gen_creator = CNNCreator_${config.instanceName}.CNNCreator_${config.instanceName}()
dis_creator = CNNCreator_${discriminatorInstanceName}()
${config.instanceName}_trainer = CNNGanTrainer_${config.instanceName}.CNNGanTrainer_${config.instanceName}(
data_loader,
gen_creator,
dis_creator
)
${config.instanceName}_trainer.train(
<#if (config.batchSize)??>
batch_size=${config.batchSize},
</#if>
<#if (config.numEpoch)??>
num_epoch=${config.numEpoch},
</#if>
<#if (config.loadCheckpoint)??>
load_checkpoint=${config.loadCheckpoint?string("True","False")},
</#if>
<#if (config.context)??>
context='${config.context}',
</#if>
<#if (config.normalize)??>
normalize=${config.normalize?string("True","False")},
</#if>
<#if (config.imgResizeWidth)??>
<#if (config.imgResizeHeight)??>
img_resize=(${config.imgResizeWidth}, ${config.imgResizeHeight}),
</#if>
</#if>
<#if (config.evalMetric)??>
eval_metric='${config.evalMetric}',
</#if>
<#if (config.configuration.loss)??>
loss ='${config.lossName}',
<#if (config.lossParams)??>
loss_params={
<#list config.lossParams?keys as param>
'${param}': ${config.lossParams[param]}<#sep>,
</#list>
},
</#if>
</#if>
<#if (config.configuration.optimizer)??>
optimizer='${config.optimizerName}',
optimizer_params={
<#list config.optimizerParams?keys as param>
'${param}': ${config.optimizerParams[param]}<#sep>,
</#list>
},
</#if>
<#if (config.noiseDistribution)??>
noise_distribution = '${config.noiseDistribution.name}',
noise_distribution_params = {
<#if (config.noiseDistribution.mean_value)??>
'mean_value': ${config.noiseDistribution.mean_value},
</#if>
<#if (config.noiseDistribution.spread_value)??>
'spread_value': ${config.noiseDistribution.spread_value}
</#if>
})
</#if>
<#setting number_format="computer">
<#assign config = configurations[0]>
<#assign rlAgentType=config.rlAlgorithm?switch("dqn", "DqnAgent", "ddpg", "DdpgAgent", "td3", "TwinDelayedDdpgAgent")>
from ${ganFrameworkModule}.CNNCreator_${discriminatorInstanceName} import CNNCreator_${discriminatorInstanceName}
import CNNCreator_${config.instanceName}
import mxnet as mx
import logging
import numpy as np
import time
import os
import shutil
from mxnet import gluon, autograd, nd
if __name__ = "__main__":
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger()
handler = logging.FileHandler("train.log", "w", encoding=None, delay="true")
logger.addHandler(handler)
<#if (config.context)??>
context = mx.${config.context}()
<#else>
context = mx.cpu()
</#if>
generator_creator = CNNCreator_${config.instanceName}.CNNCreator_${config.instanceName}()
generator_creator.construct(context)
discriminator_creator = CNNCreator_${discriminatorInstanceName}()
discriminator_creator.construct(context)
<#if (config.batchSize)??>
batch_size=${config.batchSize},
</#if>
<#if (config.numEpoch)??>
num_epoch=${config.numEpoch},
</#if>
<#if (config.loadCheckpoint)??>
load_checkpoint=${config.loadCheckpoint?string("True","False")},
</#if>
<#if (config.normalize)??>
normalize=${config.normalize?string("True","False")},
</#if>
<#if (config.evalMetric)??>
eval_metric='${config.evalMetric}',
</#if>
<#if (config.configuration.loss)??>
loss='${config.lossName}',
<#if (config.lossParams)??>