from functools import partial import tempfile import numpy as np import tensorflow as tf from operator import itemgetter from os import listdir, path, remove from tqdm.auto import tqdm class SKDict(dict): @staticmethod def keyify(keyish): if not isinstance(keyish, (tuple, list, set, frozenset)): keyish = (keyish,) keyish = frozenset(keyish) assert not any(isinstance(key, set) for key in keyish) return keyish def __init__(self, *args, **kwargs): super(SKDict, self).__init__() self.update(dict(*args, **kwargs)) def update(self, *args, **kwargs): # assert 0 <= len(args) <= 1 args += (kwargs,) for arg in args: for k, v in arg.items(): self[k] = v def get(self, key, *args, **kwargs): return super(SKDict, self).get(self.keyify(key), *args, **kwargs) def pop(self, key, *args, **kwargs): return super(SKDict, self).pop(self.keyify(key), *args, **kwargs) def __getitem__(self, key): key = self.keyify(key) if key in self: return super(SKDict, self).__getitem__(key) ret = self.__class__({k - key: v for k, v in self.items() if key <= k}) if not ret: raise KeyError(key) return ret def __setitem__(self, key, value): key = self.keyify(key) if isinstance(value, dict): for k, v in value.items(): self[key | self.keyify(k)] = v else: super(SKDict, self).__setitem__(key, value) def __delitem__(self, key): key = self.keyify(key) for k in self[key].keys(): super().__delitem__(k | key) def copy(self): return self.__class__(self) def map(self, func, groups=None, prog=False): if groups is None: groups = self.keys() if isinstance(prog, str): prog = dict(desc=prog) if isinstance(prog, dict): prog = partial(tqdm, **prog) elif not prog: prog = lambda x: x return self.__class__({g: func(self[g]) for g in map(self.keyify, prog(groups))}) @classmethod def zip(cls, *insts): assert all(isinstance(inst, cls) for inst in insts) keys = set() keys.update(*(inst.keys() for inst in insts)) return cls({key: tuple(inst.get(key) for inst in insts) for key in keys}) @classmethod def concatenate(cls, *insts): return cls.zip(*insts).map(np.concatenate) def only(self, *keys): return self.__class__({key: self[key] for key in keys}) def dekey(self, *key): for k in self[key].keys(): yield self[k] def items(self, *key): for k in (self[key] if key else self).keys(): yield k, self[k] def mkeys(self, keys): data = [self[key] for key in keys] ads = set(isinstance(d, self.__class__) for d in data) if ads == {True}: keys = set(frozenset(d.keys()) for d in data) assert len(keys) == 1 # bad depth return list(keys)[0] elif ads == {False}: return ((),) else: raise RuntimeError("bad depth") def skeys(self): return map(frozenset, sorted(map(sorted, self.keys()))) @property def pretty(self): return {"/".join(sorted(map(str, k))): v for k, v in self.items()} class GetNextSlice(object): curr = None def __init__(self, next): self.next = next def __call__(self, num): if self.curr is None: self.curr = self.next() self.pos = 0 sli = self.curr[self.pos : self.pos + num] self.pos += num if len(sli) < num: del self.curr return sli def iter(self, batch_size, fill=True, times=np.inf): while times: acc = self(batch_size) while fill and len(acc) < batch_size: acc = np.concatenate((acc, self(batch_size - len(acc))), axis=0) times -= 1 yield acc class DSS(SKDict): @property def blen(self): lens = list(set(val.shape[0] for val in self.values())) assert len(lens) == 1 return lens[0] @property def dtype(self): dtypes = list(set(val.dtype for val in self.values())) assert len(dtypes) == 1 return dtypes[0] @property def dims(self): dimss = list(set(val.ndim for val in self.values())) assert len(dimss) == 1 return dimss[0] @property def shape(self): shapes = list(set(val.shape for val in self.values())) if len(shapes) > 1: assert set(map(len, shapes)) == {self.dims} return tuple(s[0] if len(s) == 1 else None for s in map(list, map(set, zip(*shapes)))) return shapes[0] def fuse(self, *keys, **kwargs): op = kwargs.pop("op", np.concatenate) assert not kwargs return self.zip(*(self[self.keyify(key)] for key in keys)).map(op) def split(self, thresh, right=False, rng=np.random): if isinstance(thresh, int): thresh = np.linspace(0, 1, num=thresh + 1)[1:-1] if isinstance(thresh, float): thresh = (thresh,) thresh = np.array(thresh) assert np.all((0 < thresh) & (thresh < 1)) idx = np.digitize(rng.uniform(size=self.blen), thresh, right=right) return tuple(self.map(itemgetter(idx == i)) for i in range(len(thresh) + 1)) def shuffle(self, rng=np.random): return self.map(itemgetter(rng.permutation(self.blen))) def gen_feed_dict(self, tensor2key, batch_size=1024): for sli in self.batch_slices(batch_size): buf = {key: self[key][sli] for key in set(tensor2key.values())} yield {tensor: buf[key] for tensor, key in tensor2key.items()} def batch_slices(self, batch_size): for i in range(0, self.blen, batch_size): yield slice(i, min(i + batch_size, self.blen)) def random_indices(self, batch_size, rng=np.random): assert self.blen # blen needs to be greater than zero return GetNextSlice(lambda: rng.permutation(self.blen)).iter(batch_size) def batch_generator(self, batch_size, rng=np.random): for s in self.random_indices(batch_size, rng=rng): yield self.map(itemgetter(s)) def dataset(self, batch_size, rng=np.random): return tf.data.Dataset.from_generator( generator=self.batch_generator(batch_size, rng=rng), output_types=self.map(lambda x: x.dtype), output_shapes=self.map(lambda x: (None,) + x.shape[1:]), ) def kfeed(self, x, y, w, **kwargs): getter = itemgetter(x, y, w) train, valid = self["train"], self["valid"] return dict( zip(["x", "y", "sample_weight"], getter(train)), validation_data=getter(valid), **kwargs ) def balanced(self, *keys, **kwargs): kref = kwargs.pop("kref", np.sum) iref = kwargs.pop("iref", np.sum) sums = {} for key in keys: s = self[key].map(np.sum) s = iref(s.values()) if callable(iref) else s[iref] if isinstance(s, dict): s = np.sum(s.values()) sums[key] = s ref = kref(sums.values()) if callable(kref) else sums[kref] return self.__class__({k: self[k].map(lambda x: x * (ref / s)) for k, s in sums.items()}) @classmethod def from_npy(cls, dir, sep="_", **kwargs): return cls( { tuple(fn[:-4].split(sep)): np.load(path.join(dir, fn), **kwargs) for fn in listdir(dir) if fn.endswith(".npy") } ) def to_npy(self, dir, sep="_", clean=True, **kwargs): for fn in listdir(dir): if fn.endswith(".npy"): remove(path.join(dir, fn)) for key, value in self.items(): np.save(path.join(dir, "%s.npy" % sep.join(sorted(key))), value, **kwargs) class DSSDisk(DSS): def __setitem__(self, key, value): if isinstance(value, np.ndarray) and not isinstance(value, np.memmap): with tempfile.NamedTemporaryFile() as tmp_file: np.save(tmp_file, value) tmp_file.flush() value = np.load(tmp_file.name, mmap_mode="r+") return super().__setitem__(key, value)