pytorch `torch.nn.module`前向传递,带非批处理输出

kokeuurv  于 2023-08-05  发布在  其他
关注(0)|答案(1)|浏览(104)

我想让Pytorch网络作为它的前向传递接受一些批处理输入和非批处理输入。我想知道是否有一个美学的方式来做这件事?

Model(nn.Module):
   ...

def forward(self, t : float, y : Tensor) -> Tensor :
    ...

字符串
我想知道是否可以让forward方法将形状为[batch, ...]的批量数据y和标量t作为输入。除了将t广播到形状[batch, 1]之外,还有什么方法可以做到这一点吗?

nhn9ugyo

nhn9ugyo1#

在PyTorch中,确实可以将标量数据和批处理数据都馈送到forward函数中。前者通过在PyTorch中广播的方式自动扩展到后者的维度。

import torch
import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.fc = nn.Linear(20, 20)

    def forward(self, t: float, y: torch.Tensor) -> torch.Tensor:
        y = self.fc(y)
        y = F.relu(y)
        y = t * y 
        return y

model = Model()
y = torch.randn((10, 20))  # batch of size 10
t = 0.5  # scalar

output = model(t, y)

字符串

相关问题