使用钩子在Pytorch中向后传递期间打印中间梯度值

ldioqlga  于 12个月前  发布在  其他
关注(0)|答案(1)|浏览(88)

我尝试在模型的向后传递过程中打印每个中间梯度的值,使用寄存器向后钩子:

class func_NN(torch.nn.Module):
    def __init__(self,) :
        super().__init__()
        self.a = torch.nn.Parameter(torch.rand(1))
        self.b = torch.nn.Parameter(torch.rand(1))

    def forward(self, inp):
        mul_x = torch.cos(self.a.view(-1,1)*inp)
        sum_x = mul_x - self.b
        return sum_x

# hook function
def backward_hook(module, grad_input, grad_output):
    print("module: ", module)
    print("inp: ", grad_input)
    print("out: ", grad_output) 

# Training
# Generate labels
a = torch.Tensor([0.5])
b = torch.Tensor([0.8])
x = torch.linspace(-1, 1, 10)
y = a*x + (0.1**0.5)*torch.randn_like(x)*(0.001) + b
inp = torch.linspace(-1, 1, 10)
foo = func_NN()
handle_ = foo.register_full_backward_hook(backward_hook)
loss = torch.nn.MSELoss()
optim = torch.optim.Adam(foo.parameters(),lr=0.001)

t_l = []
for i in range(2):
    optim.zero_grad()
    l = loss(y, foo.forward(inp=inp))
    t_l.append(l.detach())
    l.backward()
    optim.step()
handle_.remove()

字符串
但这并没有提供预期的结果。
我的目标是打印非叶节点的梯度,如sum_xmul_x。请帮助。

zsohkypk

zsohkypk1#

Pytorch钩子被设计用来抓取参数的梯度。你不能用它们来抓取中间Tensor的梯度。
如果你想获得中间Tensor的梯度,你需要将它们保存到模型的状态中,并对它们应用retain_grad

class func_NN(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.a = torch.nn.Parameter(torch.rand(1))
        self.b = torch.nn.Parameter(torch.rand(1))

    def forward(self, inp):
        mul_x = torch.cos(self.a.view(-1, 1) * inp)
        sum_x = mul_x - self.b

        # Retain gradients for intermediate variables
        mul_x.retain_grad()
        sum_x.retain_grad()

        # Store references to the intermediate tensors
        self.mul_x = mul_x
        self.sum_x = sum_x

        return sum_x

a = torch.Tensor([0.5])
b = torch.Tensor([0.8])
x = torch.linspace(-1, 1, 10)
y = a*x + (0.1**0.5)*torch.randn_like(x)*(0.001) + b
inp = torch.linspace(-1, 1, 10)
foo = func_NN()
loss = torch.nn.MSELoss()

l = loss(y, foo.forward(inp=inp))
l.backward()

print(foo.mul_x.grad)
print(foo.sum_x.grad)

字符串

相关问题