PyTorch RNN使用`batch_first=False`更有效?

flmtquvp  于 2023-04-21  发布在  其他
关注(0)|答案(1)|浏览(128)

在机器翻译中,我们总是需要在注解和预测中切出第一个时间步(SOS令牌)。
当使用batch_first=False时,切掉第一个时间步仍然保持Tensor连续。

import torch
batch_size = 128
seq_len = 12
embedding = 50

# Making a dummy output that is `batch_first=False`
batch_not_first = torch.randn((seq_len,batch_size,embedding))
batch_not_first = batch_first[1:].view(-1, embedding) # slicing out the first time step

然而,如果我们使用batch_first=True,切片后,Tensor不再是连续的。我们需要在进行view等不同操作之前使其连续。

batch_first = torch.randn((batch_size,seq_len,embedding))
batch_first[:,1:].view(-1, embedding) # slicing out the first time step

output>>>
"""
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-8-a9bd590a1679> in <module>
----> 1 batch_first[:,1:].view(-1, embedding) # slicing out the first time step

RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.
"""

这是否意味着batch_first=False更好,至少,在机器翻译的背景下?因为它使我们不必执行contiguous()步骤。是否有任何情况下,batch_first=True工作得更好?

bvjxkvbb

bvjxkvbb1#

性能

batch_first=Truebatch_first=False之间似乎没有太大的区别。请参阅下面的脚本:

import time

import torch

def time_measure(batch_first: bool):
    torch.cuda.synchronize()
    layer = torch.nn.RNN(10, 20, batch_first=batch_first).cuda()
    if batch_first:
        inputs = torch.randn(100000, 7, 10).cuda()
    else:
        inputs = torch.randn(7, 100000, 10).cuda()

    torch.cuda.synchronize()
    start = time.perf_counter()

    for chunk in torch.chunk(inputs, 100000 // 64, dim=0 if batch_first else 1):
        _, last = layer(chunk)

    torch.cuda.synchronize()
    return time.perf_counter() - start

print(f"Time taken for batch_first=False: {time_measure(False)}")
print(f"Time taken for batch_first=True: {time_measure(True)}")

在我的设备(GTX 1050 Ti)上,PyTorch 1.6.0和CUDA 11.0的结果如下:

Time taken for batch_first=False: 0.3275816479999776
Time taken for batch_first=True: 0.3159054920001836

(and它以任何方式变化,所以没有结论性的)。

代码可读性

当你想使用其他需要batch作为0第一维的PyTorch层时,batch_first=True更简单(几乎所有的torch.nn层都是这样,比如torch.nn.Linear)。
在这种情况下,如果指定了batch_first=False,则无论如何都必须返回permuteTensor。

机器翻译

它应该更好,因为tensor始终是连续的,并且不需要复制数据。使用[1:]而不是[:,1:]进行切片看起来也更干净。

相关问题