From 98fbbc88779d2fee1b0168022fd2b5b815d79c20 Mon Sep 17 00:00:00 2001
From: "jan.middendorf@rwth-aachen.de" <jan.middendorf@rwth-aachen.de>
Date: Tue, 16 Jun 2020 19:50:24 +0200
Subject: [PATCH] Model working, testing architectures

---
 keras.py | 4 +++-
 1 file changed, 3 insertions(+), 1 deletion(-)

diff --git a/keras.py b/keras.py
index 69fc12f..5500168 100644
--- a/keras.py
+++ b/keras.py
@@ -160,9 +160,11 @@ def Onehot(index, n, name=None):
     n is the number of different variables.
     """
     def to_onehot(x):
+        #Concat zeros to eye for indices in x equal to n (larger than those encoded by one)
+        eye = tf.concat((tf.eye(n), tf.zeros((1, n))), axis=0)
         return tf.concat((
                           x[...,:index],
-                          tf.gather(tf.eye(n), tf.cast(x[...,index], tf.int64)),
+                          tf.gather(eye, tf.cast(x[...,index], tf.int64)),
                           x[...,(index+1):],
                          ), axis=-1)
     return tf.keras.layers.Lambda(to_onehot, name=name)
-- 
GitLab