B = 2 # batches
T = 5 # tokens
in_features = 27
out_features = 10
# Create an input tensor of shape (B, T, in_features)
x = torch.randn(B, T, in_features)
# Create a linear layer that transforms from in_features to out_features
linear_layer = nn.Linear(in_features, out_features, bias=False)
# Check the shape of the weight matrix
w = linear_layer.weight
assert w.shape == (out_features, in_features)
# Compute result by matrix multiplication
matmul_result = (x.reshape(-1, in_features) @ w.T).reshape(*x.shape[:-1], out_features)
# Compute result by calling the layer
call_result = linear_layer(x)
# Assert they are equal
assert torch.all(matmul_result == call_result).item() is True
2条答案
按热度按时间ohfgkhjo1#
在
linear.weight
中学习in_features * out_features
参数,在linear.bias
中学习out_features
参数。您可以将nn.Linear
视为1.将Tensor重塑为某个
(N', in_features)
,其中N'
是N
与*
描述的所有维度的乘积:input_2d = input.reshape(-1, in_features)
1.应用标准矩阵-矩阵乘法
output_2d = linear.weight @ input_2d
。1.添加偏置
output_2d += linear.bias.reshape(1, in_features)
(请注意,我们将其广播到所有N'
维)1.修改输出,使其具有与
input
相同的尺寸,除了最后一个:output = output_2d.reshape(*input.shape[:-1], out_features)
return output
因此,前导维度
N
与*
维度的处理方式相同。文档中明确说明了N
的输入必须 * 至少 * 为2d,但可以根据您的需要任意多维。ccrfmcuu2#
我知道这是一个老问题,但我今天试图自己弄清楚这一点,并提出了一些基于Jatentaki的answer的工作代码,帮助它为我点击。我想分享一下也许会有帮助。
备注:
output_2d = linear.weight @ input_2d
,但我发现矩阵维度并不这样排列。但是input_2d @ linear.weight.T
的等价物对我来说是有效的(这就是文档描述线性变换的方式:)。