# -*- coding: utf-8 -*-

import numpy as np


def set_thresh(values, thresh=1e-5):
    val = values.copy()
    val[val < thresh] = 0
    return val


def one_hot(a, n=None):
    if n is None:
        n = a.max() + 1
    o_h = np.zeros((a.size, n))
    o_h[np.arange(a.size), a] = 1
    return o_h


def array_to_hist(values, bins):
    hist = np.histogram2d([0.0], [0.0], bins=bins)
    hist = list(hist)
    hist[0] = values
    hist = tuple(hist)
    return hist


def swap_rows(a, swp, roll=1, swp_to=None):
    """Swaps rows in the last dimension of a.
        The rows to swap are specified by the swp argument.
        If swp_to is defined, the rows to swap the rows from swp
        in can be directly specified, if not, the swapping is done
        in a rolling manner.
    Returns:
        Array with swapped rows
    """
    if not (swp_to):
        a[..., swp] = a[..., np.roll(swp, roll)]
        return a
    else:
        a[..., swp] = a[..., swp_to]
        return a


def create_swaps(swap_length, circle, length):
    start = -swap_length if circle else 0
    return np.column_stack(
        (np.arange(start, length - swap_length), np.arange(start + swap_length, length))
    )


def _axis_clip(ref, axis):
    assert -ref <= axis < ref
    axis = ref + axis if axis < 0 else axis
    return axis, ref - axis - 1


def xsel(source, saxis, indices, iaxis=-1):
    return source[xsel_mask(source, saxis, indices, iaxis)]


def xsel_mask(source, saxis, indices, iaxis=-1):
    saxis, safter = _axis_clip(source.ndim, saxis)
    iaxis, iafter = _axis_clip(indices.ndim, iaxis)
    assert iaxis <= saxis
    assert iafter <= safter
    indices = indices[((None,) * (saxis - iaxis)) + (Ellipsis,) + ((None,) * (safter - iafter))]
    grid = np.ogrid[tuple(map(slice, source.shape))]
    grid[saxis] = indices
    return tuple(grid)


def template_to(inp, template):
    return np.argsort(
        xsel(np.argsort(inp, axis=-1), 1, np.argsort(np.argsort(template, axis=-1), axis=-1), 1),
        axis=-1,
    )


def intersect2d(A, B, *args, **kwargs):
    nrows, ncols = A.shape
    dtype = {"names": ["f{}".format(i) for i in range(ncols)], "formats": ncols * [A.dtype]}
    C = np.intersect1d(A.view(dtype), B.view(dtype), *args, **kwargs)
    if len(C) > 0:
        c = C[0]
    else:
        c = C
    c = c.view(A.dtype).reshape(-1, ncols)
    return C


def sigmaclip_mask(arr, low=4.0, high=4.0):
    delta = 1
    arr = np.array(arr)
    mask = np.ones_like(arr, dtype=bool)
    while delta:
        arr_std = arr[mask].std()
        arr_mean = arr[mask].mean()
        good = np.sum(mask)
        critlower = arr_mean - arr_std * low
        critupper = arr_mean + arr_std * high
        mask = mask & (arr >= critlower) & (arr <= critupper)
        delta = good - np.sum(mask)
    return ~mask