numpy 随机选择非唯一最大值的argmax

acruukt9  于 11个月前  发布在  其他
关注(0)|答案(2)|浏览(86)

给定一个2D numpy数组,我想用每行最大值的列索引构造一个数组。到目前为止,arr.argmax(1)工作得很好。然而,对于我的特定情况,对于某些行,2个或更多列可能包含最大值。在这种情况下,我想随机选择一个列索引(不是第一个索引,因为它是.argmax(1)的情况)。
例如,对于以下arr

arr = np.array([
    [0, 1, 0],
    [1, 1, 0],
    [2, 1, 3],
    [3, 2, 2]
])

字符串
有两种可能的结果:array([1, 0, 2, 0])array([1, 1, 2, 0]),每一种都以1/2的概率被选择。
我有一段代码,它使用列表解析返回预期的输出:

idx = np.arange(arr.shape[1])
ans = [np.random.choice(idx[ix]) for ix in arr == arr.max(1, keepdims=True)]


但是我在寻找一个优化的numpy解决方案。换句话说,我如何用numpy方法替换列表解析,使代码适用于更大的数组?

5jvtdoz2

5jvtdoz21#

使用scipy.stats.rankdataapply_along_axis如下。

import numpy as np
from scipy.stats import rankdata
ranks = rankdata(-arr, axis = 1, method = "min")
func = lambda x: np.random.choice(np.where(x==1)[0])
idx = np.apply_along_axis(func, 1, ranks)

print(idx)

字符串
它返回[1 0 2 0]或[1 1 2 0]。
rankdata计算每行中每个值的秩,最大值为1。func随机选择一个对应值为1的索引。最后,apply_along_axisfunc应用于arr的每一行。

n3schb8v

n3schb8v2#

在我离线的一些建议之后,事实证明,当我们将标记行最大值的布尔数组乘以相同形状的随机数组时,最大值的随机化是可能的。然后剩下的是一个简单的argmax(1)调用。

# boolean array that flags maximum values of each row
mxs = arr == arr.max(1, keepdims=True)
# random array where non-maximum values are zero and maximum values are random values
random_arr = np.random.rand(*arr.shape) * mxs
# row-wise maximum of the auxiliary array
ans = random_arr.argmax(1)

字符串
一个timeit测试显示,对于形状为(507_563, 12)的数据,这段代码在我的机器上运行时间约为172 ms,而问题中的循环运行时间为11秒,因此这大约快了63倍。

相关问题