Skip to content
Snippets Groups Projects
Commit 2dd34d0b authored by David Josef Schmidt's avatar David Josef Schmidt
Browse files

fix the GP loss term

parent 359dd7c4
Branches fix-GP
No related tags found
No related merge requests found
...@@ -44,10 +44,13 @@ def wasserstein_loss(y_true, y_pred): ...@@ -44,10 +44,13 @@ def wasserstein_loss(y_true, y_pred):
def gradient_penalty_loss(y_true, y_pred, averaged_batch, penalty_weight): def gradient_penalty_loss(y_true, y_pred, averaged_batch, penalty_weight):
"""Calculates the gradient penalty (for details see arXiv:1704.00028v3). """Calculates the gradient penalty (for details see arXiv:1704.00028v3).
The 1-Lipschitz constraint for improved WGANs is enforced by adding a term to the loss which penalizes if the gradient norm in the critic unequal to 1""" The 1-Lipschitz constraint for improved WGANs is enforced by adding a term to the loss which penalizes if the gradient norm in the critic unequal to 1"""
gradients = K.gradients(K.sum(y_pred), averaged_batch) gradients = K.gradients(y_pred, averaged_batch)
gradient_l2_norm = K.sqrt(K.sum(K.square(gradients))) gradients_sqr = K.square(gradients)[0]
gradients_sqr_sum = K.sum(
K.sum(K.sum(gradients_sqr, axis=3), axis=2), axis=1)
gradient_l2_norm = K.sqrt(gradients_sqr_sum)
gradient_penalty = penalty_weight * K.square(1 - gradient_l2_norm) gradient_penalty = penalty_weight * K.square(1 - gradient_l2_norm)
return gradient_penalty return K.mean(gradient_penalty)
class RandomWeightedAverage(_Merge): class RandomWeightedAverage(_Merge):
...@@ -96,7 +99,7 @@ def plot_loss(loss, log_dir=".", name=""): ...@@ -96,7 +99,7 @@ def plot_loss(loss, log_dir=".", name=""):
plt.legend(loc='upper right', prop={'size': 10}) plt.legend(loc='upper right', prop={'size': 10})
ax1.set_xlabel(r'Iterations') ax1.set_xlabel(r'Iterations')
ax1.set_ylabel(r'Loss') ax1.set_ylabel(r'Loss')
ax1.set_ylim(-1.5, 1.5) ax1.set_ylim(np.min(loss), 2)
fig.savefig(log_dir + '/%s_Loss.png' % name, dpi=120) fig.savefig(log_dir + '/%s_Loss.png' % name, dpi=120)
plt.close('all') plt.close('all')
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment