我有一个大小为torch.Size([b, 1, h, w])的mask,由map > 0.5获取,还有一个大小为torch.Size([b, c, h, w])的tensor,如何获取大小为torch.Size([b, c, 1])的masked tensor的最小值?我尝试了tensor[mask],但它返回了1-D结果。先谢谢你了。
torch.Size([b, 1, h, w])
mask
map > 0.5
torch.Size([b, c, h, w])
tensor
torch.Size([b, c, 1])
masked tensor
tensor[mask]
xpszyzbs1#
torch.where(mask,tensor,torch.inf).min(-1).values.min(-1).values.unsqueeze(-1)这通过用无穷大替换屏蔽的元素来工作。下面是一个简短的工作示例:
torch.where(mask,tensor,torch.inf).min(-1).values.min(-1).values.unsqueeze(-1)
import torch B = 2 C = 3 H = 4 W = 5 tensor = torch.randn(B,C,H,W) mask = torch.rand(B,1,H,W) > 0.5 out = torch.where(mask,tensor,torch.inf).min(-1).values.min(-1).values.unsqueeze(-1)
1条答案
按热度按时间xpszyzbs1#
torch.where(mask,tensor,torch.inf).min(-1).values.min(-1).values.unsqueeze(-1)
这通过用无穷大替换屏蔽的元素来工作。
下面是一个简短的工作示例: