to_dense_batch(doc)将mini-batch转换为dense batch。如何将dense batch转换回mini-batch?是否有类似于“from_dense_batch”的方法,它接受dense_batch和mask,并给出mini-batched data?
to_dense_batch
dense_batch
mask
data
wydwbb8l1#
我找到了一个可行的解决方案,但我不确定它在哪里是最好的实现。以下是我的代码与测试示例。
from torch_geometric.datasets import TUDataset from torch_geometric.loader import DataLoader import torch from torch_geometric.utils import to_dense_batch def from_dense_batch(dense_bath, mask): # dense batch, B, N, F # mask, B, N B, N, F = dense_bath.size() flatten_dense_batch = dense_bath.view(-1, F) flatten_mask = mask.view(-1) data_x = flatten_dense_batch[flatten_mask, :] num_nodes = torch.sum(mask, dim=1) # B, like 3,4,3 pr_value = torch.cumsum(num_nodes, dim=0) # B, like 3,7,10 indicator_vector = torch.zeros(torch.sum(num_nodes, dim=0)) indicator_vector[pr_value[:-1]] = 1 # num_of_nodes, 0,0,0,1,0,0,0,1,0,0,1 data_batch = torch.cumsum(indicator_vector, dim=0) # num_of_nodes, 0,0,0,1,1,1,1,1,2,2,2 return data_x, data_batch dataset = TUDataset(root='/tmp/ENZYMES', name='ENZYMES', use_node_attr=True) loader = DataLoader(dataset, batch_size=32, shuffle=True) test_batch = next(iter(loader)) dense_data, mask = to_dense_batch(test_batch.x, test_batch.batch) output_data, output_batch = from_dense_batch(dense_data, mask) print((test_batch.x == output_data).all()) print((test_batch.batch == output_batch).all())
1条答案
按热度按时间wydwbb8l1#
我找到了一个可行的解决方案,但我不确定它在哪里是最好的实现。以下是我的代码与测试示例。