为什么代码作者在Pytorch Forward中重用变量x

fnvucqvd  于 2023-03-08  发布在  其他
关注(0)|答案(1)|浏览(134)

Pytorch模型中典型的前向声明如下所示:

def forward(self, x):
        x = self.conv1(x)
        x = F.relu(F.max_pool2d(x, kernel_size = 2))
        x = self.drop1(x)
        return x

它似乎是普遍使用的标准。但是我已经能够通过创建新的变量来让代码工作。

def forward(self, x):
        a = self.conv1(x)
        b = F.relu(F.max_pool2d(a, kernel_size = 2))
        c = self.drop1(b)
        return c

我在任何地方都找不到实际的解释。有人能解释一下为什么重用的x版本是首选的吗?

gr8qqesn

gr8qqesn1#

一个仍在使用中的变量(例如a),只要仍然存在对它的引用,就不能被Python的垃圾收集机制释放。
使用a、B、c变量的版本可能会导致内存使用峰值更高,使用相同的名称x,在几行之后,就没有对self.conv1(x)输出的引用了,因此引用计数变为零,内存可以被释放。
此外,重复使用同一变量可以更轻松地快速重新排序操作或注解掉某些层。
请注意,峰值内存通常是两行代码,而不是一行。在第一个示例中,F.relu输出的内存将被分配并赋值给x,然后self.conv1输出的内存才能被释放...但与第二个示例不同的是,它可以在调用self.drop1之前被释放。

相关问题