下面是我想做的事情的简化版本:
import torch
import time
# Create dummy tensors and save them in my_list
my_list = [[]] * 100
for i in range(len(my_list)):
my_list[i] = torch.randint(0, 1000000000, (100000, 256))
concat_list = torch.tensor([])
# I want to concat two consecutive tensors in my_list
tic = time.time()
for i in range(0, len(my_list), 2):
concat_list = torch.cat((concat_list, my_list[i]))
concat_list = torch.cat((concat_list, my_list[i+1]))
# Do some work at CPU with concat_list
concat_list = torch.tensor([]) # Empty concat_list
print('time: ', time.time() - tic) # It takes 3.5 seconds in my environment
有没有什么方法可以让上面的Tensor拼接更快?
我尝试将my_list[i]
、my_list[i+1]
和concat_list
发送到GPU,并在设备中执行torch.cat
函数,但随后我必须将concat_list
发送回CPU,以完成上面所写的“一些工作”。由于频繁的GPU-CPU数据传输,这需要更多的时间。
我还测试了将Tensor转换为列表,以便与基本的Python列表进行连接,但这种方法比简单的torch.cat
方法慢得多。
我听说将DataLoader与自定义的collate_fn
一起使用可以启用连接,但我不知道如何实现它。
有没有更快的方法?
1条答案
按热度按时间a14dhokn1#
你的代码在我的电脑上大约需要11秒。下面的代码需要4.1秒: