Skip to content
Snippets Groups Projects
Commit 856125ca authored by David Josef Schmidt's avatar David Josef Schmidt
Browse files

test

parent 18567f5b
No related branches found
No related tags found
No related merge requests found
......@@ -44,10 +44,18 @@ def wasserstein_loss(y_true, y_pred):
def gradient_penalty_loss(y_true, y_pred, averaged_batch, penalty_weight):
"""Calculates the gradient penalty (for details see arXiv:1704.00028v3).
The 1-Lipschitz constraint for improved WGANs is enforced by adding a term to the loss which penalizes if the gradient norm in the critic unequal to 1"""
gradients = K.gradients(K.sum(y_pred), averaged_batch)
gradient_l2_norm = K.sqrt(K.sum(K.square(gradients)))
gradients = K.gradients(y_pred, averaged_batch)
# compute the euclidean norm by squaring ...
gradients_sqr = K.square(gradients)[0]
# ... summing over the rows ...
gradients_sqr_sum = K.sum(gradients_sqr,
axis=range(1, len(gradients_sqr.shape)))
# ... and sqrt
gradient_l2_norm = K.sqrt(gradients_sqr_sum)
# compute lambda * (1 - ||grad||)^2
gradient_penalty = penalty_weight * K.square(1 - gradient_l2_norm)
return gradient_penalty
# return the mean as loss
return K.mean(gradient_penalty)
class RandomWeightedAverage(_Merge):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment