Skip to content
Snippets Groups Projects
Commit c176df5d authored by Dennis Noll's avatar Dennis Noll
Browse files

[dssutils] black

parent 24fe587e
No related branches found
No related tags found
No related merge requests found
import numpy as np import numpy as np
def nan_to_zero(x): def nan_to_zero(x):
return x.map(lambda y: np.nan_to_num(y, nan=0, posinf=0, neginf=0)) return x.map(lambda y: np.nan_to_num(y, nan=0, posinf=0, neginf=0))
def abs(x): def abs(x):
return np.absolute(x) * np.sum(x) / np.sum(np.absolute(x)) return np.absolute(x) * np.sum(x) / np.sum(np.absolute(x))
def cut(x): def cut(x):
return np.where(x < 0, 0, x) return np.where(x < 0, 0, x)
def norm(x): def norm(x):
return x / np.mean(x) return x / np.mean(x)
def mask(x, mask): def mask(x, mask):
return x.map(lambda y: y[mask]) return x.map(lambda y: y[mask])
def splitting_function(identifier, n, test=0, valid=0): def splitting_function(identifier, n, test=0, valid=0):
assert test != valid assert test != valid
def func(x): def func(x):
test_mask = (x[identifier] % n == test).flatten() test_mask = (x[identifier] % n == test).flatten()
valid_mask = (x[identifier] % n == valid).flatten() valid_mask = (x[identifier] % n == valid).flatten()
...@@ -27,15 +34,16 @@ def splitting_function(identifier, n, test=0, valid=0): ...@@ -27,15 +34,16 @@ def splitting_function(identifier, n, test=0, valid=0):
"train": mask(x, train_mask.flatten()), "train": mask(x, train_mask.flatten()),
} }
return out return out
return func return func
def splitting(split, n_splits, identifier): def splitting(split, n_splits, identifier):
if split == -1: if split == -1:
return lambda x: { return lambda x: {
"valid": x, "valid": x,
"train": x, "train": x,
} }
else: else:
n = n_splits n = n_splits
assert split < n assert split < n
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment