Commit a9baa36b authored by Dennis Noll's avatar Dennis Noll

[edgeconf] knn: now using tf.gather

parent 5964b038
......@@ -146,11 +146,4 @@ def knn(topk_indices, features):
# topk_indices: (N, P, K)
# features: (N, P, C)
# return: (N, P, K, C)
with tf.name_scope('knn'):
k = tf.shape(topk_indices)[-1]
num_points = tf.shape(features)[-2]
queries_shape = tf.shape(features)
batch_size = queries_shape[0]
batch_indices = tf.tile(tf.reshape(tf.range(batch_size), (-1, 1, 1, 1)), (1, num_points, k, 1))
indices = tf.concat([batch_indices, tf.expand_dims(topk_indices, axis=3)], axis=3) # (N, P, K, 2)
return tf.gather_nd(features, indices)
return tf.gather(features, topk_indices, batch_dims=1)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment