Newer
Older
# -*- coding: utf-8 -*-
import numpy as np
def one_hot(a, n=None):
if n is None:
n = a.max() + 1
o_h = np.zeros((a.size, n))
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
o_h[np.arange(a.size), a] = 1
return o_h
def array_to_hist(values, bins):
hist = np.histogram2d([0.0], [0.0], bins=bins)
hist = list(hist)
hist[0] = values
hist = tuple(hist)
return hist
def swap_rows(a, swp, roll=1, swp_to=None):
"""Swaps rows in the last dimension of a.
The rows to swap are specified by the swp argument.
If swp_to is defined, the rows to swap the rows from swp
in can be directly specified, if not, the swapping is done
in a rolling manner.
Returns:
Array with swapped rows
"""
if not (swp_to):
a[..., swp] = a[..., np.roll(swp, roll)]
return a
else:
a[..., swp] = a[..., swp_to]
return a
def create_swaps(swap_length, circle, length):
start = -swap_length if circle else 0
return np.column_stack(
(np.arange(start, length - swap_length), np.arange(start + swap_length, length))
)
def _axis_clip(ref, axis):
assert -ref <= axis < ref
axis = ref + axis if axis < 0 else axis
return axis, ref - axis - 1
def xsel(source, saxis, indices, iaxis=-1):
return source[xsel_mask(source, saxis, indices, iaxis)]
def xsel_mask(source, saxis, indices, iaxis=-1):
saxis, safter = _axis_clip(source.ndim, saxis)
iaxis, iafter = _axis_clip(indices.ndim, iaxis)
assert iaxis <= saxis
assert iafter <= safter
indices = indices[((None,) * (saxis - iaxis)) + (Ellipsis,) + ((None,) * (safter - iafter))]
grid = np.ogrid[tuple(map(slice, source.shape))]
grid[saxis] = indices
return tuple(grid)
def template_to(inp, template):
return np.argsort(
xsel(np.argsort(inp, axis=-1), 1, np.argsort(np.argsort(template, axis=-1), axis=-1), 1),
axis=-1,
)