Pytorch中NN.线性层在附加维数上的应用

4dc9hkyq  于 2023-10-20  发布在  其他
关注(0)|答案(2)|浏览(133)

pytorch中的全连接层(nn.Linear)是如何应用在“附加维度”上的?文档说,它可以应用于连接Tensor(N,*,in_features)(N,*,out_features),其中N在一个批次中的示例数量中,所以它是无关紧要的,而*是那些“额外”的维度。这是否意味着使用额外维度中所有可能的切片来训练单个层,或者为每个切片训练单独的层,或者其他不同的层?

ohfgkhjo

ohfgkhjo1#

linear.weight中学习in_features * out_features参数,在linear.bias中学习out_features参数。您可以将nn.Linear视为
1.将Tensor重塑为某个(N', in_features),其中N'N*描述的所有维度的乘积:input_2d = input.reshape(-1, in_features)
1.应用标准矩阵-矩阵乘法output_2d = linear.weight @ input_2d
1.添加偏置output_2d += linear.bias.reshape(1, in_features)(请注意,我们将其广播到所有N'维)
1.修改输出,使其具有与input相同的尺寸,除了最后一个:output = output_2d.reshape(*input.shape[:-1], out_features)

  1. return output
    因此,前导维度N*维度的处理方式相同。文档中明确说明了N的输入必须 * 至少 * 为2d,但可以根据您的需要任意多维。
ccrfmcuu

ccrfmcuu2#

我知道这是一个老问题,但我今天试图自己弄清楚这一点,并提出了一些基于Jatentakianswer的工作代码,帮助它为我点击。我想分享一下也许会有帮助。
备注:

  • 为了简单起见,我省略了偏见
  • Jatentaki说矩阵乘法看起来像output_2d = linear.weight @ input_2d,但我发现矩阵维度并不这样排列。但是input_2d @ linear.weight.T的等价物对我来说是有效的(这就是文档描述线性变换的方式:

    )。
B = 2 # batches
T = 5 # tokens

in_features = 27 
out_features = 10 

# Create an input tensor of shape (B, T, in_features)
x = torch.randn(B, T, in_features)

# Create a linear layer that transforms from in_features to out_features
linear_layer = nn.Linear(in_features, out_features, bias=False)

# Check the shape of the weight matrix
w = linear_layer.weight
assert w.shape == (out_features, in_features)

# Compute result by matrix multiplication
matmul_result = (x.reshape(-1, in_features) @ w.T).reshape(*x.shape[:-1], out_features)

# Compute result by calling the layer
call_result = linear_layer(x)

# Assert they are equal 
assert torch.all(matmul_result == call_result).item() is True

相关问题