我有Tensor和掩码,还有第二个掩码。
现在我想改变tensor[mask][second_mask]的值,但它不起作用。
我认为这是因为tensor[mask]返回一个新的tensor,而不是原始tensor的视图,并且将值应用于tensor[mask][second_mask]不会改变原始tensor的值。我的演示如下:
import torch
x = torch.linspace(1,9,9).reshape((3,3))
mask = x>5
second_mask = 0 # In practice, it will be a booltensor
x[mask][second_mask] = 100
print(x)
# One way to solve it is use a temp tensor be like:
t = x[mask]
t[second_mask] = 100
x[mask] = t
# but it is sooooo long, hoping for any convenient method
字符串
2条答案
按热度按时间oknwwptz1#
您可以使用索引而不是掩码:
字符串
2ledvvac2#
选择操作返回的不是视图,这是真的。
因为你的第二个掩码只是索引
0
,而不是真正的掩码,你总是试图更新第一个正掩码值吗?如果是这样,对于这种情况,你可以这样做:字符串
当所有的max值都相同时,
argmax
返回第一个max值。不幸的是,它没有为bool实现,因此是to(torch.int)
。