Skip to content
Snippets Groups Projects
Commit 5bf258f9 authored by Dennis Noll's avatar Dennis Noll
Browse files

[numpy] one_hot: fixed edgecase

parent 711b0e59
No related branches found
No related tags found
No related merge requests found
......@@ -3,8 +3,10 @@
import numpy as np
def one_hot(a):
o_h = np.zeros((a.size, a.max() + 1))
def one_hot(a, n=None):
if n is None:
n = a.max() + 1
o_h = np.zeros((a.size, n))
o_h[np.arange(a.size), a] = 1
return o_h
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment