Commit a77fadff authored by Niklas Uwe Langner's avatar Niklas Uwe Langner
Browse files

[edgeconv], [test] change code so that the h-function takes two tensors

parent 70c2f970
......@@ -4,7 +4,7 @@ General implementation of the EdgeConv-Block as described in [ParticleNet: Jet T
## Using the layer
In order to use the layer class found in [edgeconv.py](edgeconv.py), a h-function needs to be defined. This function should take a Keras tensor of length 2*C with C being the dimension of the features. The h-function might include other Keras layers. Setting the number of k nearest neighbors to be considered using the `next_neighbors` argument, the EdgeConv layer can be implemented as demonstrated in [test.py](test.py).
In order to use the layer class found in [edgeconv.py](edgeconv.py), a h-function needs to be defined. This function should take a list of two Keras tensors of length C with C being the dimension of the features. The h-function might include other Keras layers. Setting the number of k nearest neighbors to be considered using the `next_neighbors` argument, the EdgeConv layer can be implemented as demonstrated in [test.py](test.py).
## Acknowledgement
This implementation borrows code from the ParticleNet tensorflow [implementation](https://github.com/hqucms/ParticleNet).
......@@ -4,11 +4,63 @@ import tensorflow.keras.layers as lay
from tensorflow import keras
class SplitLayer(lay.Layer):
""" Custom layer: split layer along specific axis.
eg. split (1,9) into 9 x (1,1)
Parameters
----------
n_splits : int
number of splits
split_axis : int
axis where to split tensor
**kwargs : type
Description of parameter `**kwargs`.
Attributes
----------
n_splits
split_axis
"""
def __init__(self, n_splits=12, split_axis=-1, **kwargs):
self.n_splits = n_splits
self.split_axis = split_axis
super(SplitLayer, self).__init__(**kwargs)
def get_config(self):
config = {'n_splits': self.n_splits,
'split_axis': self.split_axis}
base_config = super(SplitLayer, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self, x):
''' return array of splitted tensors '''
sub_tensors = tf.split(x, self.n_splits, axis=self.split_axis)
return sub_tensors
def compute_output_shape(self, input_shape):
sub_tensor_shape = list(input_shape)
num_channels = sub_tensor_shape[-1]
sub_tensor_shape[-1] = int(num_channels / self.n_splits)
sub_tensor_shape = tuple(sub_tensor_shape)
list_of_output_shape = [sub_tensor_shape] * self.n_splits
return list_of_output_shape
def compute_mask(self, inputs, mask=None):
return self.n_splits * [None]
class EdgeConv(lay.Layer):
'''
Keras layer implementation of EdgeConv.
# Arguments
h_func: h-function applied on the points and it's k nearest neighbors.
h_func: h-function applied on the points and it's k nearest neighbors. The function should take a list
of two tensors. The first tensor is the vector v_i of the central point, the second tensor is the vector
of one of its neighbors v_j.
:param list: [v_i, v_j] with v_i and v_j being Keras tensors with shape (C_f, ).
:return: Keras tensor of shape (C', ).
next_neighbors: number k of nearest neighbors to consider
agg_func: Aggregation function applied after h. Must take argument "axis=2" to
aggregate over all neighbors.
......@@ -36,8 +88,15 @@ class EdgeConv(lay.Layer):
p_shape, f_shape = input_shape
except ValueError:
f_shape = input_shape
x = lay.Input((f_shape.as_list()[-1] * 2,))
self.h_func = keras.models.Model(x, self.h_func(x))
x = lay.Input((f_shape.as_list()[-1]*2,))
a = lay.Reshape((f_shape.as_list()[-1], 2))(x)
x1, x2 = SplitLayer(n_splits=2, split_axis=-1)(a) # (2, C)
x1 = lay.Reshape((f_shape.as_list()[-1],))(x1)
x2 = lay.Reshape((f_shape.as_list()[-1],))(x2)
y = self.h_func([x1, x2])
self.h_func = keras.models.Model(x, y)
super(EdgeConv, self).build(input_shape) # Be sure to call this at the end
def call(self, x):
......@@ -53,8 +112,9 @@ class EdgeConv(lay.Layer):
fts = features
knn_fts = knn(indices, fts) # (N, P, K, C)
knn_fts_center = tf.tile(tf.expand_dims(fts, axis=2), (1, 1, self.K, 1)) # (N, P, K, C)
knn_fts = tf.concat([knn_fts_center, tf.subtract(knn_fts, knn_fts_center)], axis=-1) # (N, P, K, 2*C)
res = lay.TimeDistributed(lay.TimeDistributed(self.h_func))(knn_fts) # (N,P,K,C')
knn_fts = tf.concat([knn_fts_center, knn_fts], axis=-1) # (N, P, K, 2*C)
res = lay.TimeDistributed(lay.TimeDistributed(self.h_func))(knn_fts) # (N, P, K, C')
# aggregation
agg = self.agg_func(res, axis=2) # (N, P, C')
return agg
......@@ -78,7 +138,6 @@ def knn(topk_indices, features):
# features: (N, P, C)
with tf.name_scope('knn'):
k = tf.shape(topk_indices)[-1]
print(k)
num_points = tf.shape(features)[-2]
queries_shape = tf.shape(features)
batch_size = queries_shape[0]
......
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers as lay
from edgeconv import EdgeConv
import numpy as np
def f(x):
def f(data):
d1, d2 = data
dif = lay.Subtract()([d1, d2])
x = lay.Concatenate(axis=-1)([d1, dif])
x = lay.Dense(30)(x)
x = lay.Dense(20)(x)
return x
points = lay.Input((10, 5))
feats = lay.Input((10, 5))
points = lay.Input((10, 6))
feats = lay.Input((10, 6))
a = EdgeConv(f, next_neighbors=3)([points, feats])
y = EdgeConv(f, next_neighbors=3)(a)
out = EdgeConv(f, next_neighbors=3)(y)
......@@ -21,4 +25,4 @@ model.summary()
model.compile(loss="mse", optimizer=keras.optimizers.Adam())
model.fit([np.ones((16, 10, 5)), np.ones((16, 10, 5))], np.ones((16, 10, 20)), epochs=10)
model.fit([np.ones((300, 10, 6)), np.ones((300, 10, 6))], np.ones((300, 10, 20)), epochs=10)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment