Commit 5964b038 authored by JGlombitza's avatar JGlombitza

allow loading and saving of the model

parent a77fadff
# Keras EdgeConv-Layer
General implementation of the EdgeConv-Block as described in [ParticleNet: Jet Tagging via Particle Clouds](https://arxiv.org/abs/1902.08570).
General implementation of the EdgeConv-Block as described in [Dynamic Graph CNN for Learning on Point Clouds](https://arxiv.org/abs/1801.07829).
## Using the layer
In order to use the layer class found in [edgeconv.py](edgeconv.py), a h-function needs to be defined. This function should take a list of two Keras tensors of length C with C being the dimension of the features. The h-function might include other Keras layers. Setting the number of k nearest neighbors to be considered using the `next_neighbors` argument, the EdgeConv layer can be implemented as demonstrated in [test.py](test.py).
In order to use the layer class found in [edgeconv.py](edgeconv.py), a kernel-function needs to be defined. This function should take a list of two Keras tensors of length C with C being the dimension of the features. The kernel-function might must include Keras layers. Setting the number of k nearest neighbors to be considered using the `next_neighbors` argument, the EdgeConv layer can be implemented as demonstrated in [test.py](test.py).
## Acknowledgement
This implementation borrows code from the ParticleNet tensorflow [implementation](https://github.com/hqucms/ParticleNet).
......@@ -56,7 +56,7 @@ class EdgeConv(lay.Layer):
'''
Keras layer implementation of EdgeConv.
# Arguments
h_func: h-function applied on the points and it's k nearest neighbors. The function should take a list
kernel_func: h-function applied on the points and it's k nearest neighbors. The function should take a list
of two tensors. The first tensor is the vector v_i of the central point, the second tensor is the vector
of one of its neighbors v_j.
:param list: [v_i, v_j] with v_i and v_j being Keras tensors with shape (C_f, ).
......@@ -76,12 +76,21 @@ class EdgeConv(lay.Layer):
with C_h being the output dimension of the h-function.
'''
def __init__(self, h_func, next_neighbors, agg_func=tf.reduce_mean, **kwargs):
self.h_func = h_func
self.K = next_neighbors
def __init__(self, kernel_func, next_neighbors, agg_func=keras.backend.mean, **kwargs):
self.kernel_func = kernel_func
self.next_neighbors = next_neighbors
self.agg_func = agg_func
if type(agg_func) == str:
raise ValueError("No such agg_func '%s'. When loading the model specify the agg_func '%s' via custom_objects" % (agg_func, agg_func))
super(EdgeConv, self).__init__(**kwargs)
def get_config(self):
config = {'next_neighbors': self.next_neighbors,
'kernel_func': self.kernel_func,
'agg_func': self.agg_func}
base_config = super(EdgeConv, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def build(self, input_shape):
# Create a trainable weight variable for this layer.
try:
......@@ -89,13 +98,14 @@ class EdgeConv(lay.Layer):
except ValueError:
f_shape = input_shape
x = lay.Input((f_shape.as_list()[-1]*2,))
a = lay.Reshape((f_shape.as_list()[-1], 2))(x)
x1, x2 = SplitLayer(n_splits=2, split_axis=-1)(a) # (2, C)
x1 = lay.Reshape((f_shape.as_list()[-1],))(x1)
x2 = lay.Reshape((f_shape.as_list()[-1],))(x2)
y = self.h_func([x1, x2])
self.h_func = keras.models.Model(x, y)
if type(self.kernel_func) != keras.models.Model: # for not wrapping model around model when loading model
x = lay.Input((f_shape.as_list()[-1] * 2,))
a = lay.Reshape((f_shape.as_list()[-1], 2))(x)
x1, x2 = SplitLayer(n_splits=2, split_axis=-1)(a) # (2, C)
x1 = lay.Reshape((f_shape.as_list()[-1],))(x1)
x2 = lay.Reshape((f_shape.as_list()[-1],))(x2)
y = self.kernel_func([x1, x2])
self.kernel_func = keras.models.Model(x, y)
super(EdgeConv, self).build(input_shape) # Be sure to call this at the end
......@@ -104,27 +114,26 @@ class EdgeConv(lay.Layer):
points, features = x
except TypeError:
points = features = x
# distance
D = batch_distance_matrix_general(points, points) # (N, P, P)
_, indices = tf.nn.top_k(-D, k=self.K + 1) # (N, P, K+1)
indices = indices[:, :, 1:] # (N, P, K)
fts = features
knn_fts = knn(indices, fts) # (N, P, K, C)
knn_fts_center = tf.tile(tf.expand_dims(fts, axis=2), (1, 1, self.K, 1)) # (N, P, K, C)
_, indices = tf.nn.top_k(-D, k=self.next_neighbors + 1) # (N, P, K+1)
indices = indices[:, :, 1:] # (N, P, K) remove self connection
knn_fts = knn(indices, features) # (N, P, K, C)
knn_fts_center = tf.tile(tf.expand_dims(features, axis=2), (1, 1, self.next_neighbors, 1)) # (N, P, K, C)
knn_fts = tf.concat([knn_fts_center, knn_fts], axis=-1) # (N, P, K, 2*C)
res = lay.TimeDistributed(lay.TimeDistributed(self.h_func))(knn_fts) # (N, P, K, C')
res = lay.TimeDistributed(lay.TimeDistributed(self.kernel_func))(knn_fts) # (N, P, K, C')
# aggregation
agg = self.agg_func(res, axis=2) # (N, P, C')
return agg
def compute_output_shape(self, input_shape):
self.output_shape = self.h_func.get_output_shape_at(-1)
self.output_shape = self.kernel_func.get_output_shape_at(-1)
return self.output_shape
def batch_distance_matrix_general(A, B):
''' Calculate elements-wise distance between entries in two tensors '''
with tf.name_scope('dmat'):
r_A = tf.reduce_sum(A * A, axis=2, keepdims=True)
r_B = tf.reduce_sum(B * B, axis=2, keepdims=True)
......@@ -136,6 +145,7 @@ def batch_distance_matrix_general(A, B):
def knn(topk_indices, features):
# topk_indices: (N, P, K)
# features: (N, P, C)
# return: (N, P, K, C)
with tf.name_scope('knn'):
k = tf.shape(topk_indices)[-1]
num_points = tf.shape(features)[-2]
......
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers as lay
from edgeconv import EdgeConv
import edgeconv
import numpy as np
......@@ -16,9 +15,9 @@ def f(data):
points = lay.Input((10, 6))
feats = lay.Input((10, 6))
a = EdgeConv(f, next_neighbors=3)([points, feats])
y = EdgeConv(f, next_neighbors=3)(a)
out = EdgeConv(f, next_neighbors=3)(y)
a = edgeconv.EdgeConv(f, next_neighbors=3)([points, feats])
y = edgeconv.EdgeConv(f, next_neighbors=3)(a)
out = edgeconv.EdgeConv(f, next_neighbors=3)(y)
model = keras.models.Model([points, feats], out)
model.summary()
......@@ -26,3 +25,10 @@ model.summary()
model.compile(loss="mse", optimizer=keras.optimizers.Adam())
model.fit([np.ones((300, 10, 6)), np.ones((300, 10, 6))], np.ones((300, 10, 20)), epochs=10)
print("\n------------------------- loading and saving -------------------------\n")
model.save("my_model.h5")
m = keras.models.load_model("my_model.h5", {"EdgeConv": edgeconv.EdgeConv, "SplitLayer": edgeconv.SplitLayer, "mean": keras.backend.mean})
m.summary()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment