如何在pytorch中获取掩码Tensor的每个通道的最小值?

6za6bjd0  于 2023-04-12  发布在  其他
关注(0)|答案(1)|浏览(200)

我有一个大小为torch.Size([b, 1, h, w])mask,由map > 0.5获取,还有一个大小为torch.Size([b, c, h, w])tensor,如何获取大小为torch.Size([b, c, 1])masked tensor的最小值?
我尝试了tensor[mask],但它返回了1-D结果。
先谢谢你了。

xpszyzbs

xpszyzbs1#

torch.where(mask,tensor,torch.inf).min(-1).values.min(-1).values.unsqueeze(-1)
这通过用无穷大替换屏蔽的元素来工作。
下面是一个简短的工作示例:

import torch

B = 2
C = 3
H = 4
W = 5
tensor = torch.randn(B,C,H,W)
mask = torch.rand(B,1,H,W) > 0.5
out = torch.where(mask,tensor,torch.inf).min(-1).values.min(-1).values.unsqueeze(-1)

相关问题