PyTorch函数中的下划线后缀是什么意思?

myzjeezk  于 2022-11-09  发布在  其他
关注(0)|答案(3)|浏览(211)

在PyTorch中,Tensor的许多方法有两种版本--一种有下划线后缀,另一种没有。如果我尝试它们,它们似乎做了同样的事情:

In [1]: import torch

In [2]: a = torch.tensor([2, 4, 6])

In [3]: a.add(10)
Out[3]: tensor([12, 14, 16])

In [4]: a.add_(10)
Out[4]: tensor([12, 14, 16])

有什么区别

  • torch.addtorch.add_
  • torch.subtorch.sub_
  • ...等等?
2w2cym1i

2w2cym1i1#

  • 您已经回答了自己的问题,下划线表示PyTorch中的就地操作。但是,我想简要地指出为什么就地操作会有问题:*
  • 首先,在PyTorch站点上,建议在大多数情况下不使用就地操作。除非在内存压力很大的情况下工作,否则在大多数情况下不使用就地操作会更有效率

https://pytorch.org/docs/stable/notes/autograd.html#in-place-operations-with-autograd

  • 其次,在使用就地操作时,计算梯度可能会有问题:

每个Tensor都有一个版本计数器,每当它在任何操作中被标记为脏的时候,它就会递增。当函数保存任何向后的Tensor时,它们所包含的Tensor的版本计数器也会被保存。一旦你访问self.saved_tensors,它就会被检查,如果它大于保存的值,就会引发错误。这确保了如果你正在使用就地函数,并且没有看到任何错误。你可以确保计算的梯度是正确的。* 与上面的源相同。*
下面是一个从你发布的答案中截取的经过略微修改的示例:

首先是就地版本:

import torch
a = torch.tensor([2, 4, 6], requires_grad=True, dtype=torch.float)
adding_tensor = torch.rand(3)
b = a.add_(adding_tensor)
c = torch.sum(b)
c.backward()
print(c.grad_fn)

导致此错误:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-27-c38b252ffe5f> in <module>
      2 a = torch.tensor([2, 4, 6], requires_grad=True, dtype=torch.float)
      3 adding_tensor = torch.rand(3)
----> 4 b = a.add_(adding_tensor)
      5 c = torch.sum(b)
      6 c.backward()

RuntimeError: a leaf Variable that requires grad has been used in an in-place operation.

其次是非就地版本:

import torch
a = torch.tensor([2, 4, 6], requires_grad=True, dtype=torch.float)
adding_tensor = torch.rand(3)
b = a.add(adding_tensor)
c = torch.sum(b)
c.backward()
print(c.grad_fn)

工作正常-输出:

<SumBackward0 object at 0x7f06b27a1da0>

所以,作为一个外卖,我只是想指出,要小心使用在PyTorch的地方操作。

rlcwz9us

rlcwz9us2#

根据documentation,以下划线结尾的方法会改变Tensorin-place。这意味着执行该操作不会分配新的内存,通常increase performance,但会导致问题和更差的PyTorch性能。

In [2]: a = torch.tensor([2, 4, 6])

Tensor.add()

In [3]: b = a.add(10)

In [4]: a is b
Out[4]: False # b is a new tensor, new memory was allocated

Tensor.add_()

In [3]: b = a.add_(10)

In [4]: a is b
Out[4]: True # Same object, no new memory was allocated

请注意,运算符++=也是两种不同的实现。+通过使用.add()创建新Tensor,而+=通过使用.add_()修改Tensor

In [2]: a = torch.tensor([2, 4, 6])

In [3]: id(a)
Out[3]: 140250660654104

In [4]: a += 10

In [5]: id(a)
Out[5]: 140250660654104 # Still the same object, no memory allocation was required

In [6]: a = a + 10

In [7]: id(a)
Out[7]: 140250649668272 # New object was created
jv2fixgn

jv2fixgn3#

在PyTorch中,以下划线结尾,这是PyTorch中的一种约定,表示该方法不会返回新的Tensor,而是在适当的位置修改Tensor。例如,scatter_
https://yuyangyy.medium.com/understand-torch-scatter-b0fd6275331c

相关问题