diff --git a/data.py b/data.py index 6044750490ab965826e7dfa2a7646858bbff4959..24f8222d6b9d9ea5c7f0b232853c7e17a1d6a508 100644 --- a/data.py +++ b/data.py @@ -67,6 +67,10 @@ class SKDict(dict): keys.update(*(inst.keys() for inst in insts)) return cls({key: tuple(inst.get(key) for inst in insts) for key in keys}) + @classmethod + def concatenate(cls, *insts): + return cls.zip(*insts).map(lambda x: np.concatenate(x)) + def only(self, *keys): return self.__class__({key: self[key] for key in keys})