torch argsort根据dim=...返回行或列内的索引
我怎么才能得到一个二维索引...
[[ r1,r2,r3,...], [c1,c2,c3,.....]]
谢谢...这是我做的
#https://github.com/pytorch/pytorch/issues/35674
def unravel_indices(indices, shape):
coord = []
for dim in reversed(shape):
coord.append(torch.fmod(indices, dim))
indices = torch.div(indices, dim, rounding_mode='floor')
coord = torch.stack(coord[::-1], dim=-1)
return coord
torch.unravel_indices = unravel_indices
1条答案
按热度按时间6ojccjat1#
虽然
numpy
有unravel_index
来执行此操作,但没有内置的Torch,但您可以自己完成。fwiw pytorch已经有一个function request and PR(s) floating around好几年了。