diff --git a/AirShower_WGAN/utils.py b/AirShower_WGAN/utils.py index 125dd31d9594a9eebf5acfd941386f26a5493a62..cefe9b017c951fe0c407e6264ef409d4c92a38bd 100644 --- a/AirShower_WGAN/utils.py +++ b/AirShower_WGAN/utils.py @@ -44,10 +44,13 @@ def wasserstein_loss(y_true, y_pred): def gradient_penalty_loss(y_true, y_pred, averaged_batch, penalty_weight): """Calculates the gradient penalty (for details see arXiv:1704.00028v3). The 1-Lipschitz constraint for improved WGANs is enforced by adding a term to the loss which penalizes if the gradient norm in the critic unequal to 1""" - gradients = K.gradients(K.sum(y_pred), averaged_batch) - gradient_l2_norm = K.sqrt(K.sum(K.square(gradients))) + gradients = K.gradients(y_pred, averaged_batch) + gradients_sqr = K.square(gradients)[0] + gradients_sqr_sum = K.sum( + K.sum(K.sum(gradients_sqr, axis=3), axis=2), axis=1) + gradient_l2_norm = K.sqrt(gradients_sqr_sum) gradient_penalty = penalty_weight * K.square(1 - gradient_l2_norm) - return gradient_penalty + return K.mean(gradient_penalty) class RandomWeightedAverage(_Merge): @@ -96,7 +99,7 @@ def plot_loss(loss, log_dir=".", name=""): plt.legend(loc='upper right', prop={'size': 10}) ax1.set_xlabel(r'Iterations') ax1.set_ylabel(r'Loss') - ax1.set_ylim(-1.5, 1.5) + ax1.set_ylim(np.min(loss), 2) fig.savefig(log_dir + '/%s_Loss.png' % name, dpi=120) plt.close('all')