From 2dad3d89f804289e20a7e2b898da9190eb2785df Mon Sep 17 00:00:00 2001
From: Dennis Noll <dennis.noll@rwth-aachen.de>
Date: Fri, 26 Jun 2020 15:46:54 +0200
Subject: [PATCH] [data] DSS: dims - fixed typo

---
 data.py  | 73 ++++++++++++++++++--------------------------------------
 keras.py | 19 +++++++++++++--
 2 files changed, 40 insertions(+), 52 deletions(-)

diff --git a/data.py b/data.py
index 26a81e0..eab779f 100644
--- a/data.py
+++ b/data.py
@@ -8,7 +8,7 @@ class SKDict(dict):
     @staticmethod
     def keyify(keyish):
         if not isinstance(keyish, (tuple, list, set, frozenset)):
-            keyish = keyish,
+            keyish = (keyish,)
         keyish = frozenset(keyish)
         assert not any(isinstance(key, set) for key in keyish)
         return keyish
@@ -19,7 +19,7 @@ class SKDict(dict):
 
     def update(self, *args, **kwargs):
         # assert 0 <= len(args) <= 1
-        args += kwargs,
+        args += (kwargs,)
         for arg in args:
             for k, v in arg.items():
                 self[k] = v
@@ -34,11 +34,7 @@ class SKDict(dict):
         key = self.keyify(key)
         if key in self:
             return super(SKDict, self).__getitem__(key)
-        ret = self.__class__({
-            k - key: v
-            for k, v in self.items()
-            if key <= k
-        })
+        ret = self.__class__({k - key: v for k, v in self.items() if key <= k})
         if not ret:
             raise KeyError(key)
         return ret
@@ -64,10 +60,7 @@ class SKDict(dict):
         assert all(isinstance(inst, cls) for inst in insts)
         keys = set()
         keys.update(*(inst.keys() for inst in insts))
-        return cls({
-            key: tuple(inst.get(key) for inst in insts)
-            for key in keys
-        })
+        return cls({key: tuple(inst.get(key) for inst in insts) for key in keys})
 
     def only(self, *keys):
         return self.__class__({key: self[key] for key in keys})
@@ -88,7 +81,7 @@ class SKDict(dict):
             assert len(keys) == 1  # bad depth
             return list(keys)[0]
         elif ads == {False}:
-            return (),
+            return ((),)
         else:
             raise RuntimeError("bad depth")
 
@@ -97,10 +90,7 @@ class SKDict(dict):
 
     @property
     def pretty(self):
-        return {
-            "/".join(sorted(map(str, k))): v
-            for k, v in self.items()
-        }
+        return {"/".join(sorted(map(str, k))): v for k, v in self.items()}
 
 
 class GetNextSlice(object):
@@ -113,7 +103,7 @@ class GetNextSlice(object):
         if self.curr is None:
             self.curr = self.next()
             self.pos = 0
-        sli = self.curr[self.pos:self.pos + num]
+        sli = self.curr[self.pos : self.pos + num]
         self.pos += num
         if len(sli) < num:
             del self.curr
@@ -143,7 +133,7 @@ class DSS(SKDict):
 
     @property
     def dims(self):
-        dimss = list(set(val.dims for val in self.values()))
+        dimss = list(set(val.ndim for val in self.values()))
         assert len(dimss) == 1
         return dimss[0]
 
@@ -152,45 +142,31 @@ class DSS(SKDict):
         shapes = list(set(val.shape for val in self.values()))
         if len(shapes) > 1:
             assert set(map(len, shapes)) == {self.dims}
-            return tuple(
-                s[0] if len(s) == 1 else None
-                for s in map(list, map(set, zip(*shapes)))
-            )
+            return tuple(s[0] if len(s) == 1 else None for s in map(list, map(set, zip(*shapes))))
         return shapes[0]
 
     def fuse(self, *keys, **kwargs):
         op = kwargs.pop("op", np.concatenate)
         assert not kwargs
