Commit 86b9d59f authored by Julian Dierkes's avatar Julian Dierkes

fixed problem in input generator regarding InfoGan

parent dc316085
Pipeline #228265 passed with stages
......@@ -59,7 +59,7 @@ import matplotlib as mpl
from matplotlib import pyplot as plt
def visualize(img_arr):
plt.imshow((img_arr.asnumpy().transpose(1, 2, 0) * 255).astype(np.uint8).reshape(64,64))
plt.imshow((img_arr.asnumpy().transpose(1, 2, 0) * 255).astype(np.uint8).reshape(28,28))
plt.axis('off')
def getDataIter(ctx, batch_size=64, Z=100):
......@@ -84,9 +84,9 @@ def getDataIter(ctx, batch_size=64, Z=100):
X = X[p]
import cv2
X = np.asarray([cv2.resize(x, (64,64)) for x in X])
#X = np.asarray([cv2.resize(x, (64,64)) for x in X])
X = X.astype(np.float32, copy=False)/(255.0/2) - 1.0
X = X.reshape((img_number, 1, 64, 64))
X = X.reshape((img_number, 1, 28, 28))
X = np.tile(X, (1, 1, 1, 1))
data = mx.nd.array(X)
......@@ -150,7 +150,7 @@ class ${tc.fileNameWithoutEnding}:
if not (os.path.isdir(self._net_creator_qnet._model_dir_)):
raise
q_net = self._net_creator_qnet.networks[0]
qnet_trainer = mx.gluon.Trainer(q_net.collect_params(), optimizer, optimizer_params)
qnet_trainer = mx.gluon.Trainer(q_net.collect_params(), 'adam', {'learning_rate': 0.0002, 'beta1': 0.5})
g_input = self._data_loader._input_names_
q_input = [name[:-1] for name in q_net.getOutputs()]
......@@ -209,8 +209,8 @@ class ${tc.fileNameWithoutEnding}:
os.path.isdir(self._net_creator_dis._model_dir_)):
raise
gen_trainer = mx.gluon.Trainer(gen_net.collect_params(), optimizer, optimizer_params)
dis_trainer = mx.gluon.Trainer(dis_net.collect_params(), optimizer, optimizer_params)
gen_trainer = mx.gluon.Trainer(gen_net.collect_params(), 'adam', {'learning_rate': 0.0002, 'beta1': 0.5})
dis_trainer = mx.gluon.Trainer(dis_net.collect_params(), 'adam', {'learning_rate': 0.0002, 'beta1': 0.5})
if loss == 'sigmoid_binary_cross_entropy':
loss_function = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss()
......@@ -262,11 +262,11 @@ class ${tc.fileNameWithoutEnding}:
discriminated_fake_gen, features = discriminated_fake_gen
loss_resultG = loss_function(discriminated_fake_gen, real_labels)
if self.use_qnet:
qnet_discriminated = list(q_net(features))
qnet_discriminated = [q_net(features)]
for i, qnet_out in enumerate(qnet_discriminated):
loss_resultG = loss_resultG + qnet_losses[i](qnet_out, exp_qnet_output[i])
loss_resultE = loss_resultG + qnet_losses[i](qnet_out, exp_qnet_output[i])
loss_resultG.backward()
loss_resultE.backward()
gen_trainer.step(batch_size)
if self.use_qnet:
qnet_trainer.step(batch_size)
......@@ -299,14 +299,18 @@ class ${tc.fileNameWithoutEnding}:
# ugly start
#if batch_i % 200 == 0:
# fake_data[0].asnumpy()
if batch_i % 750 == 0:
if batch_i % 900 == 0:
#gen_net.save_parameters(self.parameter_path_gen() + '-' + str(num_epoch + begin_epoch).zfill(4) + '.params')
#gen_net.export(self.parameter_path_gen() + '_newest', epoch=0)
#dis_net.save_parameters(self.parameter_path_dis() + '-' + str(num_epoch + begin_epoch).zfill(4) + '.params')
#dis_net.export(self.parameter_path_dis() + '_newest', epoch=0)
for i in range(10):
plt.subplot(1,10,i+1)
fake_img = fake_data[i]
noise = mx.nd.random_normal(0, 1, shape=(10, 62), ctx=mx_context)
label = nd.array(np.random.randint(10, size=10)).as_in_context(mx_context)
c1 = nd.one_hot(nd.ones(shape=(10), ctx=mx_context), depth=10).as_in_context(mx_context)
images = gen_net(noise, c1)
for j in range(10):
plt.subplot(1, 10, j+1)
fake_img = images[j]
visualize(fake_img)
plt.show()
# ugly end
......
......@@ -23,9 +23,9 @@
shape=(batch_size,)+domain[3],
dtype=domain[0], ctx=mx_context,), dtype="float32")
elif domain[0] == int:
generators[name] = lambda domain=domain, min=min, max=max: mx.nd.cast(mx.ndarray.random.randint(low=int(min),
high=int(max)+1, shape=(batch_size,)+domain[3],
ctx=mx_context), dtype="float32")
generators[name] = lambda domain=domain, min=min, max=max:mx.ndarray.one_hot(mx.ndarray.random.randint(low=0,
high=int(max-min)+1, shape=(batch_size,), dtype=int,
ctx=mx_context), depth=int(max-min)+1, on_value=1).reshape((batch_size,)+domain[3])
if name[-1] in constraint_losses:
loss_dict = constraint_losses[name[:-1]]
......@@ -55,9 +55,7 @@
if domain[0] == float:
qnet_losses += [mx.gluon.loss.L2Loss()]
elif domain[0] == int:
qnet_losses += [lambda pred, labels: mx.gluon.loss.SoftmaxCrossEntropyLoss()(pred, labels.reshape(batch_size))]
qnet_losses += [lambda pred, labels: mx.gluon.loss.SoftmaxCrossEntropyLoss(sparse_label=False)(pred, labels)]
for name in gen_inputs:
if not name in qnet_outputs:
......@@ -77,6 +75,7 @@
if not name in qnet_outputs:
input_to_gen += [generators[name]()]
for name in qnet_outputs:
expected_output_qnet += [generators[name]()]
input_to_gen += [generators[name]()]
value = generators[name]()
expected_output_qnet += [value]
input_to_gen += [value]
return input_to_gen, expected_output_qnet
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment