pytorch 如何加快www.example.com的运行torch.cat速度?

yk9xbfzb  于 2022-11-09  发布在  其他
关注(0)|答案(1)|浏览(140)

下面是我想做的事情的简化版本:

import torch
import time

# Create dummy tensors and save them in my_list

my_list = [[]] * 100
for i in range(len(my_list)):
    my_list[i] = torch.randint(0, 1000000000, (100000, 256))
concat_list = torch.tensor([])

# I want to concat two consecutive tensors in my_list

tic = time.time()
for i in range(0, len(my_list), 2):
    concat_list = torch.cat((concat_list, my_list[i]))
    concat_list = torch.cat((concat_list, my_list[i+1]))
    # Do some work at CPU with concat_list
    concat_list = torch.tensor([]) # Empty concat_list
print('time: ', time.time() - tic) # It takes 3.5 seconds in my environment

有没有什么方法可以让上面的Tensor拼接更快?
我尝试将my_list[i]my_list[i+1]concat_list发送到GPU,并在设备中执行torch.cat函数,但随后我必须将concat_list发送回CPU,以完成上面所写的“一些工作”。由于频繁的GPU-CPU数据传输,这需要更多的时间。
我还测试了将Tensor转换为列表,以便与基本的Python列表进行连接,但这种方法比简单的torch.cat方法慢得多。
我听说将DataLoader与自定义的collate_fn一起使用可以启用连接,但我不知道如何实现它。
有没有更快的方法?

a14dhokn

a14dhokn1#

你的代码在我的电脑上大约需要11秒。下面的代码需要4.1秒:


# Create dummy tensors and save them in my_list

my_list = [[]] * 100
for i in range(len(my_list)):
    my_list[i] = torch.randint(0, 1000000000, (100000, 256))
tic = time.time()
my_list = torch.stack(my_list)

# I want to concat two consecutive tensors in my_list

for i in range(0, len(my_list), 2):
    concat_list = my_list[i:i+2]
    # Do some work at CPU with concat_list
print('time: ', time.time() - tic) # It takes 3.5 seconds in my environment

相关问题