diff --git a/data.py b/data.py index ed5b3fc6505c49c28e67383bd1bffe0af88c8915..51873d68c8cb161f874b90f3c5259e7adbd65373 100644 --- a/data.py +++ b/data.py @@ -1,6 +1,7 @@ import numpy as np import tensorflow as tf from operator import itemgetter +from os import listdir, path class SKDict(dict): @@ -210,3 +211,11 @@ class DSS(SKDict): k: self[k].map(lambda x: x * (ref / s)) for k, s in sums.items() }) + + @classmethod + def from_npy(cls, dir, sep="_", **kwargs): + return cls({ + tuple(fn[:-4].split(sep)): np.load(path.join(dir, fn), **kwargs) + for fn in listdir(dir) + if fn.endswith(".npy") + })