From 224c28f635a8d4182f6c534367ecfdd09abce43d Mon Sep 17 00:00:00 2001 From: Dennis Noll <dennis.noll@rwth-aachen.de> Date: Fri, 26 Feb 2021 15:14:57 +0100 Subject: [PATCH] [dssutils] init dssutils functions which are helpful for ML trainings --- dssutils.py | 46 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) create mode 100644 dssutils.py diff --git a/dssutils.py b/dssutils.py new file mode 100644 index 0000000..7432697 --- /dev/null +++ b/dssutils.py @@ -0,0 +1,46 @@ +import numpy as np + +def nan_to_zero(x): + return x.map(lambda y: np.nan_to_num(y, nan=0, posinf=0, neginf=0)) + +def abs(x): + return np.absolute(x) * np.sum(x) / np.sum(np.absolute(x)) + +def cut(x): + return np.where(x < 0, 0, x) + +def norm(x): + return x / np.mean(x) + +def mask(x, mask): + return x.map(lambda y: y[mask]) + +def splitting_function(identifier, n, test=0, valid=0): + assert test != valid + def func(x): + test_mask = (x[identifier] % n == test).flatten() + valid_mask = (x[identifier] % n == valid).flatten() + train_mask = ~(test_mask + valid_mask) + out = { + "test": mask(x, test_mask.flatten()), + "valid": mask(x, valid_mask.flatten()), + "train": mask(x, train_mask.flatten()), + } + return out + return func + + +def splitting(split, n_splits, identifier): + if split == -1: + return lambda x: { + "valid": x, + "train": x, + } + else: + n = n_splits + assert split < n + + test = split + valid = (test + 1) % n + func = splitting_function(identifier, n, test=test, valid=valid) + return func -- GitLab