Skip to content
Snippets Groups Projects
Commit 3da6d10c authored by Dennis Noll's avatar Dennis Noll
Browse files

[keras] layer: init Blackout layer - can be used to zero certain parts of tensors

parent 6671b720
No related branches found
No related tags found
No related merge requests found
......@@ -2,7 +2,7 @@
import itertools
from functools import cached_property
import gc
from collections import OrderedDict, defaultdict
from collections import OrderedDict, defaultdict, Iterable
import fnmatch
import math
import datetime
......@@ -177,6 +177,26 @@ class KFeed(object):
)
def Blackout(ref, slic, name=None):
"""Blackout (set to zero) values in tensor accoring to numpy like slices. Always done alogn zeroth axis.
Args:
ref (np.array): reference array for shape and dtype
slic (Iterable): iterable of numpy like slices
name (_type_, optional): _description_. Defaults to None.
Returns:
_type_: _description_
"""
shape = ref.shape[1:]
dtype = ref.dtype
mask = np.ones(shape, dtype=dtype)
slic = slic if isinstance(slic, Iterable) else (slic,)
for _slic in slic:
mask[_slic] = 0
return tf.keras.layers.Lambda((lambda x: x * mask), name=name)
def Normal(ref, const=None, ignore_zeros=False, name=None, **kwargs):
"""
Normalizing layer according to ref.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment