Pytorch -将钩子的输出添加到损失函数

wb1gzix0  于 2023-08-05  发布在  其他
关注(0)|答案(3)|浏览(116)

我正在加载一个模型(不是我写的/训练的),我已经使用register_forward_hook向模型的一些层添加了钩子。
我的钩子计算了层的输入的一些转换(这是前一层的输出)。
目标是将钩子计算的转换添加到损失函数中,以便在微调期间,模型将尝试学习转换的输出应最小化。
举例来说:

y1 = None

def hook(module, input):
   y1 = foo(input)

model.some_layer.register_forward_hook(hook)

loss = MSE(...) + L1(y1.detach())

字符串
这样做有意义吗?它能反向传播吗?

sirbozc5

sirbozc51#

根据这里的文档和这里的一个超级令人印象深刻的教程,我可以说,前向钩子只在forward()方法调用期间被调用。
这意味着按照@Ahmad的建议将global关键字添加到变量中应该可以工作。然而,主要的问题仍然是当你使用GPU并行训练时,y1的值会发生什么。我不确定你是否会得到你正在训练的批次的正确对应值。
你真的需要hooking吗?我的解决方案是在forward方法中保留一个train布尔参数-在训练时将True传递给这个参数,并在测试或推断时将其设置为False。在forward方法中,通过直接调用处理方法(如上所述的foo)处理从感兴趣的层中获得的Tensor,如果train参数设置为True,则返回这些Tensor。定义您选择的损失准则,并将这些“处理过的Tensor”与输出Tensor沿着作为参数传递给您的损失函数。

pbgvytdp

pbgvytdp2#

将关键字global添加到y1调用

y1 = None

def hook(module, input):
    global y1
    y1 = foo(input)

字符串
此外,每次在forward passbackward pass期间调用层时,都会执行forward hook。由于钩子是在forward pass期间调用的,因此您可以根据前一层的输出计算y1。但是,在backward pass期间,仍将调用钩子,但是y1变量可能不再具有正确的值,因为它是在forward pass期间计算的。在我看来,钩子函数不应该用来存储你在backward pass中需要的值。相反,您可以将中间值保存在模型内的单独属性或缓冲区中。

osh3o9ms

osh3o9ms3#

由于hook函数作用域内的局部变量y1,您尝试实现它的方式将无法按预期工作。这将不允许外部作用域访问钩子中分配给y1的值。
您可能需要考虑将y1作为模型的一个属性,或者使用一个可以在钩子函数内部和外部访问的列表/字典对象。Python的listdict对象是可变的,当传递给函数时,函数可以修改它们。
下面是一个如何使用属性执行此操作的示例:

class YourModel(nn.Module):
    def __init__(self, *args, **kwargs):
        super(YourModel, self).__init__(*args, **kwargs)
        self.y1 = None

    def forward(self, x):
        # define your forward pass
        pass

    def hook(self, module, input):
        self.y1 = foo(input)

model = YourModel()
model.some_layer.register_forward_hook(model.hook)

字符串
或者使用列表:

y1 = []

def hook(module, input):
   y1.append(foo(input))

model.some_layer.register_forward_hook(hook)


至于你问题的另一部分,是的,将钩子计算的变换添加到损失函数中是可行的,只要确保y1Tensor保持其梯度,反向传播就可以按预期工作。
在给定的示例中,您已经使用.detach()y1从计算图中分离出来。这将防止梯度通过它反向传播。因此,需要在损失函数中从y1中删除.detach()

loss = MSE(...) + L1(y1)


此外,在每次向前传递之前,请小心清除列表或重置属性,以防止存储不必要的计算历史,这将消耗内存并可能导致不正确的结果。

相关问题