PytorchTensor在掩蔽时的维数损失

v64noz0r  于 2023-10-20  发布在  其他
关注(0)|答案(2)|浏览(101)

首先,这基本上与Numpy array loss of dimension when masking相同,但对于PyTorchTensor而不是NumPy数组。这个问题的solution(s),使用等价的PyTorch函数torch.wheremasked tensors工作,但我发现Google搜索PyTorchTensor并没有很快找到答案。所以,我想一个等价的StackOverflow pytorch标记问题可能对其他人有用!)
我有一个2D PyTorchTensor(尽管它可能有更多的维度),我想应用一个等效形状的二进制掩码。然而,当我应用掩码时,输出只是一维的。如何在应用遮罩后保持与原始Tensor相同的尺寸?
例如用于

import torch

x = torch.tensor([[1.0, 2.0, 8.0], [-4.0, 0.0, 3.0]])
mask = x >=2.0
print(x[mask])
tensor([2., 8., 3.])

输出现在是1D而不是2D。

e5njpo68

e5njpo681#

使用torchwhere函数,我们将得到行,列tensors如下:

import torch

x = torch.tensor([[1.0, 2.0, 8.0], [-4.0, 0.0, 3.0]])
mask = x >=2.0
print(torch.where(mask))
print(x[torch.where(mask)])

其输出:

(tensor([0, 0, 1]), tensor([1, 2, 2]))
tensor([2., 8., 3.])

然而,将其插入x将仅输出mask的艾德化值,从而使我们得到1D tensor,因为我们正在摆脱tensor中的所有值,这些值的计算结果不为True(因此它不可能与原始tensor的形状相同,因为它的值较少,因此它被展平为一维)。
如果你想让x是它的原始形状,但 onlywheremask的艾德值是Truth y,那么我们可以在这些索引处用1填充一个0的tensor

import torch

x = torch.tensor([[1.0, 2.0, 8.0], [-4.0, 0.0, 3.0]])
mask = x >=2.0
masked_x = torch.zeros((x.shape))
masked_x[torch.where(mask)] = 1.0
print(masked_x)

输出:

tensor([[0., 1., 1.],
        [0., 0., 1.]])

所以现在masked_x是一个tensor的零,形状与x相同,但与1 wheremaskTruth y。
如果你想让masked_xx的值where组成,那么mask就是Truth y,那么:

masked_x = torch.zeros((x.shape))
masked_x[torch.where(mask)] = x[torch.where(mask)]
print(masked_x)

输出:

tensor([[0., 2., 8.],
        [0., 0., 3.]])

0的tensorwheremaskFalse y。
如果你想要别的东西,请澄清。

fnx2tebb

fnx2tebb2#

就像numpy question中建议的那样,torch.where可以实现这一点。
为了保持相同的维度,你需要一个填充值。在下面的示例中,我使用0,但您也可以使用torch.nan

torch.where(mask, x, 0)

返回

tensor([[0., 2., 8.],
        [0., 0., 3.]])

相关问题