From 1a205b2d00be6484b713385efc1463b459af15cb Mon Sep 17 00:00:00 2001
From: "jan.middendorf@rwth-aachen.de" <jan.middendorf@rwth-aachen.de>
Date: Wed, 10 Jun 2020 14:38:54 +0200
Subject: [PATCH] Seed test

---
 keras.py | 21 ++++++++++++++++++---
 1 file changed, 18 insertions(+), 3 deletions(-)

diff --git a/keras.py b/keras.py
index 098f69f..69fc12f 100644
--- a/keras.py
+++ b/keras.py
@@ -67,7 +67,7 @@ class KFeed(object):
 
     @property
     def all(self):
-        return set().union(*self.xyz)
+        return set().union(*self.xyw)
 
     def balance(self, src):
         return src
@@ -137,7 +137,7 @@ class KFeed(object):
         return self.gensteps(*args, **kwargs)[0]
 
 
-def Normal(ref, indices=None, **kwargs):
+def Normal(ref, indices=None, name=None, **kwargs):
     """
     Normalizing layer according to ref.
     If given, only the variables corresponding to indices will be normalized.
@@ -145,14 +145,29 @@ def Normal(ref, indices=None, **kwargs):
     mean = ref.mean(**kwargs)
     std = ref.std(**kwargs)
     if indices:
+        indices = np.array(indices)
         replace = np.isin(np.arange(len(mean)), indices, invert=True)
         mean[replace] = 0
         std[replace] = 1
     mul = 1.0 / std
     add = -mean / std
-    return tf.keras.layers.Lambda(lambda x: (x * mul) + add)
+    return tf.keras.layers.Lambda((lambda x: (x * mul) + add), name=name)
 
 
+def Onehot(index, n, name=None):
+    """
+    One hot encodes a variable referred to by index.
+    n is the number of different variables.
+    """
+    def to_onehot(x):
+        return tf.concat((
+                          x[...,:index],
+                          tf.gather(tf.eye(n), tf.cast(x[...,index], tf.int64)),
+                          x[...,(index+1):],
+                         ), axis=-1)
+    return tf.keras.layers.Lambda(to_onehot, name=name)
+    
+    
 class Moment(tf.keras.metrics.Mean):
 
     def __init__(self, order, label=False, **kwargs):
-- 
GitLab