diff --git a/keras.py b/keras.py index fced3c9fd7ce84e3ea75aca46c37efde65bfd913..e4b333cd1d28e2d504c0807390335ab8a6ffb503 100644 --- a/keras.py +++ b/keras.py @@ -164,13 +164,17 @@ class KFeed(object): ) -def Normal(ref, indices=None, name=None, **kwargs): +def Normal(ref, indices=None, ignore_zeros=False, name=None, **kwargs): """ Normalizing layer according to ref. If given, only the variables corresponding to indices will be normalized. """ - mean = ref.mean(**kwargs) - std = ref.std(**kwargs) + if ignore_zeros: + mean = np.nanmean(np.where(ref==0, np.ones_like(ref)*np.nan, ref), **kwargs) + std = np.nanstd(np.where(ref==0, np.ones_like(ref)*np.nan, ref), **kwargs) + else: + mean = ref.mean(**kwargs) + std = ref.std(**kwargs) if indices: indices = np.array(indices) replace = np.isin(np.arange(len(mean)), indices, invert=True)