假设我有一个Tensor:
input: ([[-0.5535, 0.0000],
[ 0.0000, 0.0000],
[-1.1370, -0.2736],
[-1.2300, 0.9185]])
Output:([[-0.5535, 0.0000],
[-1.1370, -0.2736],
[-1.2300, 0.9185]])
我只需要保留所有列中有非零元素的行,以及被删除行的索引。为了简单起见,我将矩阵限制为两列,但在我的例子中,列数和行数在每次迭代中都在变化。
我已经找到了满足矩阵中任何元素的条件的解决方案,或者每列可能有单独的条件要满足,但我不知道如何解决这个特殊的情况。
- 谢谢-谢谢
2条答案
按热度按时间06odsfpq1#
您可以这样做:
其思想是识别哪些值不为零,然后在列中相乘,更改为布尔值,并将其用作数组的索引。
mpgws1up2#
我会回答我自己的问题,因为我在pytorch找到了解决方案。
此函数将返回非x的行索引[torch.nonzero(torch.tensor(x),as_tuple=True)[0].unique()]
或
x[ Torch .非零( Torch .Tensor(x.总和(1)),as_元组=True)[0]]