diff --git a/keras.py b/keras.py index c211feb25cb1ef26fe27280de5bd30e089d3cce4..a87629ad7e34d049d61d211135455538bcce3d4e 100644 --- a/keras.py +++ b/keras.py @@ -130,10 +130,17 @@ class KFeed(object): return self.gensteps(*args, **kwargs)[0] -def Normal(ref, **kwargs): - """ Normalizing layer according to ref """ +def Normal(ref, indices=None, **kwargs): + """ + Normalizing layer according to ref. + If given, only the variables corresponding to indices will be normalized. + """ mean = ref.mean(**kwargs) std = ref.std(**kwargs) + if indices: + replace = np.isin(np.arange(len(mean)), indices, invert=True) + mean[replace] = 0 + std[replace] = 1 mul = 1.0 / std add = -mean / std return tf.keras.layers.Lambda(lambda x: (x * mul) + add)