Commit cd64dbd3 authored by Julian Dierkes's avatar Julian Dierkes
Browse files

adjusted GAN trainer

parent 86b9d59f
......@@ -63,3 +63,29 @@ class ${tc.fileNameWithoutEnding}:
for i, network in self.networks.items():
network.export(self._model_dir_ + self._model_prefix_ + "_" + str(i), epoch=0)
def getInputs(self):
inputs = {}
<#list tc.architecture.streams as stream>
<#assign dimensions = (tc.getStreamInputs(stream, false))>
<#assign domains = (tc.getStreamInputDomains(stream))>
<#list tc.getStreamInputVariableNames(stream, false) as name>
input_dimensions = (${tc.join(dimensions[name], ",")})
input_domains = (${tc.join(domains[name], ",")})
inputs["${name}"] = input_domains + (input_dimensions,)
</#list>
</#list>
return inputs
def getOutputs(self):
outputs = {}
<#list tc.architecture.streams as stream>
<#assign dimensions = (tc.getStreamOutputs(stream, false))>
<#assign domains = (tc.getStreamOutputDomains(stream))>
<#list tc.getStreamOutputVariableNames(stream, false) as name>
output_dimensions = (${tc.join(dimensions[name], ",")})
output_domains = (${tc.join(domains[name], ",")})
outputs["${name}"] = output_domains + (output_dimensions,)
</#list>
</#list>
return outputs
......@@ -54,30 +54,32 @@ class SoftmaxCrossEntropyLossIgnoreIndices(gluon.loss.Loss):
loss = loss * mx.nd.logical_not(mx.nd.equal(mx.nd.argmax(pred, axis=1), mx.nd.ones_like(mx.nd.argmax(pred, axis=1))*i) * mx.nd.equal(mx.nd.argmax(pred, axis=1), label))
return loss.mean(axis=self._batch_axis, exclude=True)
"""
# ugly hardcoded
import matplotlib as mpl
#import matplotlib as mpl
from matplotlib import pyplot as plt
def visualize(img_arr):
plt.imshow((img_arr.asnumpy().transpose(1, 2, 0) * 255).astype(np.uint8).reshape(28,28))
plt.axis('off')
"""
def getDataIter(ctx, batch_size=64, Z=100):
img_number = 70000
img_number = 500
mnist_train = mx.gluon.data.vision.datasets.MNIST(train=True)
mnist_test = mx.gluon.data.vision.datasets.MNIST(train=False)
#X = np.zeros((img_number, 28, 28))
#for i in range(img_number/2):
# X[i] = mnist_train[i][0].asnumpy()[:,:,0]
#for i in range(img_number/2):
# X[img_number/2+i] = mnist_test[i][0].asnumpy()[:,:,0]
X = np.zeros((img_number, 28, 28))
for i, (data, label) in enumerate(mnist_train):
for i in range(img_number/2):
X[i] = mnist_train[i][0].asnumpy()[:,:,0]
for i, (data, label) in enumerate(mnist_test):
X[len(mnist_train)+i] = data.asnumpy()[:,:,0]
for i in range(img_number/2):
X[img_number/2+i] = mnist_test[i][0].asnumpy()[:,:,0]
#X = np.zeros((img_number, 28, 28))
#for i, (data, label) in enumerate(mnist_train):
# X[i] = mnist_train[i][0].asnumpy()[:,:,0]
#for i, (data, label) in enumerate(mnist_test):
# X[len(mnist_train)+i] = data.asnumpy()[:,:,0]
np.random.seed(1)
p = np.random.permutation(X.shape[0])
......@@ -90,10 +92,12 @@ def getDataIter(ctx, batch_size=64, Z=100):
X = np.tile(X, (1, 1, 1, 1))
data = mx.nd.array(X)
"""
for i in range(4):
plt.subplot(1,4,i+1)
visualize(data[i])
plt.show()
"""
image_iter = mx.io.NDArrayIter(data, batch_size=batch_size)
return image_iter
......@@ -123,12 +127,17 @@ class ${tc.fileNameWithoutEnding}:
context='gpu',
checkpoint_period=5,
normalize=True,
img_resize=(64,64),
noise_distribution='gaussian',
noise_distribution_params=(('mean_value', 0),('spread_value', 1),),
discriminator_optimizer='adam',
discriminator_optimizer_params=(('learning_rate', 0.001),),
constraint_distributions={},
constraint_losses={},
preprocessing = False):
preprocessing = False,
k_value = 1,
generator_loss = None,
conditional_input = None,
noise_input = None):
if context == 'gpu':
mx_context = mx.gpu()
......@@ -137,34 +146,37 @@ class ${tc.fileNameWithoutEnding}:
else:
logging.error("Context argument is '" + context + "'. Only 'cpu' and 'gpu are valid arguments'.")
gen_input_names = list(self._net_creator_gen.getInputs().keys())
gen_input_names = [name[:-1] for name in gen_input_names]
dis_input_names = list(self._net_creator_dis.getInputs().keys())
dis_input_names = [name[:-1] for name in dis_input_names]
if self.use_qnet:
self._net_creator_qnet.construct(mx_context)
if load_checkpoint:
self._net_creator_qnet.load(mx_context)
else:
if os.path.isdir(self._net_creator_qnet._model_dir_):
shutil.rmtree(self._net_creator_qnet._model_dir_)
try:
os.makedirs(self._net_creator_qnet._model_dir_)
except OSError:
if not (os.path.isdir(self._net_creator_qnet._model_dir_)):
raise
q_net = self._net_creator_qnet.networks[0]
qnet_trainer = mx.gluon.Trainer(q_net.collect_params(), 'adam', {'learning_rate': 0.0002, 'beta1': 0.5})
g_input = self._data_loader._input_names_
q_input = [name[:-1] for name in q_net.getOutputs()]
new_inputs = [name for name in g_input if (name not in q_input)]
self._data_loader._input_names_ = new_inputs
qnet_input_names = list(self._net_creator_qnet.getInputs().keys())
qnet_input_names = [name[:-1] for name in qnet_input_names]
if preprocessing:
preproc_lib = "CNNPreprocessor_${tc.fileNameWithoutEnding?keep_after("CNNGanTrainer_")}_executor"
train_iter = getDataIter(mx_context, batch_size, 100)
# if preprocessing:
# train_iter, test_iter, data_mean, data_std = self._data_loader.load_preprocessed_data(batch_size, preproc_lib)
# else:
# train_iter, test_iter, data_mean, data_std = self._data_loader.load_data(batch_size)
self._data_loader._output_names_ = []
if self.use_qnet:
dataloader_inputs = set(gen_input_names + dis_input_names).difference(qnet_input_names)
dataloader_inputs.discard(noise_input)
else:
dataloader_inputs = set(gen_input_names + dis_input_names)
dataloader_inputs.discard(noise_input)
self._data_loader._input_names_ = list(dataloader_inputs)
# train_iter = getDataIter(mx_context, batch_size, 100)
if preprocessing:
train_iter, test_iter, data_mean, data_std, _, _ = self._data_loader.load_preprocessed_data(batch_size, preproc_lib)
else:
train_iter, test_iter, data_mean, data_std, _, _ = self._data_loader.load_data(batch_size)
traindata_to_index = {}
curIndex = 0
for data_tuple in train_iter.data:
traindata_to_index[data_tuple[0] + "_"] = curIndex
curIndex += 1
if 'weight_decay' in optimizer_params:
optimizer_params['wd'] = optimizer_params['weight_decay']
......@@ -181,6 +193,21 @@ class ${tc.fileNameWithoutEnding}:
del optimizer_params['step_size']
del optimizer_params['learning_rate_decay']
if 'weight_decay' in discriminator_optimizer_params:
discriminator_optimizer_params['wd'] = discriminator_optimizer_params['weight_decay']
del discriminator_optimizer_params['weight_decay']
if 'learning_rate_decay' in optimizer_params:
min_learning_rate = 1e-08
if 'learning_rate_minimum' in discriminator_optimizer_params:
min_learning_rate = discriminator_optimizer_params['learning_rate_minimum']
del discriminator_optimizer_params['learning_rate_minimum']
discriminator_optimizer_params['lr_scheduler'] = mx.lr_scheduler.FactorScheduler(
discriminator_optimizer_params['step_size'],
factor=discriminator_optimizer_params['learning_rate_decay'],
stop_factor_lr=min_learning_rate)
del discriminator_optimizer_params['step_size']
del discriminator_optimizer_params['learning_rate_decay']
if normalize:
self._net_creator_dis.construct(mx_context, data_mean=data_mean, data_std=data_std)
else:
......@@ -188,6 +215,20 @@ class ${tc.fileNameWithoutEnding}:
self._net_creator_gen.construct(mx_context)
if self.use_qnet:
self._net_creator_qnet.construct(mx_context)
if load_checkpoint:
self._net_creator_qnet.load(mx_context)
else:
if os.path.isdir(self._net_creator_qnet._model_dir_):
shutil.rmtree(self._net_creator_qnet._model_dir_)
try:
os.makedirs(self._net_creator_qnet._model_dir_)
except OSError:
if not (os.path.isdir(self._net_creator_qnet._model_dir_)):
raise
q_net = self._net_creator_qnet.networks[0]
begin_epoch = 0
if load_checkpoint:
begin_epoch = self._net_creator_dis.load(mx_context)
......@@ -209,8 +250,10 @@ class ${tc.fileNameWithoutEnding}:
os.path.isdir(self._net_creator_dis._model_dir_)):
raise
gen_trainer = mx.gluon.Trainer(gen_net.collect_params(), 'adam', {'learning_rate': 0.0002, 'beta1': 0.5})
dis_trainer = mx.gluon.Trainer(dis_net.collect_params(), 'adam', {'learning_rate': 0.0002, 'beta1': 0.5})
gen_trainer = mx.gluon.Trainer(gen_net.collect_params(), optimizer, optimizer_params)
dis_trainer = mx.gluon.Trainer(dis_net.collect_params(), discriminator_optimizer, discriminator_optimizer_params)
if self.use_qnet:
qnet_trainer = mx.gluon.Trainer(q_net.collect_params(), discriminator_optimizer, discriminator_optimizer_params)
if loss == 'sigmoid_binary_cross_entropy':
loss_function = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss()
......@@ -223,6 +266,14 @@ class ${tc.fileNameWithoutEnding}:
else:
logging.error("Invalid loss parameter.")
if not generator_loss == None:
if generator_loss == "L2":
generator_loss_func = mx.gluon.loss.L2Loss()
elif generator_loss == "L1":
generator_loss_func = mx.gluon.loss.L1Loss()
else:
logging.error("Invalid generator loss parameter")
metric_dis = mx.metric.create(eval_metric)
metric_gen = mx.metric.create(eval_metric)
<#include "gan/InputGenerator.ftl">
......@@ -233,8 +284,10 @@ class ${tc.fileNameWithoutEnding}:
for epoch in range(begin_epoch, begin_epoch + num_epoch):
train_iter.reset()
for batch_i, batch in enumerate(train_iter):
real_data = batch.data[0].as_in_context(mx_context)
gen_input, exp_qnet_output = create_generator_input()
real_data = batch.data[traindata_to_index[dis_input_names[0]]].as_in_context(mx_context)
dis_conditional_input = create_discriminator_input(batch)
gen_input, exp_qnet_output = create_generator_input(batch)
fake_labels = mx.nd.zeros((batch_size), ctx=mx_context)
real_labels = mx.nd.ones((batch_size), ctx=mx_context)
......@@ -242,11 +295,11 @@ class ${tc.fileNameWithoutEnding}:
with autograd.record():
fake_data = gen_net(*gen_input)
fake_data.detach()
discriminated_fake_dis = dis_net(fake_data)
discriminated_fake_dis = dis_net(fake_data, *dis_conditional_input)
if self.use_qnet:
discriminated_fake_dis, _ = discriminated_fake_dis
loss_resultF = loss_function(discriminated_fake_dis, fake_labels)
discriminated_real_dis = dis_net(real_data)
discriminated_real_dis = dis_net(real_data, *dis_conditional_input)
if self.use_qnet:
discriminated_real_dis, _ = discriminated_real_dis
loss_resultR = loss_function(discriminated_real_dis, real_labels)
......@@ -255,21 +308,24 @@ class ${tc.fileNameWithoutEnding}:
loss_resultD.backward()
dis_trainer.step(batch_size)
with autograd.record():
fake_data = gen_net(*gen_input)
discriminated_fake_gen = dis_net(fake_data)
if self.use_qnet:
discriminated_fake_gen, features = discriminated_fake_gen
loss_resultG = loss_function(discriminated_fake_gen, real_labels)
if batch_i % k_value == 0:
with autograd.record():
fake_data = gen_net(*gen_input)
discriminated_fake_gen = dis_net(fake_data, *dis_conditional_input)
if self.use_qnet:
discriminated_fake_gen, features = discriminated_fake_gen
loss_resultG = loss_function(discriminated_fake_gen, real_labels)
if not generator_loss == None:
condition = batch.data[traindata_to_index[conditional_input + "_"]]
loss_resultG = loss_resultG + generator_loss_func(fake_data, condition)
if self.use_qnet:
qnet_discriminated = [q_net(features)]
for i, qnet_out in enumerate(qnet_discriminated):
loss_resultG = loss_resultG + qnet_losses[i](qnet_out, exp_qnet_output[i])
loss_resultG.backward()
gen_trainer.step(batch_size)
if self.use_qnet:
qnet_discriminated = [q_net(features)]
for i, qnet_out in enumerate(qnet_discriminated):
loss_resultE = loss_resultG + qnet_losses[i](qnet_out, exp_qnet_output[i])
loss_resultE.backward()
gen_trainer.step(batch_size)
if self.use_qnet:
qnet_trainer.step(batch_size)
qnet_trainer.step(batch_size)
if tic is None:
tic = time.time()
......@@ -296,24 +352,22 @@ class ${tc.fileNameWithoutEnding}:
tic = time.time()
"""
# ugly start
#if batch_i % 200 == 0:
# fake_data[0].asnumpy()
if batch_i % 900 == 0:
if batch_i % 200 == 0:
fake_data[0].asnumpy()
if batch_i < 500000:
#gen_net.save_parameters(self.parameter_path_gen() + '-' + str(num_epoch + begin_epoch).zfill(4) + '.params')
#gen_net.export(self.parameter_path_gen() + '_newest', epoch=0)
#dis_net.save_parameters(self.parameter_path_dis() + '-' + str(num_epoch + begin_epoch).zfill(4) + '.params')
#dis_net.export(self.parameter_path_dis() + '_newest', epoch=0)
noise = mx.nd.random_normal(0, 1, shape=(10, 62), ctx=mx_context)
label = nd.array(np.random.randint(10, size=10)).as_in_context(mx_context)
c1 = nd.one_hot(nd.ones(shape=(10), ctx=mx_context), depth=10).as_in_context(mx_context)
images = gen_net(noise, c1)
for j in range(10):
plt.subplot(1, 10, j+1)
fake_img = images[j]
fake_img = fake_data[j]
visualize(fake_img)
plt.show()
# ugly end
"""
if (epoch - begin_epoch) % checkpoint_period == 0:
......
......@@ -105,28 +105,3 @@ ${tc.include(networkInstruction.body, "FORWARD_FUNCTION")}
</#if>
</#list>
def getInputs(self):
inputs = {}
<#list tc.architecture.streams as stream>
<#assign dimensions = (tc.getStreamInputs(stream, false))>
<#assign domains = (tc.getStreamInputDomains(stream))>
<#list tc.getStreamInputVariableNames(stream, false) as name>
input_dimensions = (${tc.join(dimensions[name], ",")})
input_domains = (${tc.join(domains[name], ",")})
inputs["${name}"] = input_domains + (input_dimensions,)
</#list>
</#list>
return inputs
def getOutputs(self):
outputs = {}
<#list tc.architecture.streams as stream>
<#assign dimensions = (tc.getStreamOutputs(stream, false))>
<#assign domains = (tc.getStreamOutputDomains(stream))>
<#list tc.getStreamOutputVariableNames(stream, false) as name>
output_dimensions = (${tc.join(dimensions[name], ",")})
output_domains = (${tc.join(domains[name], ",")})
outputs["${name}"] = output_domains + (output_dimensions,)
</#list>
</#list>
return outputs
gen_inputs = gen_net.getInputs()
gen_inputs = self._net_creator_gen.getInputs()
dis_inputs = self._net_creator_dis.getInputs()
qnet_outputs = []
if self.use_qnet:
qnet_outputs = q_net.getOutputs()
qnet_outputs = self._net_creator_qnet.getOutputs()
qnet_losses = []
generators = {}
if self.use_qnet:
......@@ -68,14 +70,24 @@
shape=(batch_size,)+domain[3], dtype=domain[0],
ctx=mx_context), dtype="float32")
def create_generator_input():
expected_output_qnet = []
input_to_gen = []
def create_generator_input(cur_batch):
expected_qnet_output = []
gen_input = []
for name in gen_inputs:
if not name in qnet_outputs:
input_to_gen += [generators[name]()]
for name in qnet_outputs:
value = generators[name]()
expected_output_qnet += [value]
input_to_gen += [value]
return input_to_gen, expected_output_qnet
if name in traindata_to_index.keys():
gen_input += [batch.data[traindata_to_index[name]].as_in_context(mx_context)]
elif name in qnet_outputs:
value = generators[name]()
expected_qnet_output += [value]
gen_input += [value]
else:
gen_input += [generators[name]()]
return gen_input, expected_qnet_output
def create_discriminator_input(cur_batch):
conditional_input = []
for name in gen_inputs:
if name in traindata_to_index.keys():
conditional_input += [batch.data[traindata_to_index[name]].as_in_context(mx_context)]
return conditional_input
......@@ -82,6 +82,13 @@ if __name__ == "__main__":
</#list>
},
</#if>
<#if (config.configuration.criticOptimizer)??>
discriminator_optimizer= '${config.criticOptimizerName}',
discriminator_optimizer_params= {
<#list config.criticOptimizerParams?keys as param>
'${param}': ${config.criticOptimizerParams[param]}<#sep>,
</#list>},
</#if>
<#if (config.constraintDistributions)??>
<#assign map = (config.constraintDistributions)>
constraint_distributions = {
......@@ -121,7 +128,19 @@ if __name__ == "__main__":
<#if (config.noiseDistribution.spread_value)??>
'spread_value': ${config.noiseDistribution.spread_value}
</#if>
})
},
<#if (config.KValue)??>
k_value=${config.KValue},
</#if>
<#if (config.generatorLoss)??>
generator_loss="${config.generatorLoss}",
</#if>
<#if (config.conditionalInput)??>
conditional_input="${config.conditionalInput}",
</#if>
<#if (config.noiseInput)??>
noise_input="${config.noiseInput}"
</#if>)
</#if>
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment