From 98fbbc88779d2fee1b0168022fd2b5b815d79c20 Mon Sep 17 00:00:00 2001 From: "jan.middendorf@rwth-aachen.de" <jan.middendorf@rwth-aachen.de> Date: Tue, 16 Jun 2020 19:50:24 +0200 Subject: [PATCH] Model working, testing architectures --- keras.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/keras.py b/keras.py index 69fc12f..5500168 100644 --- a/keras.py +++ b/keras.py @@ -160,9 +160,11 @@ def Onehot(index, n, name=None): n is the number of different variables. """ def to_onehot(x): + #Concat zeros to eye for indices in x equal to n (larger than those encoded by one) + eye = tf.concat((tf.eye(n), tf.zeros((1, n))), axis=0) return tf.concat(( x[...,:index], - tf.gather(tf.eye(n), tf.cast(x[...,index], tf.int64)), + tf.gather(eye, tf.cast(x[...,index], tf.int64)), x[...,(index+1):], ), axis=-1) return tf.keras.layers.Lambda(to_onehot, name=name) -- GitLab