Skip to content
Snippets Groups Projects
data.py 8.28 KiB
Newer Older
from functools import partial
import tempfile
import numpy as np
import tensorflow as tf
from operator import itemgetter
from os import listdir, path, remove
from tqdm.auto import tqdm


class SKDict(dict):
    @staticmethod
    def keyify(keyish):
        if not isinstance(keyish, (tuple, list, set, frozenset)):
            keyish = (keyish,)
        keyish = frozenset(keyish)
        assert not any(isinstance(key, set) for key in keyish)
        return keyish

    def __init__(self, *args, **kwargs):
        super(SKDict, self).__init__()
        self.update(dict(*args, **kwargs))

    def update(self, *args, **kwargs):
        # assert 0 <= len(args) <= 1
        args += (kwargs,)
        for arg in args:
            for k, v in arg.items():
                self[k] = v

    def get(self, key, *args, **kwargs):
        return super(SKDict, self).get(self.keyify(key), *args, **kwargs)

    def pop(self, key, *args, **kwargs):
        return super(SKDict, self).pop(self.keyify(key), *args, **kwargs)

    def __getitem__(self, key):
        key = self.keyify(key)
        if key in self:
            return super(SKDict, self).__getitem__(key)
        ret = self.__class__({k - key: v for k, v in self.items() if key <= k})
        if not ret:
            raise KeyError(key)
        return ret

    def __setitem__(self, key, value):
        key = self.keyify(key)
        if isinstance(value, dict):
            for k, v in value.items():
                self[key | self.keyify(k)] = v
        else:
            super(SKDict, self).__setitem__(key, value)

    def __delitem__(self, key):
        key = self.keyify(key)
        for k in self[key].keys():
            super().__delitem__(k | key)

    def copy(self):
        return self.__class__(self)

    def map(self, func, groups=None, prog=False):
        if groups is None:
            groups = self.keys()
        if isinstance(prog, str):
            prog = dict(desc=prog)
        if isinstance(prog, dict):
            prog = partial(tqdm, **prog)
        elif not prog:
            prog = lambda x: x
        return self.__class__({g: func(self[g]) for g in map(self.keyify, prog(groups))})

    @classmethod
    def zip(cls, *insts):
        assert all(isinstance(inst, cls) for inst in insts)
        keys = set()
        keys.update(*(inst.keys() for inst in insts))
        return cls({key: tuple(inst.get(key) for inst in insts) for key in keys})
    @classmethod
    def concatenate(cls, *insts):
        return cls.zip(*insts).map(np.concatenate)
    def only(self, *keys):
        return self.__class__({key: self[key] for key in keys})

    def dekey(self, *key):
        for k in self[key].keys():
            yield self[k]

    def items(self, *key):
        for k in (self[key] if key else self).keys():
            yield k, self[k]

    def mkeys(self, keys):
        data = [self[key] for key in keys]
        ads = set(isinstance(d, self.__class__) for d in data)
        if ads == {True}:
            keys = set(frozenset(d.keys()) for d in data)
            assert len(keys) == 1  # bad depth
            return list(keys)[0]
        elif ads == {False}:
            return ((),)
        else:
            raise RuntimeError("bad depth")

    def skeys(self):
        return map(frozenset, sorted(map(sorted, self.keys())))

    @property
    def pretty(self):
        return {"/".join(sorted(map(str, k))): v for k, v in self.items()}


class GetNextSlice(object):
    curr = None

    def __init__(self, next):
        self.next = next

    def __call__(self, num):
        if self.curr is None:
            self.curr = self.next()
            self.pos = 0
        sli = self.curr[self.pos : self.pos + num]
        self.pos += num
        if len(sli) < num:
            del self.curr
        return sli

    def iter(self, batch_size, fill=True, times=np.inf):
        while times:
            acc = self(batch_size)
            while fill and len(acc) < batch_size:
                acc = np.concatenate((acc, self(batch_size - len(acc))), axis=0)
                times -= 1
            yield acc


