Commit 00648d07 authored by Niklas Uwe Langner's avatar Niklas Uwe Langner
Browse files

[edgeconv] fix knn bug and add documentation

parent b0d46e10
......@@ -2,14 +2,32 @@
import tensorflow as tf
import tensorflow.keras.layers as lay
from tensorflow import keras
from tensorflow.keras import backend as K
class EdgeConv(lay.Layer):
'''
Keras layer implementation of EdgeConv.
# Arguments
h_func: h-function applied on the points and it's k nearest neighbors.
next_neighbors: number k of nearest neighbors to consider
agg_func: Aggregation function applied after h. Must take argument "axis=2" to
aggregate over all neighbors.
# Input shape
List of two tensors [points, features] with shape:
`[(batch, P, C_p), (batch, P, C_f)]`.
or tensor with shape:
`(batch, P, C)`
if points (coordinates) and features are supposed to be the same.
# Output shape
Tensor with shape:
`(batch, P, C_h)`
with C_h being the output dimension of the h-function.
'''
def __init__(self, h_func, next_neighbors, **kwargs):
def __init__(self, h_func, next_neighbors, agg_func=tf.reduce_mean, **kwargs):
self.h_func = h_func
self.K = next_neighbors
self.agg_func = agg_func
super(EdgeConv, self).__init__(**kwargs)
def build(self, input_shape):
......@@ -18,8 +36,6 @@ class EdgeConv(lay.Layer):
p_shape, f_shape = input_shape
except ValueError:
f_shape = input_shape
print(f_shape)
print(f_shape.as_list()[-1])
x = lay.Input((f_shape.as_list()[-1] * 2,))
self.h_func = keras.models.Model(x, self.h_func(x))
super(EdgeConv, self).build(input_shape) # Be sure to call this at the end
......@@ -30,24 +46,17 @@ class EdgeConv(lay.Layer):
except TypeError:
points = features = x
# distance
# distance
D = batch_distance_matrix_general(points, points) # (N, P, P)
print(D)
_, indices = tf.nn.top_k(-D, k=self.K + 1) # (N, P, K+1)
indices = indices[:, :, 1:] # (N, P, K)
fts = features
knn_fts = knn(indices, fts) # (N, P, K, C)
print(knn_fts)
print(fts)
knn_fts_center = tf.tile(tf.expand_dims(fts, axis=2), (1, 1, self.K, 1)) # (N, P, K, C)
knn_fts = tf.concat([knn_fts_center, tf.subtract(knn_fts, knn_fts_center)], axis=-1) # (N, P, K, 2*C)
print("knn_fts", knn_fts)
print("h_func", self.h_func.get_output_shape_at(-1))
res = lay.TimeDistributed(lay.TimeDistributed(self.h_func))(knn_fts) # (N,P,K,C')
# aggregation
agg = tf.reduce_mean(res, axis=2) # (N, P, C')
agg = self.agg_func(res, axis=2) # (N, P, C')
return agg
def compute_output_shape(self, input_shape):
......@@ -68,7 +77,8 @@ def knn(topk_indices, features):
# topk_indices: (N, P, K)
# features: (N, P, C)
with tf.name_scope('knn'):
k = tf.shape(features)[-1]
k = tf.shape(topk_indices)[-1]
print(k)
num_points = tf.shape(features)[-2]
queries_shape = tf.shape(features)
batch_size = queries_shape[0]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment