Skip to content
Snippets Groups Projects
Commit 04bea6db authored by Dennis Noll's avatar Dennis Noll
Browse files

[keras] adds WhereEquals Layer

parent 086b9781
No related branches found
No related tags found
No related merge requests found
...@@ -932,6 +932,18 @@ class AUCOneVsAll(tf.keras.metrics.AUC): ...@@ -932,6 +932,18 @@ class AUCOneVsAll(tf.keras.metrics.AUC):
) )
# Layer Definitions
class WhereEquals(tf.keras.layers.Layer):
def __init__(self, value=0):
super(WhereEquals, self).__init__()
self.value = value
def call(self, inp):
return tf.where(inp[:, 0] == self.value)
class DenseLayer(tf.keras.layers.Layer): class DenseLayer(tf.keras.layers.Layer):
""" """
The DenseLayer object is an extended implementation of the tf.keras.layers.Dense. The DenseLayer object is an extended implementation of the tf.keras.layers.Dense.
...@@ -1297,6 +1309,8 @@ def grouped_cross_entropy_t( ...@@ -1297,6 +1309,8 @@ def grouped_cross_entropy_t(
# Custom Loss Functions # Custom Loss Functions
class GroupedXEnt(tf.keras.losses.Loss): class GroupedXEnt(tf.keras.losses.Loss):
def __init__( def __init__(
self, self,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment