Numpy: get column and row index of minimum value of 2D array

For example,

x = array([[1,2,3],[3,2,5],[9,0,2]])
some_func(x) gives (2,1)

      

I know it can be done with a custom function:

def find_min_idx(x):
    k = x.argmin()
    ncol = x.shape[1]
    return k/ncol, k%ncol

      

However, I'm wondering if there is a built-in numpy function that makes this faster.

Thank.

EDIT: thanks for the answers. I tested their speeds as follows:

%timeit np.unravel_index(x.argmin(), x.shape)
#100000 loops, best of 3: 4.67 Β΅s per loop

%timeit np.where(x==x.min())
#100000 loops, best of 3: 12.7 Β΅s per loop

%timeit find_min_idx(x) # this is using the custom function above
#100000 loops, best of 3: 2.44 Β΅s per loop

      

The custom function seems to be faster than unravel_index () and where (). The unravel_index () function performs similar functionality, such as a user-defined function, plus the overhead of checking additional arguments. where () can return multiple indices, but much slower for my purpose. Perhaps pure python code isn't so slow that it only does two simple arithmetic, and the UDF approach is as fast as it can get.

+3


source to share


2 answers


You can use np.where

:

In [9]: np.where(x == np.min(x))
Out[9]: (array([2]), array([1]))

      

Just like @senderle mentioned in the comment, to get the values ​​in an array, you can use np.argwhere

:

In [21]: np.argwhere(x == np.min(x))
Out[21]: array([[2, 1]])

      

Updated:



As the OP's time suggests, and it is much clearer what argmin

is desirable (no duplicate minutes, etc.), one way I think might improve the OP's original approach a bit is to use divmod

:

divmod(x.argmin(), x.shape[1])

      

Timing and you will find that an extra bit of speed, but not much, but still an improvement.

%timeit find_min_idx(x)
1000000 loops, best of 3: 1.1 Β΅s per loop

%timeit divmod(x.argmin(), x.shape[1])
1000000 loops, best of 3: 1.04 Β΅s per loop

      

If you are really concerned about performance, you can take a look at cython .

+7


source


You can use np.unravel_index



print(np.unravel_index(x.argmin(), x.shape))
(2, 1)

      

+6


source







All Articles