如何创建块对角权重Tensor而不使用torch.block_diag?(PyTorch)

nfg76nw0  于 2023-03-18  发布在  其他
关注(0)|答案(1)|浏览(200)

下面的代码完成了我在2D中寻找的功能,但没有推广到更高的维度:

M = torch.block_diag(M1, M2, M3, M4)

这里M1, M2, M3, M4m x m权Tensor。
因此,我尝试创建相同的块对角权重Tensor,但不使用torch.block_diag

N = torch.zeros(n, n)
N[:m, :m] = M1
N[m:2*m, m:2*m] = M2
N[2*m:3*m, 2*m:3*m] = M3
N[3*m:, 3*m:] = M4

得到的Tensor等于上面的结果,即torch.equal((M, N))返回True,但在loss.backward()步骤中得到以下错误:

RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

编辑:下面是完整的代码:

class BlockLinear(nn.Module):
    def __init__(self, n):
        super().__init__()
        self.n = n
        self.m = int(np.sqrt(self.n))

        self.M1 = nn.Parameter(torch.randn(m, m))
        self.M2 = nn.Parameter(torch.randn(m, m))
        self.M3 = nn.Parameter(torch.randn(m, m))
        self.M4 = nn.Parameter(torch.randn(m, m))
        
        # This works!
        self.M = torch.block_diag(self.M1, self.M2, self.M3, self.M4) 
        
        # This doesn't work!
        #self.M = torch.zeros(n, n)
        #self.M[:m, :m] = self.M1
        #self.M[m:2*m, m:2*m] = self.M2
        #self.M[2*m:3*m, 2*m:3*m] = self.M3
        #self.M[3*m:, 3*m:] = self.M4 
    
    def forward(self, x):
        x = torch.einsum('ij, bi -> bj', self.M, x)
        return x

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(28 * 28, 16)
        self.fc2 = BlockLinear(16, 16)
        self.fc3 = BlockLinear(16, 16)
        self.fc4 = nn.Linear(16, 10)

    def forward(self, x):
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = self.fc4(x)
        return x

net = Net()
bmp9r5qi

bmp9r5qi1#

如果只想让非零元素作为参数,则需要在每个推理步骤中计算M,具体方法如下:

class BlockLinear(nn.Module):
    def __init__(self, n):
        super().__init__()
        self.n = n
        m = n//4

        self.M1 = nn.Parameter(torch.randn(m, m))
        self.M2 = nn.Parameter(torch.randn(m, m))
        self.M3 = nn.Parameter(torch.randn(m, m))
        self.M4 = nn.Parameter(torch.randn(m, m))
    
    def diag(self):
        M = torch.zeros(self.n, self.n)
        m = self.n//4
        M[:m, :m] = self.M1
        M[m:2*m, m:2*m] = self.M2
        M[2*m:3*m, 2*m:3*m] = self.M3
        M[3*m:, 3*m:] = self.M4 
        return M
    
    def forward(self, x):
        return x@self.diag()

相关问题