我得到了一个数字矩阵,我想要得到每一行中最大值的索引。例如。
[[1,2,3],[1,3,2],[3,2,1]]
会回来的
[0,1,2]
但是,当每行中有超过1个最大值时,numpy.argmax
将只返回最小的索引。例如。
[[0,0,0],[0,0,0],[0,0,0]]
会回来的
[0,0,0]
我可以将默认(最小索引)更改为其他一些值吗?例如,当存在相等的最大值时,返回1
或None
,以便上面的结果为
[1,1,1]
or
[None, None, None]
如果我能在TensorFlow中做到这一点,那就更好了。
谢谢!
2条答案
按热度按时间8cdiaqws1#
您可以使用
np.partition
两个查找最大的两个值并检查它们是否相等,然后将其用作np.where
中的掩码来设置默认值:30byixjq2#
“Default”的一个方便值是-1,因为
argmax
本身不会返回该值。None
不适合整数数组。masked array也是一种选择,但我没有走那么远。下面是一个NumPy实现使用示例:
说明:
where
选择每行all个最大元素的索引。如果一行有多个最大值,行号将在rows
数组中多次出现。由于该数组已经排序,因此通过比较连续的元素来检测这种重复。这标识了具有多个最大值的行,之后它们在NumPy的argmax
方法的输出中被屏蔽。