我正在加载一个模型(不是我写的/训练的),我已经使用register_forward_hook
向模型的一些层添加了钩子。
我的钩子计算了层的输入的一些转换(这是前一层的输出)。
目标是将钩子计算的转换添加到损失函数中,以便在微调期间,模型将尝试学习转换的输出应最小化。
举例来说:
y1 = None
def hook(module, input):
y1 = foo(input)
model.some_layer.register_forward_hook(hook)
loss = MSE(...) + L1(y1.detach())
字符串
这样做有意义吗?它能反向传播吗?
3条答案
按热度按时间sirbozc51#
根据这里的文档和这里的一个超级令人印象深刻的教程,我可以说,前向钩子只在
forward()
方法调用期间被调用。这意味着按照@Ahmad的建议将
global
关键字添加到变量中应该可以工作。然而,主要的问题仍然是当你使用GPU并行训练时,y1
的值会发生什么。我不确定你是否会得到你正在训练的批次的正确对应值。你真的需要hooking吗?我的解决方案是在
forward
方法中保留一个train
布尔参数-在训练时将True
传递给这个参数,并在测试或推断时将其设置为False
。在forward方法中,通过直接调用处理方法(如上所述的foo
)处理从感兴趣的层中获得的Tensor,如果train
参数设置为True
,则返回这些Tensor。定义您选择的损失准则,并将这些“处理过的Tensor”与输出Tensor沿着作为参数传递给您的损失函数。pbgvytdp2#
将关键字
global
添加到y1
调用字符串
此外,每次在
forward pass
和backward pass
期间调用层时,都会执行forward hook。由于钩子是在forward pass
期间调用的,因此您可以根据前一层的输出计算y1
。但是,在backward pass
期间,仍将调用钩子,但是y1
变量可能不再具有正确的值,因为它是在forward pass
期间计算的。在我看来,钩子函数不应该用来存储你在backward pass
中需要的值。相反,您可以将中间值保存在模型内的单独属性或缓冲区中。osh3o9ms3#
由于
hook
函数作用域内的局部变量y1
,您尝试实现它的方式将无法按预期工作。这将不允许外部作用域访问钩子中分配给y1
的值。您可能需要考虑将
y1
作为模型的一个属性,或者使用一个可以在钩子函数内部和外部访问的列表/字典对象。Python的list
或dict
对象是可变的,当传递给函数时,函数可以修改它们。下面是一个如何使用属性执行此操作的示例:
字符串
或者使用列表:
型
至于你问题的另一部分,是的,将钩子计算的变换添加到损失函数中是可行的,只要确保
y1
Tensor保持其梯度,反向传播就可以按预期工作。在给定的示例中,您已经使用
.detach()
将y1
从计算图中分离出来。这将防止梯度通过它反向传播。因此,需要在损失函数中从y1
中删除.detach()
:型
此外,在每次向前传递之前,请小心清除列表或重置属性,以防止存储不必要的计算历史,这将消耗内存并可能导致不正确的结果。