Skip to content
Snippets Groups Projects
Commit ab02c03d authored by Dennis Noll's avatar Dennis Noll
Browse files

[dssutils] init prepare/apply train valid test

used to split dss object in train valid and test according to event number
first map prepare:
  1. duplicates data ([012]->[01201])
  2. saves correct chunk starts and stops:
    split 0: test=[0], valid=[1], train=[2]
    split 1: test=[1], valid=[2], train=[0]
    split 2: test=[2], valid=[0], train=[1]
then map apply
parent c176df5d
No related branches found
No related tags found
No related merge requests found
import numpy as np import numpy as np
from .data import DSS
def nan_to_zero(x): def nan_to_zero(x):
return x.map(lambda y: np.nan_to_num(y, nan=0, posinf=0, neginf=0)) return x.map(lambda y: np.nan_to_num(y, nan=0, posinf=0, neginf=0))
...@@ -52,3 +54,44 @@ def splitting(split, n_splits, identifier): ...@@ -52,3 +54,44 @@ def splitting(split, n_splits, identifier):
valid = (test + 1) % n valid = (test + 1) % n
func = splitting_function(identifier, n, test=test, valid=valid) func = splitting_function(identifier, n, test=test, valid=valid)
return func return func
def get_prepare_train_valid_test(n_splits, identifier):
def prepare_train_valid_test(x):
masks = ((x[identifier] % n_splits == i).flatten() for i in range(n_splits))
chunks = [mask(x, _mask) for _mask in masks]
chunks = chunks + chunks[:-1]
lenghts = [chunk.blen for chunk in chunks]
stops = np.cumsum(lenghts).tolist()
starts = [0] + stops[:-1]
train, valid, test = [], [], []
for i in range(n_splits):
test.append((starts[i], stops[i]))
valid.append((starts[i + 1], stops[i + 1]))
train.append((starts[i + 2], stops[i + (n_splits - 1)]))
data = DSS.concatenate(*chunks)
data["test"] = test
data["valid"] = valid
data["train"] = train
return data
return prepare_train_valid_test
def get_apply_train_valid_test(split):
def apply_train_valid_test(x):
train_start, train_stop = x["train"][split][0], x["train"][split][1]
valid_start, valid_stop = x["valid"][split][0], x["valid"][split][1]
test_start, test_stop = x["test"][split][0], x["test"][split][1]
return {
"train": x.map(lambda y: y[train_start:train_stop]),
"valid": x.map(lambda y: y[valid_start:valid_stop]),
"test": x.map(lambda y: y[test_start:test_stop]),
}
return apply_train_valid_test
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment