Skip to content
Snippets Groups Projects
Commit d1ae3610 authored by Niclas Eich's avatar Niclas Eich
Browse files

Transformer toolkit

parent f9366fe0
No related branches found
No related tags found
No related merge requests found
File added
File added
File added
File added
File added
File added
File added
comet.py 100644 → 100755
File mode changed from 100644 to 100755
data.py 100644 → 100755
File mode changed from 100644 to 100755
dssutils.py 100644 → 100755
File mode changed from 100644 to 100755
evil.py 100644 → 100755
File mode changed from 100644 to 100755
keras.py 100644 → 100755
......@@ -503,9 +503,9 @@ class PlottingCallback:
pin(locals())
super().__init__(**kwargs)
def draw(self, figure, dict, key):
def draw(self, figure, dict, key, epoch=None, name=None):
if self.path is not None:
plt.savefig(f"{self.path}/{key}.pdf")
plt.savefig(f"{self.path}/{key}{epoch}.pdf")
image = figure_to_image(figure)
dict.update({key: image})
......@@ -525,10 +525,12 @@ class PlotMulticlass(PlottingCallback, TFSummaryCallback):
plot_importance=False,
plot_activations=False,
tag="",
epochs=None,
**kwargs,
):
pin(locals())
self.y = y[0]
self.epochs = epochs
self.sample_weight_flat = sample_weight[0]
self.comet_experiment = kwargs.get("comet_experiment", None)
self.freq = kwargs.get("freq", None)
......@@ -552,16 +554,19 @@ class PlotMulticlass(PlottingCallback, TFSummaryCallback):
def on_train_begin(self, logs=None):
if self.plot_inputs:
self.make_input_plots()
self.make_eval_plots(epoch=0, name=str(0))
def on_epoch_end(self, epoch, logs=None):
print("On Epoch end, epoch: ", epoch)
if isinstance(self.freq, (float, int)):
if epoch % int(self.freq) == 0 and epoch > 0:
self.make_eval_plots(epoch=epoch, name=str(epoch))
def on_train_end(self, logs=None):
self.make_eval_plots()
self.make_eval_plots(epoch=self.epochs if self.epochs is not None else 999999999)
def make_input_plots(self):
print("Creating Summary Plots")
inps = self.x
if not isinstance(inps, (list, tuple)):
inps = [inps]
......@@ -613,11 +618,12 @@ class PlotMulticlass(PlottingCallback, TFSummaryCallback):
plt.close("all")
gc.collect()
@cached_property
#@cached_property
@property
def prediction(self):
return self.model.predict(self.x, batch_size=self.batch_size)
def make_eval_plots(self, epoch=0, name=""):
def make_eval_plots(self, epoch, name=""):
imgs = {}
fig = figure_roc_curve(
......@@ -626,7 +632,7 @@ class PlotMulticlass(PlottingCallback, TFSummaryCallback):
class_names=self.class_names,
sample_weight=self.sample_weight_flat,
)
self.draw(fig, imgs, "roc_curve{}".format(name), epoch=epoch, name="ROC")
self.draw(fig, imgs, "roc_curve".format(name), epoch=epoch, name="ROC")
self.clear_figure(fig)
fig = figure_confusion_matrix(
......@@ -639,7 +645,7 @@ class PlotMulticlass(PlottingCallback, TFSummaryCallback):
self.draw(
fig,
imgs,
"confusion_matrix_true{}".format(name),
"confusion_matrix_true".format(name),
epoch=epoch,
name="Confusion-Matrix-True",
)
......@@ -655,7 +661,7 @@ class PlotMulticlass(PlottingCallback, TFSummaryCallback):
self.draw(
fig,
imgs,
"confusion_matrix_pred{}".format(name),
"confusion_matrix_pred".format(name),
epoch=epoch,
name="Confusion-Matrix-Pred",
)
......@@ -1376,6 +1382,79 @@ class DenseNetBlock(tf.keras.layers.Layer):
return {"block_size": self.block_size, "sub_kwargs": self.sub_kwargs}
# Define TransformerBlock based on https://arxiv.org/pdf/1706.03762.pdf and VISPA transformer example
class TransformerBlock(tf.keras.layers.Layer):
def __init__(self, embed_dim, num_heads, ff_dim, rate=0.0, activation="relu"):
super(TransformerBlock, self).__init__()
self.att = tf.keras.layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim // num_heads)
self.ffn = tf.keras.models.Sequential([tf.keras.layers.Dense(ff_dim, activation=activation), tf.keras.layers.Dense(embed_dim)])
"""
self.batchnorm1 = tf.keras.layers.BatchNormalization()
self.batchnorm2 = tf.keras.layers.BatchNormalization()
"""
# can have BIG influence on training
self.batchnorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
self.batchnorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
self.dropout1 = tf.keras.layers.Dropout(rate)
self.dropout2 = tf.keras.layers.Dropout(rate)
def call(self, input1, input2, training=None):
attn_output = self.att(input1, input2)
attn_output = self.dropout1(attn_output, training=training)
out1 = self.batchnorm1(input1 + attn_output)
ffn_output = self.ffn(out1)
ffn_output = self.dropout2(ffn_output, training=training)
return self.batchnorm2(out1 + ffn_output)
class PositionalEncodingEmbedding(tf.keras.layers.Layer):
def __init__(self, cls_token=False, **kwargs):
super(PositionalEncodingEmbedding, self).__init__(**kwargs)
self.cls_token = cls_token
def build(self, input_shape):
self.steps = input_shape[-2]
self.position_embedding = tf.keras.layers.Embedding(input_dim=self.steps, output_dim=input_shape[-1])
if self.cls_token:
initial_value = tf.zeros((1, 1, input_shape[-1]))
self.cls_t = tf.Variable(initial_value=initial_value, trainable=True, name="cls")
def get_config(self):
config = {'cls_token': self.cls_token}
base_config = super(PositionalEncodingEmbedding, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self, inputs):
if self.cls_token:
n = tf.shape(inputs)[0]
cls_t = tf.tile(self.cls_t, (n, 1, 1))
inputs = tf.concat([cls_t, inputs], axis=1)
positions = tf.range(start=0, limit=self.steps, delta=1)
encoded = inputs + self.position_embedding(positions)
return encoded
def get_angles(pos, i, d_model):
angle_rates = 1 / np.power(10000, (2 * (i // 2)) / np.float32(d_model))
return pos * angle_rates
def positional_encoding(position, d_model):
angle_rads = get_angles(np.arange(position)[:, np.newaxis],
np.arange(d_model)[np.newaxis, :],
d_model)
# apply sin to even indices in the array; 2i
angle_rads[:, 0::2] = np.sin(angle_rads[:, 0::2])
# apply cos to odd indices in the array; 2i+1
angle_rads[:, 1::2] = np.cos(angle_rads[:, 1::2])
pos_encoding = angle_rads[np.newaxis, ...]
return tf.cast(pos_encoding, dtype=tf.float32)
class LinearNetwork(tf.keras.layers.Layer):
@property
def name(self):
......@@ -1406,6 +1485,113 @@ class LinearNetwork(tf.keras.layers.Layer):
return {"layers": self.layers, "sub_kwargs": self.sub_kwargs}
class TransformerNetwork(tf.keras.layers.Layer):
#Model defaults
embedding_dim = 256
num_heads = 8
ff_dim = 256
n_blocks = 2
dropout_rate = 0.25
l2 = 0.0
activation = "relu"
def __init__(self, sub_kwargs=None, **kwargs):
super().__init__(name="TransformerNetwork")
self.activation = kwargs.get("activation", "relu")
self.dropout_rate = kwargs.get("dropout", 0.25)
self.embedding_dim = kwargs.get("nodes", 256)
self.l2 = kwargs.get("l2", 0.0)
self.n_blocks = kwargs.get("layers", 2)
self.num_heads = kwargs.get("block_size", 8)
self.ff_dim = kwargs.get("ff_dim", 256)
self.sub_kwargs = kwargs if sub_kwargs is None else sub_kwargs
def build(self, input_shape):
# Embedding
self.embedding_features = tf.keras.models.Sequential([tf.keras.layers.Dense(64, activation=self.activation),
tf.keras.layers.Dense(128, activation="relu"),
tf.keras.layers.Dense(self.embedding_dim, activation="relu")])
self.embedding_jets = tf.keras.models.Sequential([tf.keras.layers.Dense(64, activation=self.activation),
tf.keras.layers.Dense(128, activation="relu"),
tf.keras.layers.Dense(self.embedding_dim, activation="relu")])
self.embedding_leptons = tf.keras.models.Sequential([tf.keras.layers.Dense(64, activation=self.activation),
tf.keras.layers.Dense(128, activation="relu"),
tf.keras.layers.Dense(self.embedding_dim, activation="relu")])
self.embedding_Z = tf.keras.models.Sequential([tf.keras.layers.Dense(64, activation=self.activation),
tf.keras.layers.Dense(128, activation="relu"),
tf.keras.layers.Dense(self.embedding_dim, activation="relu")])
self.embedding_H = tf.keras.models.Sequential([tf.keras.layers.Dense(64, activation=self.activation),
tf.keras.layers.Dense(128, activation="relu"),
tf.keras.layers.Dense(self.embedding_dim, activation="relu")])
self.embedding_ZH = tf.keras.models.Sequential([tf.keras.layers.Dense(64, activation=self.activation),
tf.keras.layers.Dense(128, activation="relu"),
tf.keras.layers.Dense(self.embedding_dim, activation="relu")])
self.positional_embedding_features = PositionalEncodingEmbedding(cls_token=False)
self.positional_embedding_particles = PositionalEncodingEmbedding(cls_token=False)
self.self_attentions = []
self.cross_attentions = []
for _ in range(self.n_blocks):
self.self_attentions.append( (TransformerBlock(embed_dim=self.embedding_dim,
num_heads=self.num_heads,
ff_dim=self.ff_dim,
rate=self.dropout_rate),
TransformerBlock(embed_dim=self.embedding_dim,
num_heads=self.num_heads,
ff_dim=self.ff_dim,
rate=self.dropout_rate)))
self.cross_attentions.append(TransformerBlock(embed_dim=self.embedding_dim,
num_heads=self.num_heads,
ff_dim=self.ff_dim,
rate=self.dropout_rate))
self.global_pooling = tf.keras.layers.GlobalAveragePooling1D()
self.final_embed = tf.keras.layers.Dense(self.embedding_dim, activation='gelu')
self.final_dense = tf.keras.layers.Dense(64, activation='gelu')
def call(self, inputs, training=False):
input_leptons, input_jets, input_Z, input_H, input_ZH, input_features = inputs
# Transformer network:
# Define inputs
emb_leptons = self.embedding_leptons(input_leptons)
emb_jets = self.embedding_jets(input_jets)
emb_Z = self.embedding_Z(input_Z)
emb_H = self.embedding_H(input_H)
emb_ZH = self.embedding_ZH(input_ZH)
emb_particles = tf.concat([emb_leptons, emb_jets, emb_Z, emb_H, emb_ZH], axis=1)
emb_features = self.embedding_features(input_features)
x1 = emb_particles
x2 = emb_features
for att1, att2 in self.self_attentions:
x1 = att1(x1, x1)
x2 = att2(x2, x2)
x = x1
xx = x2
# Cross attention
for cross_attention in self.cross_attentions:
x, xx = cross_attention(x,xx), cross_attention(xx,x)
# Combine #Change into class token
x = self.global_pooling(x) #x1[:, 0]
xx = self.global_pooling(xx) #x1[:, 0]
x = tf.concat([x, xx], axis=-1)
# Get final prediction
x = self.final_embed(x)
x = self.final_dense(x)
return x
def get_config(self):
return {"block_size": self.num_heads, "activation": self.activation, "dropout": self.dropout_rate, "l2": self.l2, "layers": self.n_blocks, "sub_kwargs": self.sub_kwargs}
class FullyConnected(LinearNetwork):
"""
The FullyConnected object is an implementation of a fully connected DNN.
......@@ -1457,6 +1643,36 @@ class DenseNet(LinearNetwork):
substructure = DenseNetBlock
class Transformer(tf.keras.layers.Layer):
"""
The Transofmer object is an implementation of a DenseNet Neural Network.
Parameters
----------
layers : int
???
kwargs :
Arguments for TransformerBlock.
"""
name = "Transformer"
substructure = TransformerNetwork
def __init__(self, **kwargs):
super().__init__(name=self.name)
self.kwargs = kwargs
def build(self, input_shape):
self.transformer = TransformerNetwork(**self.kwargs)
def call(self, inputs, training=False):
x = self.transformer(inputs, training=training)
return x
def get_config(self):
return {"kwargs": self.kwargs}
class Xception1D(tf.keras.layers.Layer):
"""
The Xception1D object is an implementation of a Xception Neural Network.
......
misc.py 100644 → 100755
File mode changed from 100644 to 100755
numpy.py 100644 → 100755
File mode changed from 100644 to 100755
plotting.py 100644 → 100755
File mode changed from 100644 to 100755
tf.py 100644 → 100755
File mode changed from 100644 to 100755
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment