pytorch 筛选出每列中满足条件行

p1tboqfb  于 2022-11-09  发布在  其他
关注(0)|答案(2)|浏览(297)

假设我有一个Tensor:

input: ([[-0.5535,  0.0000],
        [ 0.0000,  0.0000],
        [-1.1370, -0.2736],
        [-1.2300,  0.9185]])

Output:([[-0.5535,  0.0000],
        [-1.1370, -0.2736],
        [-1.2300,  0.9185]])

我只需要保留所有列中有非零元素的行,以及被删除行的索引。为了简单起见,我将矩阵限制为两列,但在我的例子中,列数和行数在每次迭代中都在变化。
我已经找到了满足矩阵中任何元素的条件的解决方案,或者每列可能有单独的条件要满足,但我不知道如何解决这个特殊的情况。

  • 谢谢-谢谢
06odsfpq

06odsfpq1#

您可以这样做:

x=np.array([[1,2,3],[0,0,0],[4,5,6]])
mask = x!=0
index = mask.prod(axis=1).astype(bool)
x[index.astype(bool),:]

其思想是识别哪些值不为零,然后在列中相乘,更改为布尔值,并将其用作数组的索引。

mpgws1up

mpgws1up2#

我会回答我自己的问题,因为我在pytorch找到了解决方案。
此函数将返回非x的行索引[torch.nonzero(torch.tensor(x),as_tuple=True)[0].unique()]

x[ Torch .非零( Torch .Tensor(x.总和(1)),as_元组=True)[0]]

相关问题