Skip to content
Snippets Groups Projects
Commit 2addd001 authored by David Josef Schmidt's avatar David Josef Schmidt
Browse files

fix the GP loss term

parent 18567f5b
No related branches found
No related tags found
No related merge requests found
......@@ -44,10 +44,13 @@ def wasserstein_loss(y_true, y_pred):
def gradient_penalty_loss(y_true, y_pred, averaged_batch, penalty_weight):
"""Calculates the gradient penalty (for details see arXiv:1704.00028v3).
The 1-Lipschitz constraint for improved WGANs is enforced by adding a term to the loss which penalizes if the gradient norm in the critic unequal to 1"""
gradients = K.gradients(K.sum(y_pred), averaged_batch)
gradient_l2_norm = K.sqrt(K.sum(K.square(gradients)))
gradients = K.gradients(y_pred, averaged_batch)
gradients_sqr = K.square(gradients)[0]
gradients_sqr_sum = K.sum(
K.sum(K.sum(gradients_sqr, axis=3), axis=2), axis=1)
gradient_l2_norm = K.sqrt(gradients_sqr_sum)
gradient_penalty = penalty_weight * K.square(1 - gradient_l2_norm)
return gradient_penalty
return K.mean(gradient_penalty)
class RandomWeightedAverage(_Merge):
......@@ -96,7 +99,7 @@ def plot_loss(loss, log_dir=".", name=""):
plt.legend(loc='upper right', prop={'size': 10})
ax1.set_xlabel(r'Iterations')
ax1.set_ylabel(r'Loss')
ax1.set_ylim(-1.5, 1.5)
ax1.set_ylim(np.min(loss), 2)
fig.savefig(log_dir + '/%s_Loss.png' % name, dpi=120)
plt.close('all')
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment