From 1a205b2d00be6484b713385efc1463b459af15cb Mon Sep 17 00:00:00 2001 From: "jan.middendorf@rwth-aachen.de" <jan.middendorf@rwth-aachen.de> Date: Wed, 10 Jun 2020 14:38:54 +0200 Subject: [PATCH] Seed test --- keras.py | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/keras.py b/keras.py index 098f69f..69fc12f 100644 --- a/keras.py +++ b/keras.py @@ -67,7 +67,7 @@ class KFeed(object): @property def all(self): - return set().union(*self.xyz) + return set().union(*self.xyw) def balance(self, src): return src @@ -137,7 +137,7 @@ class KFeed(object): return self.gensteps(*args, **kwargs)[0] -def Normal(ref, indices=None, **kwargs): +def Normal(ref, indices=None, name=None, **kwargs): """ Normalizing layer according to ref. If given, only the variables corresponding to indices will be normalized. @@ -145,14 +145,29 @@ def Normal(ref, indices=None, **kwargs): mean = ref.mean(**kwargs) std = ref.std(**kwargs) if indices: + indices = np.array(indices) replace = np.isin(np.arange(len(mean)), indices, invert=True) mean[replace] = 0 std[replace] = 1 mul = 1.0 / std add = -mean / std - return tf.keras.layers.Lambda(lambda x: (x * mul) + add) + return tf.keras.layers.Lambda((lambda x: (x * mul) + add), name=name) +def Onehot(index, n, name=None): + """ + One hot encodes a variable referred to by index. + n is the number of different variables. + """ + def to_onehot(x): + return tf.concat(( + x[...,:index], + tf.gather(tf.eye(n), tf.cast(x[...,index], tf.int64)), + x[...,(index+1):], + ), axis=-1) + return tf.keras.layers.Lambda(to_onehot, name=name) + + class Moment(tf.keras.metrics.Mean): def __init__(self, order, label=False, **kwargs): -- GitLab