diff --git a/numpy.py b/numpy.py new file mode 100644 index 0000000000000000000000000000000000000000..11daaa81947cf291671202eeddffcdb618672470 --- /dev/null +++ b/numpy.py @@ -0,0 +1,69 @@ +# -*- coding: utf-8 -*- + +import numpy as np + + +def one_hot(a): + o_h = np.zeros((a.size, a.max() + 1)) + o_h[np.arange(a.size), a] = 1 + return o_h + + +def array_to_hist(values, bins): + hist = np.histogram2d([0.0], [0.0], bins=bins) + hist = list(hist) + hist[0] = values + hist = tuple(hist) + return hist + + +def swap_rows(a, swp, roll=1, swp_to=None): + """Swaps rows in the last dimension of a. + The rows to swap are specified by the swp argument. + If swp_to is defined, the rows to swap the rows from swp + in can be directly specified, if not, the swapping is done + in a rolling manner. + Returns: + Array with swapped rows + """ + if not (swp_to): + a[..., swp] = a[..., np.roll(swp, roll)] + return a + else: + a[..., swp] = a[..., swp_to] + return a + + +def create_swaps(swap_length, circle, length): + start = -swap_length if circle else 0 + return np.column_stack( + (np.arange(start, length - swap_length), np.arange(start + swap_length, length)) + ) + + +def _axis_clip(ref, axis): + assert -ref <= axis < ref + axis = ref + axis if axis < 0 else axis + return axis, ref - axis - 1 + + +def xsel(source, saxis, indices, iaxis=-1): + return source[xsel_mask(source, saxis, indices, iaxis)] + + +def xsel_mask(source, saxis, indices, iaxis=-1): + saxis, safter = _axis_clip(source.ndim, saxis) + iaxis, iafter = _axis_clip(indices.ndim, iaxis) + assert iaxis <= saxis + assert iafter <= safter + indices = indices[((None,) * (saxis - iaxis)) + (Ellipsis,) + ((None,) * (safter - iafter))] + grid = np.ogrid[tuple(map(slice, source.shape))] + grid[saxis] = indices + return tuple(grid) + + +def template_to(inp, template): + return np.argsort( + xsel(np.argsort(inp, axis=-1), 1, np.argsort(np.argsort(template, axis=-1), axis=-1), 1), + axis=-1, + )