import torch
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.fc = nn.Linear(20, 20)
def forward(self, t: float, y: torch.Tensor) -> torch.Tensor:
y = self.fc(y)
y = F.relu(y)
y = t * y
return y
model = Model()
y = torch.randn((10, 20)) # batch of size 10
t = 0.5 # scalar
output = model(t, y)
1条答案
按热度按时间nhn9ugyo1#
在PyTorch中,确实可以将标量数据和批处理数据都馈送到
forward
函数中。前者通过在PyTorch中广播的方式自动扩展到后者的维度。字符串