diff --git a/keras.py b/keras.py index 69fc12f524ff78303d9fbec358c4c06289fbc035..5500168a33cf720342b6f66bbf9e20eb2ce23ea8 100644 --- a/keras.py +++ b/keras.py @@ -160,9 +160,11 @@ def Onehot(index, n, name=None): n is the number of different variables. """ def to_onehot(x): + #Concat zeros to eye for indices in x equal to n (larger than those encoded by one) + eye = tf.concat((tf.eye(n), tf.zeros((1, n))), axis=0) return tf.concat(( x[...,:index], - tf.gather(tf.eye(n), tf.cast(x[...,index], tf.int64)), + tf.gather(eye, tf.cast(x[...,index], tf.int64)), x[...,(index+1):], ), axis=-1) return tf.keras.layers.Lambda(to_onehot, name=name)