Pytorch:我怎样才能找到二维Tensor每一行中第一个非零元素的索引?

a5g8bdjr  于 2023-03-08  发布在  其他
关注(0)|答案(4)|浏览(350)

我有一个2DTensor,每行都有一些非零元素,如下所示:

import torch
tmp = torch.tensor([[0, 0, 1, 0, 1, 0, 0],
                    [0, 0, 0, 1, 1, 0, 0]], dtype=torch.float)

我需要一个包含每行中第一个非零元素的索引的Tensor:

indices = tensor([2],
                 [3])

在Pytorch中如何计算?

oxiaedzo

oxiaedzo1#

我简化了Iman的方法,如下所示:

idx = torch.arange(tmp.shape[1], 0, -1)
tmp2= tmp * idx
indices = torch.argmax(tmp2, 1, keepdim=True)
6jjcrrmo

6jjcrrmo2#

我可以为我的问题找到一个巧妙的答案:

tmp = torch.tensor([[0, 0, 1, 0, 1, 0, 0],
                     [0, 0, 0, 1, 1, 0, 0]], dtype=torch.float)
  idx = reversed(torch.Tensor(range(1,8)))
  print(idx)

  tmp2= torch.einsum("ab,b->ab", (tmp, idx))

  print(tmp2)

  indices = torch.argmax(tmp2, 1, keepdim=True)
  print(indeces)

结果是:

tensor([7., 6., 5., 4., 3., 2., 1.])
tensor([[0., 0., 5., 0., 3., 0., 0.],
       [0., 0., 0., 4., 3., 0., 0.]])
tensor([[2],
        [3]])
zfycwa2u

zfycwa2u3#

假设所有非零值都相等,argmax返回第一个索引。

tmp = torch.tensor([[0, 0, 1, 0, 1, 0, 0],
                    [0, 0, 0, 1, 1, 0, 0]])
indices = tmp.argmax(1)
tvokkenx

tvokkenx4#

在@Seppo答案的基础上,我们可以通过简单地从原始Tensor创建掩码然后使用pytorch函数来消除“所有非零值都相等”的假设

# tmp = some tensor of whatever shape and values
indices = torch.argmax((tmp != 0).to(dtype=torch.int), dim=-1)

然而,如果Tensor的一行全是零,那么返回的信息不是第一个非零元素的索引,我想问题的本质使这种情况不会发生。

相关问题