Skip to content
Snippets Groups Projects
data.py 6.42 KiB
Newer Older
import numpy as np
import tensorflow as tf
from operator import itemgetter


class SKDict(dict):
    @staticmethod
    def keyify(keyish):
        if not isinstance(keyish, (tuple, list, set, frozenset)):
            keyish = keyish,
        keyish = frozenset(keyish)
        assert not any(isinstance(key, set) for key in keyish)
        return keyish

    def __init__(self, *args, **kwargs):
        super(SKDict, self).__init__()
        self.update(dict(*args, **kwargs))

    def update(self, *args, **kwargs):
        # assert 0 <= len(args) <= 1
        args += kwargs,
        for arg in args:
            for k, v in arg.items():
                self[k] = v

    def get(self, key, *args, **kwargs):
        return super(SKDict, self).get(self.keyify(key), *args, **kwargs)

    def pop(self, key, *args, **kwargs):
        return super(SKDict, self).pop(self.keyify(key), *args, **kwargs)

    def __getitem__(self, key):
        key = self.keyify(key)
        if key in self:
            return super(SKDict, self).__getitem__(key)
        ret = self.__class__({
            k - key: v
            for k, v in self.items()
            if key <= k
        })
        if not ret:
            raise KeyError(key)
        return ret

    def __setitem__(self, key, value):
        key = self.keyify(key)
        if isinstance(value, dict):
            for k, v in value.items():
                self[key | self.keyify(k)] = v
        else:
            super(SKDict, self).__setitem__(key, value)

    def copy(self):
        return self.__class__(self)

    def map(self, func, groups=None):
        if groups is None:
            groups = self.keys()
        return self.__class__({g: func(self[g]) for g in map(self.keyify, groups)})

    @classmethod
    def zip(cls, *insts):
        assert all(isinstance(inst, cls) for inst in insts)
        keys = set()
        keys.update(*(inst.keys() for inst in insts))
        return cls({
            key: tuple(inst.get(key) for inst in insts)
            for key in keys
        })

    def only(self, *keys):
        return self.__class__({key: self[key] for key in keys})

    def dekey(self, *key):
        for k in self[key].keys():
            yield self[k]

    def items(self, *key):
        for k in (self[key] if key else self).keys():
            yield k, self[k]

    def mkeys(self, keys):
        data = [self[key] for key in keys]
        ads = set(isinstance(d, self.__class__) for d in data)
        if ads == {True}:
            keys = set(frozenset(d.keys()) for d in data)
            assert len(keys) == 1  # bad depth
            return list(keys)[0]
        elif ads == {False}:
            return (),
        else:
            raise RuntimeError("bad depth")

    def skeys(self):
        return map(frozenset, sorted(map(sorted, self.keys())))

    @property
    def pretty(self):
        return {
            "/".join(sorted(map(str, k))): v
            for k, v in self.items()
        }


class GetNextSlice(object):
    curr = None

    def __init__(self, next):
        self.next = next

    def __call__(self, num):
        if self.curr is None:
            self.curr = self.next()
            self.pos = 0
        sli = self.curr[self.pos:self.pos + num]
        self.pos += num
        if len(sli) < num:
            del self.curr
        return sli

    def iter(self, batch_size, fill=True, times=np.inf):
        while times:
            acc = self(batch_size)
            while fill and len(acc) < batch_size:
                acc = np.concatenate((acc, self(batch_size - len(acc))), axis=0)
                times -= 1
            yield acc


class DSS(SKDict):
    @property
    def blen(self):
        lens = list(set(val.shape[0] for val in self.values()))
        assert len(lens) == 1
        return lens[0]

    def fuse(self, *keys, **kwargs):
        op = kwargs.pop("op", np.concatenate)
        assert not kwargs
        return self.zip(*(
            self[self.keyify(key)] for key in keys
        )).map(op)

    def split(self, thresh, right=False, rng=np.random):
        if isinstance(thresh, int):
            thresh = np.linspace(0, 1, num=thresh + 1)[1:-1]
        if isinstance(thresh, float):
            thresh = thresh,
        thresh = np.array(thresh)
        assert np.all((0 < thresh) & (thresh < 1))
        idx = np.digitize(rng.uniform(size=self.blen), thresh, right=right)
        return tuple(
            self.map(itemgetter(idx == i))
            for i in range(len(thresh) + 1)
        )

    def shuffle(self, rng=np.random):
        return self.map(itemgetter(rng.permutation(self.blen)))

    def gen_feed_dict(self, tensor2key, batch_size=1024):
        for sli in self.batch_slices(batch_size):
            buf = {
                key: self[key][sli]
                for key in set(tensor2key.values())
            }
            yield {
                tensor: buf[key]
                for tensor, key in tensor2key.items()
            }

    def batch_slices(self, batch_size):
        for i in range(0, self.blen, batch_size):
            yield slice(i, min(i + batch_size, self.blen))

    def random_indices(self, batch_size, rng=np.random):
        return GetNextSlice(lambda: rng.permutation(self.blen)).iter(batch_size)

    def batch_generator(self, batch_size, rng=np.random):
        for s in self.random_indices(batch_size, rng=rng):
            yield self.map(itemgetter(s))

    def dataset(self, batch_size, rng=np.random):
        return tf.data.Dataset.from_generator(
            generator=self.batch_generator(batch_size, rng=rng),
            output_types=self.map(lambda x: x.dtype),
            output_shapes=self.map(lambda x: (None,) + x.shape[1:]),
        )

    def kfeed(self, x, y, w, **kwargs):
        getter = itemgetter(x, y, w)
        train, valid = self["train"], self["valid"]
        return dict(
            zip(["x", "y", "sample_weight"], getter(train)),
            validation_data=getter(valid),
            **kwargs
        )

    def balanced(self, *keys, **kwargs):
        kref = kwargs.pop("kref", np.sum)
        iref = kwargs.pop("iref", np.sum)
        sums = {}
        for key in keys:
            s = self[key].map(np.sum)
            s = iref(s.values()) if callable(iref) else s[iref]
            if isinstance(s, dict):
                s = np.sum(s.values())
            sums[key] = s
        ref = kref(sums.values()) if callable(kref) else sums[kref]
        return self.__class__({
            k: self[k].map(lambda x: x * (ref / s))
            for k, s in sums.items()
        })