我尝试在模型的向后传递过程中打印每个中间梯度的值,使用寄存器向后钩子:
class func_NN(torch.nn.Module):
def __init__(self,) :
super().__init__()
self.a = torch.nn.Parameter(torch.rand(1))
self.b = torch.nn.Parameter(torch.rand(1))
def forward(self, inp):
mul_x = torch.cos(self.a.view(-1,1)*inp)
sum_x = mul_x - self.b
return sum_x
# hook function
def backward_hook(module, grad_input, grad_output):
print("module: ", module)
print("inp: ", grad_input)
print("out: ", grad_output)
# Training
# Generate labels
a = torch.Tensor([0.5])
b = torch.Tensor([0.8])
x = torch.linspace(-1, 1, 10)
y = a*x + (0.1**0.5)*torch.randn_like(x)*(0.001) + b
inp = torch.linspace(-1, 1, 10)
foo = func_NN()
handle_ = foo.register_full_backward_hook(backward_hook)
loss = torch.nn.MSELoss()
optim = torch.optim.Adam(foo.parameters(),lr=0.001)
t_l = []
for i in range(2):
optim.zero_grad()
l = loss(y, foo.forward(inp=inp))
t_l.append(l.detach())
l.backward()
optim.step()
handle_.remove()
字符串
但这并没有提供预期的结果。
我的目标是打印非叶节点的梯度,如sum_x
和mul_x
。请帮助。
1条答案
按热度按时间zsohkypk1#
Pytorch钩子被设计用来抓取参数的梯度。你不能用它们来抓取中间Tensor的梯度。
如果你想获得中间Tensor的梯度,你需要将它们保存到模型的状态中,并对它们应用
retain_grad
。字符串