-
Dennis Noll authoredDennis Noll authored
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
data.py 7.57 KiB
import numpy as np
import tensorflow as tf
from operator import itemgetter
from os import listdir, path, remove
class SKDict(dict):
@staticmethod
def keyify(keyish):
if not isinstance(keyish, (tuple, list, set, frozenset)):
keyish = (keyish,)
keyish = frozenset(keyish)
assert not any(isinstance(key, set) for key in keyish)
return keyish
def __init__(self, *args, **kwargs):
super(SKDict, self).__init__()
self.update(dict(*args, **kwargs))
def update(self, *args, **kwargs):
# assert 0 <= len(args) <= 1
args += (kwargs,)
for arg in args:
for k, v in arg.items():
self[k] = v
def get(self, key, *args, **kwargs):
return super(SKDict, self).get(self.keyify(key), *args, **kwargs)
def pop(self, key, *args, **kwargs):
return super(SKDict, self).pop(self.keyify(key), *args, **kwargs)
def __getitem__(self, key):
key = self.keyify(key)
if key in self:
return super(SKDict, self).__getitem__(key)
ret = self.__class__({k - key: v for k, v in self.items() if key <= k})
if not ret:
raise KeyError(key)
return ret
def __setitem__(self, key, value):
key = self.keyify(key)
if isinstance(value, dict):
for k, v in value.items():
self[key | self.keyify(k)] = v
else:
super(SKDict, self).__setitem__(key, value)
def __delitem__(self, key):
key = self.keyify(key)
for k in self[key].keys():
super().__delitem__(k | key)
def copy(self):
return self.__class__(self)
def map(self, func, groups=None):
if groups is None:
groups = self.keys()
return self.__class__({g: func(self[g]) for g in map(self.keyify, groups)})
@classmethod
def zip(cls, *insts):
assert all(isinstance(inst, cls) for inst in insts)
keys = set()
keys.update(*(inst.keys() for inst in insts))
return cls({key: tuple(inst.get(key) for inst in insts) for key in keys})
@classmethod
def concatenate(cls, *insts):
return cls.zip(*insts).map(lambda x: np.concatenate(x))
def only(self, *keys):
return self.__class__({key: self[key] for key in keys})
def dekey(self, *key):
for k in self[key].keys():
yield self[k]
def items(self, *key):
for k in (self[key] if key else self).keys():
yield k, self[k]
def mkeys(self, keys):
data = [self[key] for key in keys]
ads = set(isinstance(d, self.__class__) for d in data)
if ads == {True}:
keys = set(frozenset(d.keys()) for d in data)
assert len(keys) == 1 # bad depth
return list(keys)[0]
elif ads == {False}:
return ((),)
else:
raise RuntimeError("bad depth")
def skeys(self):
return map(frozenset, sorted(map(sorted, self.keys())))
@property
def pretty(self):
return {"/".join(sorted(map(str, k))): v for k, v in self.items()}
class GetNextSlice(object):
curr = None
def __init__(self, next):
self.next = next
def __call__(self, num):
if self.curr is None:
self.curr = self.next()
self.pos = 0
sli = self.curr[self.pos : self.pos + num]
self.pos += num
if len(sli) < num:
del self.curr
return sli
def iter(self, batch_size, fill=True, times=np.inf):
while times:
acc = self(batch_size)
while fill and len(acc) < batch_size:
acc = np.concatenate((acc, self(batch_size - len(acc))), axis=0)
times -= 1
yield acc
class DSS(SKDict):
@property
def blen(self):
lens = list(set(val.shape[0] for val in self.values()))
assert len(lens) == 1
return lens[0]
@property
def dtype(self):
dtypes = list(set(val.dtype for val in self.values()))
assert len(dtypes) == 1
return dtypes[0]
@property
def dims(self):
dimss = list(set(val.ndim for val in self.values()))
assert len(dimss) == 1
return dimss[0]
@property
def shape(self):
shapes = list(set(val.shape for val in self.values()))
if len(shapes) > 1:
assert set(map(len, shapes)) == {self.dims}
return tuple(s[0] if len(s) == 1 else None for s in map(list, map(set, zip(*shapes))))
return shapes[0]
def fuse(self, *keys, **kwargs):
op = kwargs.pop("op", np.concatenate)
assert not kwargs
return self.zip(*(self[self.keyify(key)] for key in keys)).map(op)
def split(self, thresh, right=False, rng=np.random):
if isinstance(thresh, int):
thresh = np.linspace(0, 1, num=thresh + 1)[1:-1]
if isinstance(thresh, float):
thresh = (thresh,)
thresh = np.array(thresh)
assert np.all((0 < thresh) & (thresh < 1))
idx = np.digitize(rng.uniform(size=self.blen), thresh, right=right)
return tuple(self.map(itemgetter(idx == i)) for i in range(len(thresh) + 1))
def shuffle(self, rng=np.random):
return self.map(itemgetter(rng.permutation(self.blen)))
def gen_feed_dict(self, tensor2key, batch_size=1024):
for sli in self.batch_slices(batch_size):
buf = {key: self[key][sli] for key in set(tensor2key.values())}
yield {tensor: buf[key] for tensor, key in tensor2key.items()}
def batch_slices(self, batch_size):
for i in range(0, self.blen, batch_size):
yield slice(i, min(i + batch_size, self.blen))
def random_indices(self, batch_size, rng=np.random):
return GetNextSlice(lambda: rng.permutation(self.blen)).iter(batch_size)
def batch_generator(self, batch_size, rng=np.random):
for s in self.random_indices(batch_size, rng=rng):
yield self.map(itemgetter(s))
def dataset(self, batch_size, rng=np.random):
return tf.data.Dataset.from_generator(
generator=self.batch_generator(batch_size, rng=rng),
output_types=self.map(lambda x: x.dtype),
output_shapes=self.map(lambda x: (None,) + x.shape[1:]),
)
def kfeed(self, x, y, w, **kwargs):
getter = itemgetter(x, y, w)
train, valid = self["train"], self["valid"]
return dict(
zip(["x", "y", "sample_weight"], getter(train)), validation_data=getter(valid), **kwargs
)
def balanced(self, *keys, **kwargs):
kref = kwargs.pop("kref", np.sum)
iref = kwargs.pop("iref", np.sum)
sums = {}
for key in keys:
s = self[key].map(np.sum)
s = iref(s.values()) if callable(iref) else s[iref]
if isinstance(s, dict):
s = np.sum(s.values())
sums[key] = s
ref = kref(sums.values()) if callable(kref) else sums[kref]
return self.__class__({k: self[k].map(lambda x: x * (ref / s)) for k, s in sums.items()})
@classmethod
def from_npy(cls, dir, sep="_", **kwargs):
return cls(
{
tuple(fn[:-4].split(sep)): np.load(path.join(dir, fn), **kwargs)
for fn in listdir(dir)
if fn.endswith(".npy")
}
)
def to_npy(self, dir, sep="_", clean=True, **kwargs):
for fn in listdir(dir):
if fn.endswith(".npy"):
remove(path.join(dir, fn))
for key, value in self.items():
np.save(path.join(dir, "%s.npy" % sep.join(sorted(key))), value, **kwargs)