numpy 最快的2darray索引

i7uaboj4  于 2023-04-12  发布在  其他
关注(0)|答案(1)|浏览(90)

假设我有一个数字数组,其中第一维包含从0到19999的20,000个数字,第二维包含20个数字,可以假设0到19999之间的任何值。换句话说,对于每个索引(如第一维所描述的),我在第二维中有20个选择。我称之为相似度矩阵。
我的目标是替换一些输入流中的数字例如,输入流中的一个数字可以是5。然后我在我的相似性矩阵中查找索引5,并且使用相似性矩阵对20个备选方案中的一个进行采样,以获得与其最相似的5个。问题是我经常执行这个操作,这是我程序的主要瓶颈。
我的第一次尝试是这样的:
(其中similarity_matrix是一个numpy数组)

num_samples = 20
random_numbers = [random.randint(0, similarity_matrix.shape[-1]-1) for x in range(num_samples)]
sampled_entities = similarity_matrix[stream, random_numbers]

太慢了,然后我试着

a = list(similarity_matrix)
interim = [a[i] for i in stream]
sampled_entities = [x[i] for x, i in zip(interim, random_numbers)]

这也太慢了。你有其他的建议吗?

lztngnrs

lztngnrs1#

要加快处理速度,您可以应用numpy.random.Generator.choice

rng = np.random.default_rng()
num_samples = 20
sampled_entities = similarity_matrix[stream, np.random.choice(similarity_matrix.shape[-1], num_samples)]

相关问题