PyTorch nn.模块无法取消批处理操作

iqxoj9l9  于 2022-12-13  发布在  其他
关注(0)|答案(1)|浏览(148)

我有一个nn.模块,它的forward函数有两个输入,在函数中,我将其中一个输入x1乘以一组可训练参数,然后将它们与另一个输入x2连接起来。

class ConcatMe(nn.Module):
    def __init__(self, pad_len, emb_size):
        super(ConcatMe, self).__init__()
        self.W = nn.Parameter(torch.randn(pad_len, emb_size).to(DEVICE), requires_grad=True)
        self.emb_size = emb_size
     
    def forward(self, x1: Tensor, x2: Tensor):
        cat = self.W * torch.reshape(x2, (1, -1, 1))
        return torch.cat((x1, cat), dim=-1)

据我所知,人们应该能够在PyTorch的nn.模块中编写操作,就像我们对批处理大小为1的输入所做的那样。由于某种原因,事实并非如此。我得到一个错误,表明PyTorch仍在考虑batch_size。

x1 =  torch.randn(100,2,512)
x2 = torch.randint(10, (2,1))
concat = ConcatMe(100, 512)
concat(x1, x2)

-----------------------------------------------------------------------------------
File "/home/my/file/path.py, line 0, in forward
    cat = self.W * torch.reshape(x2, (1, -1, 1))
RuntimeError: The size of tensor a (100) must match the size of tensor b (2) at non-singleton dimension 1

我做了一个for循环来修补这个问题,如下所示:

class ConcatMe(nn.Module):
    def __init__(self, pad_len, emb_size):
        super(ConcatMe, self).__init__()
        self.W = nn.Parameter(torch.randn(pad_len, emb_size).to(DEVICE), requires_grad=True)
        self.emb_size = emb_size
     
    def forward(self, x1: Tensor, x2: Tensor):
        batch_size = x2.shape[0]
        cat = torch.ones(x1.shape).to(DEVICE)

        for i in range(batch_size):
            cat[:, i, :] = self.W * x2[i]

        return torch.cat((x1, cat), dim=-1)

但是我觉得有一个更优雅的解决方案。这是否与我在nn.Module中创建参数有关?如果是,我可以实现什么解决方案而不需要for循环。

00jrzges

00jrzges1#

据我所知,人们应该能够在PyTorch的nn.Module中编写操作,就像我们对批量大小为 1 的输入所做的那样。
我不确定你是从哪里得到这个假设的,它肯定是 * 不 * 真的--恰恰相反:您始终需要以能够处理任意批次维度的一般情况的方式编写它们。
从你的第二个实现来看,你似乎在尝试将两个维数不相容的Tensor相乘。

self.W = torch.nn.Parameter(torch.randn(pad_len, 1, emb_size), requires_grad=True)

为了更好地理解这类事情,了解broadcasting会有所帮助。

相关问题