Newer
Older
from functools import partial
import numpy as np
import tensorflow as tf
from operator import itemgetter
from os import listdir, path, remove
class SKDict(dict):
@staticmethod
def keyify(keyish):
if not isinstance(keyish, (tuple, list, set, frozenset)):
keyish = frozenset(keyish)
assert not any(isinstance(key, set) for key in keyish)
return keyish
def __init__(self, *args, **kwargs):
super(SKDict, self).__init__()
self.update(dict(*args, **kwargs))
def update(self, *args, **kwargs):
# assert 0 <= len(args) <= 1
for arg in args:
for k, v in arg.items():
self[k] = v
def get(self, key, *args, **kwargs):
return super(SKDict, self).get(self.keyify(key), *args, **kwargs)
def pop(self, key, *args, **kwargs):
return super(SKDict, self).pop(self.keyify(key), *args, **kwargs)
def __getitem__(self, key):
key = self.keyify(key)
if key in self:
return super(SKDict, self).__getitem__(key)
ret = self.__class__({k - key: v for k, v in self.items() if key <= k})
if not ret:
raise KeyError(key)
return ret
def __setitem__(self, key, value):
key = self.keyify(key)
if isinstance(value, dict):
for k, v in value.items():
self[key | self.keyify(k)] = v
else:
super(SKDict, self).__setitem__(key, value)
def __delitem__(self, key):
key = self.keyify(key)
for k in self[key].keys():
super().__delitem__(k | key)
def copy(self):
return self.__class__(self)
def map(self, func, groups=None, prog=False):
if isinstance(prog, str):
prog = dict(desc=prog)
if isinstance(prog, dict):
prog = partial(tqdm, **prog)
elif not prog:
prog = lambda x: x
return self.__class__({g: func(self[g]) for g in map(self.keyify, prog(groups))})
@classmethod
def zip(cls, *insts):
assert all(isinstance(inst, cls) for inst in insts)
keys = set()
keys.update(*(inst.keys() for inst in insts))
return cls({key: tuple(inst.get(key) for inst in insts) for key in keys})
@classmethod
def concatenate(cls, *insts):
return cls.zip(*insts).map(np.concatenate)
def only(self, *keys):
return self.__class__({key: self[key] for key in keys})
def dekey(self, *key):
for k in self[key].keys():
yield self[k]
def items(self, *key):
for k in (self[key] if key else self).keys():
yield k, self[k]
def mkeys(self, keys):
data = [self[key] for key in keys]
ads = set(isinstance(d, self.__class__) for d in data)
if ads == {True}:
keys = set(frozenset(d.keys()) for d in data)
assert len(keys) == 1 # bad depth
return list(keys)[0]
elif ads == {False}:
else:
raise RuntimeError("bad depth")
def skeys(self):
return map(frozenset, sorted(map(sorted, self.keys())))
@property
def pretty(self):
return {"/".join(sorted(map(str, k))): v for k, v in self.items()}
class GetNextSlice(object):
curr = None
def __init__(self, next):
self.next = next
def __call__(self, num):
if self.curr is None:
self.curr = self.next()
self.pos = 0
self.pos += num
if len(sli) < num:
del self.curr
return sli
def iter(self, batch_size, fill=True, times=np.inf):
while times:
acc = self(batch_size)
while fill and len(acc) < batch_size:
acc = np.concatenate((acc, self(batch_size - len(acc))), axis=0)
times -= 1
yield acc
class DSS(SKDict):
@property
def blen(self):
lens = list(set(val.shape[0] for val in self.values()))
assert len(lens) == 1
return lens[0]
@property
def dtype(self):
dtypes = list(set(val.dtype for val in self.values()))
assert len(dtypes) == 1
return dtypes[0]
@property
def dims(self):
dimss = list(set(val.ndim for val in self.values()))
assert len(dimss) == 1
return dimss[0]
@property
def shape(self):
shapes = list(set(val.shape for val in self.values()))
if len(shapes) > 1:
assert set(map(len, shapes)) == {self.dims}
return tuple(s[0] if len(s) == 1 else None for s in map(list, map(set, zip(*shapes))))
return shapes[0]
def fuse(self, *keys, **kwargs):
op = kwargs.pop("op", np.concatenate)
assert not kwargs
return self.zip(*(self[self.keyify(key)] for key in keys)).map(op)
def split(self, thresh, right=False, rng=np.random):
if isinstance(thresh, int):
thresh = np.linspace(0, 1, num=thresh + 1)[1:-1]
if isinstance(thresh, float):
thresh = np.array(thresh)
assert np.all((0 < thresh) & (thresh < 1))
idx = np.digitize(rng.uniform(size=self.blen), thresh, right=right)
return tuple(self.map(itemgetter(idx == i)) for i in range(len(thresh) + 1))
def shuffle(self, rng=np.random):
return self.map(itemgetter(rng.permutation(self.blen)))
def gen_feed_dict(self, tensor2key, batch_size=1024):
for sli in self.batch_slices(batch_size):
buf = {key: self[key][sli] for key in set(tensor2key.values())}
yield {tensor: buf[key] for tensor, key in tensor2key.items()}
def batch_slices(self, batch_size):
for i in range(0, self.blen, batch_size):
yield slice(i, min(i + batch_size, self.blen))
def random_indices(self, batch_size, rng=np.random):
assert self.blen # blen needs to be greater than zero
return GetNextSlice(lambda: rng.permutation(self.blen)).iter(batch_size)
def batch_generator(self, batch_size, rng=np.random):
for s in self.random_indices(batch_size, rng=rng):
yield self.map(itemgetter(s))
def dataset(self, batch_size, rng=np.random):
return tf.data.Dataset.from_generator(
generator=self.batch_generator(batch_size, rng=rng),
output_types=self.map(lambda x: x.dtype),
output_shapes=self.map(lambda x: (None,) + x.shape[1:]),
)
def kfeed(self, x, y, w, **kwargs):
getter = itemgetter(x, y, w)
train, valid = self["train"], self["valid"]
return dict(
zip(["x", "y", "sample_weight"], getter(train)), validation_data=getter(valid), **kwargs
)
def balanced(self, *keys, **kwargs):
kref = kwargs.pop("kref", np.sum)
iref = kwargs.pop("iref", np.sum)
sums = {}
for key in keys:
s = self[key].map(np.sum)
s = iref(s.values()) if callable(iref) else s[iref]
if isinstance(s, dict):
s = np.sum(s.values())
sums[key] = s
ref = kref(sums.values()) if callable(kref) else sums[kref]
return self.__class__({k: self[k].map(lambda x: x * (ref / s)) for k, s in sums.items()})
@classmethod
def from_npy(cls, dir, sep="_", **kwargs):
return cls(
{
tuple(fn[:-4].split(sep)): np.load(path.join(dir, fn), **kwargs)
for fn in listdir(dir)
if fn.endswith(".npy")
}
)
def to_npy(self, dir, sep="_", clean=True, **kwargs):
for fn in listdir(dir):
if fn.endswith(".npy"):
remove(path.join(dir, fn))
for key, value in self.items():
np.save(path.join(dir, "%s.npy" % sep.join(sorted(key))), value, **kwargs)
class DSSDisk(DSS):
def __setitem__(self, key, value):
if isinstance(value, np.ndarray) and not isinstance(value, np.memmap):
with tempfile.NamedTemporaryFile() as tmp_file:
np.save(tmp_file, value)
value = np.load(tmp_file.name, mmap_mode="r+")
return super().__setitem__(key, value)