Skip to content
Snippets Groups Projects
numpy.py 1.87 KiB
Newer Older
# -*- coding: utf-8 -*-

import numpy as np


def one_hot(a, n=None):
    if n is None:
        n = a.max() + 1
    o_h = np.zeros((a.size, n))
    o_h[np.arange(a.size), a] = 1
    return o_h


def array_to_hist(values, bins):
    hist = np.histogram2d([0.0], [0.0], bins=bins)
    hist = list(hist)
    hist[0] = values
    hist = tuple(hist)
    return hist


def swap_rows(a, swp, roll=1, swp_to=None):
    """Swaps rows in the last dimension of a.
        The rows to swap are specified by the swp argument.
        If swp_to is defined, the rows to swap the rows from swp
        in can be directly specified, if not, the swapping is done
        in a rolling manner.
    Returns:
        Array with swapped rows
    """
    if not (swp_to):
        a[..., swp] = a[..., np.roll(swp, roll)]
        return a
    else:
        a[..., swp] = a[..., swp_to]
        return a


def create_swaps(swap_length, circle, length):
    start = -swap_length if circle else 0
    return np.column_stack(
        (np.arange(start, length - swap_length), np.arange(start + swap_length, length))
    )


def _axis_clip(ref, axis):
    assert -ref <= axis < ref
    axis = ref + axis if axis < 0 else axis
    return axis, ref - axis - 1


def xsel(source, saxis, indices, iaxis=-1):
    return source[xsel_mask(source, saxis, indices, iaxis)]


def xsel_mask(source, saxis, indices, iaxis=-1):
    saxis, safter = _axis_clip(source.ndim, saxis)
    iaxis, iafter = _axis_clip(indices.ndim, iaxis)
    assert iaxis <= saxis
    assert iafter <= safter
    indices = indices[((None,) * (saxis - iaxis)) + (Ellipsis,) + ((None,) * (safter - iafter))]
    grid = np.ogrid[tuple(map(slice, source.shape))]
    grid[saxis] = indices
    return tuple(grid)


def template_to(inp, template):
    return np.argsort(
        xsel(np.argsort(inp, axis=-1), 1, np.argsort(np.argsort(template, axis=-1), axis=-1), 1),
        axis=-1,
    )