给定一个2D numpy数组,我想用每行最大值的列索引构造一个数组。到目前为止,arr.argmax(1)
工作得很好。然而,对于我的特定情况,对于某些行,2个或更多列可能包含最大值。在这种情况下,我想随机选择一个列索引(不是第一个索引,因为它是.argmax(1)
的情况)。
例如,对于以下arr
:
arr = np.array([
[0, 1, 0],
[1, 1, 0],
[2, 1, 3],
[3, 2, 2]
])
字符串
有两种可能的结果:array([1, 0, 2, 0])
和array([1, 1, 2, 0])
,每一种都以1/2的概率被选择。
我有一段代码,它使用列表解析返回预期的输出:
idx = np.arange(arr.shape[1])
ans = [np.random.choice(idx[ix]) for ix in arr == arr.max(1, keepdims=True)]
型
但是我在寻找一个优化的numpy解决方案。换句话说,我如何用numpy方法替换列表解析,使代码适用于更大的数组?
2条答案
按热度按时间5jvtdoz21#
使用
scipy.stats.rankdata
和apply_along_axis
如下。字符串
它返回[1 0 2 0]或[1 1 2 0]。
rankdata
计算每行中每个值的秩,最大值为1。func
随机选择一个对应值为1的索引。最后,apply_along_axis
将func
应用于arr
的每一行。n3schb8v2#
在我离线的一些建议之后,事实证明,当我们将标记行最大值的布尔数组乘以相同形状的随机数组时,最大值的随机化是可能的。然后剩下的是一个简单的
argmax(1)
调用。字符串
一个timeit测试显示,对于形状为
(507_563, 12)
的数据,这段代码在我的机器上运行时间约为172 ms,而问题中的循环运行时间为11秒,因此这大约快了63倍。