class DSS(SKDict):
    @property
    def blen(self):
        lens = list(set(val.shape[0] for val in self.values()))
        assert len(lens) == 1
        return lens[0]

    @property
    def dtype(self):
        dtypes = list(set(val.dtype for val in self.values()))
        assert len(dtypes) == 1
        return dtypes[0]

    @property
    def dims(self):
        dimss = list(set(val.ndim for val in self.values()))
        assert len(dimss) == 1
        return dimss[0]

    @property
    def shape(self):
        shapes = list(set(val.shape for val in self.values()))
        if len(shapes) > 1:
            assert set(map(len, shapes)) == {self.dims}
            return tuple(s[0] if len(s) == 1 else None for s in map(list, map(set, zip(*shapes))))
    def fuse(self, *keys, **kwargs):
        op = kwargs.pop("op", np.concatenate)
        assert not kwargs
        return self.zip(*(self[self.keyify(key)] for key in keys)).map(op)

    def split(self, thresh, right=False, rng=np.random):
        if isinstance(thresh, int):
            thresh = np.linspace(0, 1, num=thresh + 1)[1:-1]
        if isinstance(thresh, float):
            thresh = (thresh,)
        thresh = np.array(thresh)
        assert np.all((0 < thresh) & (thresh < 1))
        idx = np.digitize(rng.uniform(size=self.blen), thresh, right=right)
        return tuple(self.map(itemgetter(idx == i)) for i in range(len(thresh) + 1))

    def shuffle(self, rng=np.random):
        return self.map(itemgetter(rng.permutation(self.blen)))

    def gen_feed_dict(self, tensor2key, batch_size=1024):
        for sli in self.batch_slices(batch_size):
            buf = {key: self[key][sli] for key in set(tensor2key.values())}
            yield {tensor: buf[key] for tensor, key in tensor2key.items()}

    def batch_slices(self, batch_size):
        for i in range(0, self.blen, batch_size):
            yield slice(i, min(i + batch_size, self.blen))

    def random_indices(self, batch_size, rng=np.random):
        assert self.blen  # blen needs to be greater than zero
        return GetNextSlice(lambda: rng.permutation(self.blen)).iter(batch_size)

    def batch_generator(self, batch_size, rng=np.random):
        for s in self.random_indices(batch_size, rng=rng):
            yield self.map(itemgetter(s))

    def dataset(self, batch_size, rng=np.random):
        return tf.data.Dataset.from_generator(
            generator=self.batch_generator(batch_size, rng=rng),
            output_types=self.map(lambda x: x.dtype),
            output_shapes=self.map(lambda x: (None,) + x.shape[1:]),
        )

    def kfeed(self, x, y, w, **kwargs):
        getter = itemgetter(x, y, w)
        train, valid = self["train"], self["valid"]
        return dict(
            zip(["x", "y", "sample_weight"], getter(train)), validation_data=getter(valid), **kwargs
        )

    def balanced(self, *keys, **kwargs):
        kref = kwargs.pop("kref", np.sum)
        iref = kwargs.pop("iref", np.sum)
        sums = {}
        for key in keys:
            s = self[key].map(np.sum)
            s = iref(s.values()) if callable(iref) else s[iref]
            if isinstance(s, dict):
                s = np.sum(s.values())
            sums[key] = s
        ref = kref(sums.values()) if callable(kref) else sums[kref]
        return self.__class__({k: self[k].map(lambda x: x * (ref / s)) for k, s in sums.items()})

    @classmethod
    def from_npy(cls, dir, sep="_", **kwargs):
        return cls(
            {
                tuple(fn[:-4].split(sep)): np.load(path.join(dir, fn), **kwargs)
                for fn in listdir(dir)
                if fn.endswith(".npy")
            }
        )
    def to_npy(self, dir, sep="_", clean=True, **kwargs):
        for fn in listdir(dir):
            if fn.endswith(".npy"):
                remove(path.join(dir, fn))
Benjamin Fischer's avatar
Benjamin Fischer committed
        for key, value in self.items():
            np.save(path.join(dir, "%s.npy" % sep.join(sorted(key))), value, **kwargs)
Dennis Noll's avatar
Dennis Noll committed


class DSSDisk(DSS):
    def __setitem__(self, key, value):
        if isinstance(value, np.ndarray) and not isinstance(value, np.memmap):
            with tempfile.NamedTemporaryFile() as tmp_file:
                np.save(tmp_file, value)
                tmp_file.flush()
Dennis Noll's avatar
Dennis Noll committed
                value = np.load(tmp_file.name, mmap_mode="r+")
        return super().__setitem__(key, value)