pytorch 关于理解节点和边索引的说明

mwkjh3gx  于 2022-11-23  发布在  其他
关注(0)|答案(1)|浏览(144)

通过访问data对象的属性,我如何知道哪个节点特性属于哪个节点?如果我理解正确的话,data.x包含节点特性。通过运行下面的for循环,我可以访问特性,但我如何知道它属于节点0还是节点9?

from torch_geometric.data import Data
edge_index = torch.tensor([[0, 1, 1, 2, 1, 9],
                           [1, 0, 2, 1, 8, 1]
                           ], dtype=torch.long)
x = torch.tensor([[-5,7], [0,5], [0,9], [10,9]], dtype=torch.float)
​
data = Data(x=x, edge_index=edge_index)
​
for item in range(0, data.x.shape[0]):
    print(item, data.x[item], data.edge_index.t()[item])
u5rb5r59

u5rb5r591#

在代码中,通过定义x,Pytorch Geometric(根据x的形状)推断存在四个节点。这在文档中有详细说明:
在存在节点级属性(例如data.x)的情况下,自动推断数据对象中的节点的数目。
您还为节点指定了边,直到节点9。如果您尝试在此代码上运行任何模型,我怀疑它会产生错误,因为预期只有4个节点存在。这是因为它将尝试访问x的第9个元素,这将返回索引错误。
最佳做法是定义存在的节点数。如文档中所述:
我们建议通过'data.num_nodes = ...显式设置数据对象中的节点数。

相关问题