diff --git a/dssutils.py b/dssutils.py new file mode 100644 index 0000000000000000000000000000000000000000..74326971ff4b799f522ef7b54cdc9c15c9feae86 --- /dev/null +++ b/dssutils.py @@ -0,0 +1,46 @@ +import numpy as np + +def nan_to_zero(x): + return x.map(lambda y: np.nan_to_num(y, nan=0, posinf=0, neginf=0)) + +def abs(x): + return np.absolute(x) * np.sum(x) / np.sum(np.absolute(x)) + +def cut(x): + return np.where(x < 0, 0, x) + +def norm(x): + return x / np.mean(x) + +def mask(x, mask): + return x.map(lambda y: y[mask]) + +def splitting_function(identifier, n, test=0, valid=0): + assert test != valid + def func(x): + test_mask = (x[identifier] % n == test).flatten() + valid_mask = (x[identifier] % n == valid).flatten() + train_mask = ~(test_mask + valid_mask) + out = { + "test": mask(x, test_mask.flatten()), + "valid": mask(x, valid_mask.flatten()), + "train": mask(x, train_mask.flatten()), + } + return out + return func + + +def splitting(split, n_splits, identifier): + if split == -1: + return lambda x: { + "valid": x, + "train": x, + } + else: + n = n_splits + assert split < n + + test = split + valid = (test + 1) % n + func = splitting_function(identifier, n, test=test, valid=valid) + return func