-        return self.zip(*(
-            self[self.keyify(key)] for key in keys
-        )).map(op)
+        return self.zip(*(self[self.keyify(key)] for key in keys)).map(op)
 
     def split(self, thresh, right=False, rng=np.random):
         if isinstance(thresh, int):
             thresh = np.linspace(0, 1, num=thresh + 1)[1:-1]
         if isinstance(thresh, float):
-            thresh = thresh,
+            thresh = (thresh,)
         thresh = np.array(thresh)
         assert np.all((0 < thresh) & (thresh < 1))
         idx = np.digitize(rng.uniform(size=self.blen), thresh, right=right)
-        return tuple(
-            self.map(itemgetter(idx == i))
-            for i in range(len(thresh) + 1)
-        )
+        return tuple(self.map(itemgetter(idx == i)) for i in range(len(thresh) + 1))
 
     def shuffle(self, rng=np.random):
         return self.map(itemgetter(rng.permutation(self.blen)))
 
     def gen_feed_dict(self, tensor2key, batch_size=1024):
         for sli in self.batch_slices(batch_size):
-            buf = {
-                key: self[key][sli]
-                for key in set(tensor2key.values())
-            }
-            yield {
-                tensor: buf[key]
-                for tensor, key in tensor2key.items()
-            }
+            buf = {key: self[key][sli] for key in set(tensor2key.values())}
+            yield {tensor: buf[key] for tensor, key in tensor2key.items()}
 
     def batch_slices(self, batch_size):
         for i in range(0, self.blen, batch_size):
@@ -214,9 +190,7 @@ class DSS(SKDict):
         getter = itemgetter(x, y, w)
         train, valid = self["train"], self["valid"]
         return dict(
-            zip(["x", "y", "sample_weight"], getter(train)),
-            validation_data=getter(valid),
-            **kwargs
+            zip(["x", "y", "sample_weight"], getter(train)), validation_data=getter(valid), **kwargs
         )
 
     def balanced(self, *keys, **kwargs):
@@ -230,18 +204,17 @@ class DSS(SKDict):
                 s = np.sum(s.values())
             sums[key] = s
         ref = kref(sums.values()) if callable(kref) else sums[kref]
-        return self.__class__({
-            k: self[k].map(lambda x: x * (ref / s))
-            for k, s in sums.items()
-        })
+        return self.__class__({k: self[k].map(lambda x: x * (ref / s)) for k, s in sums.items()})
 
     @classmethod
     def from_npy(cls, dir, sep="_", **kwargs):
-        return cls({
-            tuple(fn[:-4].split(sep)): np.load(path.join(dir, fn), **kwargs)
-            for fn in listdir(dir)
-            if fn.endswith(".npy")
-        })
+        return cls(
+            {
+                tuple(fn[:-4].split(sep)): np.load(path.join(dir, fn), **kwargs)
+                for fn in listdir(dir)
+                if fn.endswith(".npy")
+            }
+        )
 
     def to_npy(self, dir, sep="_", **kwargs):
         for key, value in self.items():
diff --git a/keras.py b/keras.py
index 5500168..d328c26 100644
--- a/keras.py
+++ b/keras.py
@@ -136,6 +136,21 @@ class KFeed(object):
         assert "auto_steps" not in kwargs
         return self.gensteps(*args, **kwargs)[0]
 
+    def getShapes(self, src):
+        return tuple(
+            tuple(src[k].shape for k in g)
+            for g in self.xyw
+        )
+
+    def mkInputs(self, src, **kwargs):
+        return tuple(
+            tf.keras.Input(shape=ref.shape[1:], dtype=ref.dtype, **kwargs)
+            for ref in (
+                src[x]
+                for x in self.xyw[0]
+            )
+        )
+
 
 def Normal(ref, indices=None, name=None, **kwargs):
     """
@@ -168,8 +183,8 @@ def Onehot(index, n, name=None):
                           x[...,(index+1):],
                          ), axis=-1)
     return tf.keras.layers.Lambda(to_onehot, name=name)
-    
-    
+
+
 class Moment(tf.keras.metrics.Mean):
 
     def __init__(self, order, label=False, **kwargs):
-- 
GitLab