什么是Pytorch Geometric中的“to_dense_batch”的逆?

dy2hfwbg  于 2023-03-30  发布在  其他
关注(0)|答案(1)|浏览(198)

to_dense_batch(doc)将mini-batch转换为dense batch。如何将dense batch转换回mini-batch?是否有类似于“from_dense_batch”的方法,它接受dense_batchmask,并给出mini-batched data

wydwbb8l

wydwbb8l1#

我找到了一个可行的解决方案,但我不确定它在哪里是最好的实现。以下是我的代码与测试示例。

from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader
import torch
from torch_geometric.utils import to_dense_batch

def from_dense_batch(dense_bath, mask):
    # dense batch, B, N, F
    # mask, B, N
    B, N, F = dense_bath.size()
    flatten_dense_batch = dense_bath.view(-1, F)
    flatten_mask = mask.view(-1)
    data_x = flatten_dense_batch[flatten_mask, :]
    num_nodes = torch.sum(mask, dim=1)  # B, like 3,4,3
    pr_value = torch.cumsum(num_nodes, dim=0)  # B, like 3,7,10
    indicator_vector = torch.zeros(torch.sum(num_nodes, dim=0))
    indicator_vector[pr_value[:-1]] = 1  # num_of_nodes, 0,0,0,1,0,0,0,1,0,0,1
    data_batch = torch.cumsum(indicator_vector, dim=0)  # num_of_nodes, 0,0,0,1,1,1,1,1,2,2,2
    return data_x, data_batch

dataset = TUDataset(root='/tmp/ENZYMES', name='ENZYMES', use_node_attr=True)
loader = DataLoader(dataset, batch_size=32, shuffle=True)

test_batch = next(iter(loader))
dense_data, mask = to_dense_batch(test_batch.x, test_batch.batch)
output_data, output_batch = from_dense_batch(dense_data, mask)
print((test_batch.x == output_data).all())
print((test_batch.batch == output_batch).all())

相关问题