Skip to content
Snippets Groups Projects
Commit 98fbbc88 authored by jan.middendorf@rwth-aachen.de's avatar jan.middendorf@rwth-aachen.de
Browse files

Model working, testing architectures

parent 1a205b2d
No related branches found
No related tags found
No related merge requests found
......@@ -160,9 +160,11 @@ def Onehot(index, n, name=None):
n is the number of different variables.
"""
def to_onehot(x):
#Concat zeros to eye for indices in x equal to n (larger than those encoded by one)
eye = tf.concat((tf.eye(n), tf.zeros((1, n))), axis=0)
return tf.concat((
x[...,:index],
tf.gather(tf.eye(n), tf.cast(x[...,index], tf.int64)),
tf.gather(eye, tf.cast(x[...,index], tf.int64)),
x[...,(index+1):],
), axis=-1)
return tf.keras.layers.Lambda(to_onehot, name=name)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment