From 2dad3d89f804289e20a7e2b898da9190eb2785df Mon Sep 17 00:00:00 2001 From: Dennis Noll <dennis.noll@rwth-aachen.de> Date: Fri, 26 Jun 2020 15:46:54 +0200 Subject: [PATCH] [data] DSS: dims - fixed typo --- data.py | 73 ++++++++++++++++++-------------------------------------- keras.py | 19 +++++++++++++-- 2 files changed, 40 insertions(+), 52 deletions(-) diff --git a/data.py b/data.py index 26a81e0..eab779f 100644 --- a/data.py +++ b/data.py @@ -8,7 +8,7 @@ class SKDict(dict): @staticmethod def keyify(keyish): if not isinstance(keyish, (tuple, list, set, frozenset)): - keyish = keyish, + keyish = (keyish,) keyish = frozenset(keyish) assert not any(isinstance(key, set) for key in keyish) return keyish @@ -19,7 +19,7 @@ class SKDict(dict): def update(self, *args, **kwargs): # assert 0 <= len(args) <= 1 - args += kwargs, + args += (kwargs,) for arg in args: for k, v in arg.items(): self[k] = v @@ -34,11 +34,7 @@ class SKDict(dict): key = self.keyify(key) if key in self: return super(SKDict, self).__getitem__(key) - ret = self.__class__({ - k - key: v - for k, v in self.items() - if key <= k - }) + ret = self.__class__({k - key: v for k, v in self.items() if key <= k}) if not ret: raise KeyError(key) return ret @@ -64,10 +60,7 @@ class SKDict(dict): assert all(isinstance(inst, cls) for inst in insts) keys = set() keys.update(*(inst.keys() for inst in insts)) - return cls({ - key: tuple(inst.get(key) for inst in insts) - for key in keys - }) + return cls({key: tuple(inst.get(key) for inst in insts) for key in keys}) def only(self, *keys): return self.__class__({key: self[key] for key in keys}) @@ -88,7 +81,7 @@ class SKDict(dict): assert len(keys) == 1 # bad depth return list(keys)[0] elif ads == {False}: - return (), + return ((),) else: raise RuntimeError("bad depth") @@ -97,10 +90,7 @@ class SKDict(dict): @property def pretty(self): - return { - "/".join(sorted(map(str, k))): v - for k, v in self.items() - } + return {"/".join(sorted(map(str, k))): v for k, v in self.items()} class GetNextSlice(object): @@ -113,7 +103,7 @@ class GetNextSlice(object): if self.curr is None: self.curr = self.next() self.pos = 0 - sli = self.curr[self.pos:self.pos + num] + sli = self.curr[self.pos : self.pos + num] self.pos += num if len(sli) < num: del self.curr @@ -143,7 +133,7 @@ class DSS(SKDict): @property def dims(self): - dimss = list(set(val.dims for val in self.values())) + dimss = list(set(val.ndim for val in self.values())) assert len(dimss) == 1 return dimss[0] @@ -152,45 +142,31 @@ class DSS(SKDict): shapes = list(set(val.shape for val in self.values())) if len(shapes) > 1: assert set(map(len, shapes)) == {self.dims} - return tuple( - s[0] if len(s) == 1 else None - for s in map(list, map(set, zip(*shapes))) - ) + return tuple(s[0] if len(s) == 1 else None for s in map(list, map(set, zip(*shapes)))) return shapes[0] def fuse(self, *keys, **kwargs): op = kwargs.pop("op", np.concatenate) assert not kwargs - return self.zip(*( - self[self.keyify(key)] for key in keys - )).map(op) + return self.zip(*(self[self.keyify(key)] for key in keys)).map(op) def split(self, thresh, right=False, rng=np.random): if isinstance(thresh, int): thresh = np.linspace(0, 1, num=thresh + 1)[1:-1] if isinstance(thresh, float): - thresh = thresh, + thresh = (thresh,) thresh = np.array(thresh) assert np.all((0 < thresh) & (thresh < 1)) idx = np.digitize(rng.uniform(size=self.blen), thresh, right=right) - return tuple( - self.map(itemgetter(idx == i)) - for i in range(len(thresh) + 1) - ) + return tuple(self.map(itemgetter(idx == i)) for i in range(len(thresh) + 1)) def shuffle(self, rng=np.random): return self.map(itemgetter(rng.permutation(self.blen))) def gen_feed_dict(self, tensor2key, batch_size=1024): for sli in self.batch_slices(batch_size): - buf = { - key: self[key][sli] - for key in set(tensor2key.values()) - } - yield { - tensor: buf[key] - for tensor, key in tensor2key.items() - } + buf = {key: self[key][sli] for key in set(tensor2key.values())} + yield {tensor: buf[key] for tensor, key in tensor2key.items()} def batch_slices(self, batch_size): for i in range(0, self.blen, batch_size): @@ -214,9 +190,7 @@ class DSS(SKDict): getter = itemgetter(x, y, w) train, valid = self["train"], self["valid"] return dict( - zip(["x", "y", "sample_weight"], getter(train)), - validation_data=getter(valid), - **kwargs + zip(["x", "y", "sample_weight"], getter(train)), validation_data=getter(valid), **kwargs ) def balanced(self, *keys, **kwargs): @@ -230,18 +204,17 @@ class DSS(SKDict): s = np.sum(s.values()) sums[key] = s ref = kref(sums.values()) if callable(kref) else sums[kref] - return self.__class__({ - k: self[k].map(lambda x: x * (ref / s)) - for k, s in sums.items() - }) + return self.__class__({k: self[k].map(lambda x: x * (ref / s)) for k, s in sums.items()}) @classmethod def from_npy(cls, dir, sep="_", **kwargs): - return cls({ - tuple(fn[:-4].split(sep)): np.load(path.join(dir, fn), **kwargs) - for fn in listdir(dir) - if fn.endswith(".npy") - }) + return cls( + { + tuple(fn[:-4].split(sep)): np.load(path.join(dir, fn), **kwargs) + for fn in listdir(dir) + if fn.endswith(".npy") + } + ) def to_npy(self, dir, sep="_", **kwargs): for key, value in self.items(): diff --git a/keras.py b/keras.py index 5500168..d328c26 100644 --- a/keras.py +++ b/keras.py @@ -136,6 +136,21 @@ class KFeed(object): assert "auto_steps" not in kwargs return self.gensteps(*args, **kwargs)[0] + def getShapes(self, src): + return tuple( + tuple(src[k].shape for k in g) + for g in self.xyw + ) + + def mkInputs(self, src, **kwargs): + return tuple( + tf.keras.Input(shape=ref.shape[1:], dtype=ref.dtype, **kwargs) + for ref in ( + src[x] + for x in self.xyw[0] + ) + ) + def Normal(ref, indices=None, name=None, **kwargs): """ @@ -168,8 +183,8 @@ def Onehot(index, n, name=None): x[...,(index+1):], ), axis=-1) return tf.keras.layers.Lambda(to_onehot, name=name) - - + + class Moment(tf.keras.metrics.Mean): def __init__(self, order, label=False, **kwargs): -- GitLab