pytorch 如何使argsort返回二维索引

pbpqsu0x  于 2023-02-16  发布在  其他
关注(0)|答案(1)|浏览(151)

torch argsort根据dim=...返回行或列内的索引
我怎么才能得到一个二维索引...

[[ r1,r2,r3,...], [c1,c2,c3,.....]]

谢谢...这是我做的

#https://github.com/pytorch/pytorch/issues/35674
def unravel_indices(indices, shape):
    coord = []

    for dim in reversed(shape):
        coord.append(torch.fmod(indices, dim))
        indices = torch.div(indices, dim, rounding_mode='floor')

    coord = torch.stack(coord[::-1], dim=-1)

    return coord
torch.unravel_indices = unravel_indices
6ojccjat

6ojccjat1#

虽然numpyunravel_index来执行此操作,但没有内置的Torch,但您可以自己完成。

yy, xx = indices // width, indices % width

fwiw pytorch已经有一个function request and PR(s) floating around好几年了。

相关问题