我有一个2DTensor,每行都有一些非零元素,如下所示:
import torch tmp = torch.tensor([[0, 0, 1, 0, 1, 0, 0], [0, 0, 0, 1, 1, 0, 0]], dtype=torch.float)
我需要一个包含每行中第一个非零元素的索引的Tensor:
indices = tensor([2], [3])
在Pytorch中如何计算?
oxiaedzo1#
我简化了Iman的方法,如下所示:
idx = torch.arange(tmp.shape[1], 0, -1) tmp2= tmp * idx indices = torch.argmax(tmp2, 1, keepdim=True)
6jjcrrmo2#
我可以为我的问题找到一个巧妙的答案:
tmp = torch.tensor([[0, 0, 1, 0, 1, 0, 0], [0, 0, 0, 1, 1, 0, 0]], dtype=torch.float) idx = reversed(torch.Tensor(range(1,8))) print(idx) tmp2= torch.einsum("ab,b->ab", (tmp, idx)) print(tmp2) indices = torch.argmax(tmp2, 1, keepdim=True) print(indeces)
结果是:
tensor([7., 6., 5., 4., 3., 2., 1.]) tensor([[0., 0., 5., 0., 3., 0., 0.], [0., 0., 0., 4., 3., 0., 0.]]) tensor([[2], [3]])
zfycwa2u3#
假设所有非零值都相等,argmax返回第一个索引。
argmax
tmp = torch.tensor([[0, 0, 1, 0, 1, 0, 0], [0, 0, 0, 1, 1, 0, 0]]) indices = tmp.argmax(1)
tvokkenx4#
在@Seppo答案的基础上,我们可以通过简单地从原始Tensor创建掩码然后使用pytorch函数来消除“所有非零值都相等”的假设
# tmp = some tensor of whatever shape and values indices = torch.argmax((tmp != 0).to(dtype=torch.int), dim=-1)
然而,如果Tensor的一行全是零,那么返回的信息不是第一个非零元素的索引,我想问题的本质使这种情况不会发生。
4条答案
按热度按时间oxiaedzo1#
我简化了Iman的方法,如下所示:
6jjcrrmo2#
我可以为我的问题找到一个巧妙的答案:
结果是:
zfycwa2u3#
假设所有非零值都相等,
argmax
返回第一个索引。tvokkenx4#
在@Seppo答案的基础上,我们可以通过简单地从原始Tensor创建掩码然后使用pytorch函数来消除“所有非零值都相等”的假设
然而,如果Tensor的一行全是零,那么返回的信息不是第一个非零元素的索引,我想问题的本质使这种情况不会发生。