值相等时的NumPy argmax

e0uiprwp  于 2022-11-10  发布在  其他
关注(0)|答案(2)|浏览(204)

我得到了一个数字矩阵,我想要得到每一行中最大值的索引。例如。

[[1,2,3],[1,3,2],[3,2,1]]

会回来的

[0,1,2]

但是,当每行中有超过1个最大值时,numpy.argmax将只返回最小的索引。例如。

[[0,0,0],[0,0,0],[0,0,0]]

会回来的

[0,0,0]

我可以将默认(最小索引)更改为其他一些值吗?例如,当存在相等的最大值时,返回1None,以便上面的结果为

[1,1,1]
or
[None, None, None]

如果我能在TensorFlow中做到这一点,那就更好了。
谢谢!

8cdiaqws

8cdiaqws1#

您可以使用np.partition两个查找最大的两个值并检查它们是否相等,然后将其用作np.where中的掩码来设置默认值:

In [228]: a = np.array([[1, 2, 3, 2], [3, 1, 3, 2], [3, 5, 2, 1]])

In [229]: twomax = np.partition(a, -2)[:, -2:].T

In [230]: default = -1

In [231]: argmax = np.where(twomax[0] != twomax[1], np.argmax(a, -1), default)

In [232]: argmax
Out[232]: array([ 2, -1,  1])
30byixjq

30byixjq2#

“Default”的一个方便值是-1,因为argmax本身不会返回该值。None不适合整数数组。masked array也是一种选择,但我没有走那么远。下面是一个NumPy实现

def my_argmax(a):
    rows = np.where(a == a.max(axis=1)[:, None])[0]
    rows_multiple_max = rows[:-1][rows[:-1] == rows[1:]]
    my_argmax = a.argmax(axis=1)
    my_argmax[rows_multiple_max] = -1
    return my_argmax

使用示例:

import numpy as np
a = np.array([[0, 0, 0], [4, 5, 3], [3, 4, 4], [6, 2, 1]])
my_argmax(a)   #  array([-1,  1, -1,  0])

说明:where选择每行all个最大元素的索引。如果一行有多个最大值,行号将在rows数组中多次出现。由于该数组已经排序,因此通过比较连续的元素来检测这种重复。这标识了具有多个最大值的行,之后它们在NumPy的argmax方法的输出中被屏蔽。

相关问题