From 224c28f635a8d4182f6c534367ecfdd09abce43d Mon Sep 17 00:00:00 2001
From: Dennis Noll <dennis.noll@rwth-aachen.de>
Date: Fri, 26 Feb 2021 15:14:57 +0100
Subject: [PATCH] [dssutils] init dssutils functions which are helpful for ML
 trainings

---
 dssutils.py | 46 ++++++++++++++++++++++++++++++++++++++++++++++
 1 file changed, 46 insertions(+)
 create mode 100644 dssutils.py

diff --git a/dssutils.py b/dssutils.py
new file mode 100644
index 0000000..7432697
--- /dev/null
+++ b/dssutils.py
@@ -0,0 +1,46 @@
+import numpy as np
+
+def nan_to_zero(x):
+    return x.map(lambda y: np.nan_to_num(y, nan=0, posinf=0, neginf=0))
+
+def abs(x):
+    return np.absolute(x) * np.sum(x) / np.sum(np.absolute(x))
+
+def cut(x):
+    return np.where(x < 0, 0, x)
+
+def norm(x):
+    return x / np.mean(x)
+
+def mask(x, mask):
+    return x.map(lambda y: y[mask])
+
+def splitting_function(identifier, n, test=0, valid=0):
+    assert test != valid
+    def func(x):
+        test_mask = (x[identifier] % n == test).flatten()
+        valid_mask = (x[identifier] % n == valid).flatten()
+        train_mask = ~(test_mask + valid_mask)
+        out = {
+            "test": mask(x, test_mask.flatten()),
+            "valid": mask(x, valid_mask.flatten()),
+            "train": mask(x, train_mask.flatten()),
+        }
+        return out
+    return func
+
+
+def splitting(split, n_splits, identifier):
+    if split == -1:
+        return lambda x: {
+                "valid": x,
+                "train": x,
+            }
+    else:
+        n = n_splits
+        assert split < n
+
+        test = split
+        valid = (test + 1) % n
+        func = splitting_function(identifier, n, test=test, valid=valid)
+        return func
-- 
GitLab