diff --git a/keras.py b/keras.py index 5500168a33cf720342b6f66bbf9e20eb2ce23ea8..d328c2614023ae2ba4a394fe849045a4cf89a45d 100644 --- a/keras.py +++ b/keras.py @@ -136,6 +136,21 @@ class KFeed(object): assert "auto_steps" not in kwargs return self.gensteps(*args, **kwargs)[0] + def getShapes(self, src): + return tuple( + tuple(src[k].shape for k in g) + for g in self.xyw + ) + + def mkInputs(self, src, **kwargs): + return tuple( + tf.keras.Input(shape=ref.shape[1:], dtype=ref.dtype, **kwargs) + for ref in ( + src[x] + for x in self.xyw[0] + ) + ) + def Normal(ref, indices=None, name=None, **kwargs): """ @@ -168,8 +183,8 @@ def Onehot(index, n, name=None): x[...,(index+1):], ), axis=-1) return tf.keras.layers.Lambda(to_onehot, name=name) - - + + class Moment(tf.keras.metrics.Mean): def __init__(self, order, label=False, **kwargs):