pytorch批处理中的拉普拉斯位置编码

cqoc49vn  于 2022-11-09  发布在  其他
关注(0)|答案(1)|浏览(302)

我正在尝试用pytorch重新编码一个图模型的拉普拉斯位置编码。numpy中的有效编码可以在www.example.com找到https://docs.dgl.ai/en/0.9.x/_modules/dgl/transforms/functional.html#laplacian_pe。
我想我已经设法在pytorch中做了一个与numpy等效的编码,但是为了性能问题,我希望该函数能够处理批量数据。
也就是说,下面的函数将adj[N, N]degrees[N, N]topk形式的参数用作整数,其中N是网络中的节点数。

def _laplacian_positional_encoding_th(self, adj, degrees, topk):
    number_of_nodes = adj.shape[-1].
    #degrees = th.clip(degrees, 0, 1) # not multigraph
    assert topk < number_of_nodes

    # Laplacian
    D = th.diag(degrees**-0.5)
    B = D * adj * D
    L = th.eye(number_of_nodes).to(B.device) * B

    # Eigenvectors
    EigVal, EigVec = th.linalg.eig(L)
    idx = th.argsort(th.real(EigVal)) # increasing order
    EigVal, EigVec = th.real(EigVal[idx]), th.real(EigVec[:,idx])

    # Only select [1,topk+1] EigenVectors as L is symmetric (Spectral decomposition)
    out = EigVec[:,1:topk+1]
    return out

然而,当我试图以批处理形式执行同样有效的操作时,我无法对其进行编码。也就是说,我们的想法是参数可以以adj[B, N, N]degrees[B, N, N]和topk的形式出现,B是批处理中的数据数量。

4zcjmb1e

4zcjmb1e1#

怎么样:

def _laplacian_positional_encoding_th(self, adj, degrees, topk):
        number_of_nodes = adj.shape[-1]
        assert topk < number_of_nodes

        D = th.clip(degrees, 0, 1) # not multigraph
        B = D @ adj @ D
        L = th.eye(number_of_nodes).to(B.device)[None, ...] - B 

        # Eigenvectors
        EigVal, EigVec = th.linalg.eig(L)
        idx = th.argsort(th.real(EigVal)) # increasing order

        out = th.real(th.gather(EigVec, dim=-1, index=idx[..., None]))
        return out

有关创建一批对角矩阵的信息,请参见th.diag_embed;有关根据排序后的索引选择EigVec的右列的信息,请参见th.gather

**更新数据:**如果要提取topk向量:

_, topk = th.topk(EigVal.real, k=5)  # get the top 5
out = th.gather(EigVec.real, dim=-1, index=topk[:, None, :].expand(-1, EigVec.shape[1], -1))

相关问题