python pytorch是如何自动知道我的模型参数的?

gg0vcinb  于 2023-02-07  发布在  Python
关注(0)|答案(1)|浏览(136)

我定义了定制类,如下所示:

class MLPClassifier(nn.Module):
    """
    A basic multi-layer perceptron classifier with 3 layers.
    """
  
    def __init__(self, input_size, hidden_size, num_classes):
        """
        The constructor for the MLPClassifier class.
        """
        super(MLPClassifier, self).__init__()

        self.fc1 = nn.Linear(input_size, hidden_size)  # weights & biases for the input-to-hidden layer
        self.ac1 = nn.ReLU()                           # non-linear activation for the input-to-hidden layer
        self.fc2 = nn.Linear(hidden_size, num_classes) # weights & biases for the hidden-to-output layer
        self.ac2 = nn.Softmax(dim=1)                   # non-linear activation for the hidden-to-output layer

当我运行下面的脚本时,我得到了这个:

hyper_param_input_size  = 4
hyper_param_hidden_size = 64
hyper_param_num_classes = 3

model = MLPClassifier(hyper_param_input_size, hyper_param_hidden_size, hyper_param_num_classes)

for p in model.parameters():
    print(p.shape)

>>> torch.Size([64, 4])
>>> torch.Size([64])
>>> torch.Size([3, 64])
>>> torch.Size([3])

PyTorch究竟是如何自动知道我内部定义的属性的呢?我从来没有明确地告诉过它,它是否循环遍历类中的所有内容,并检查是否为isinstance(self, nn.Layer)或其他内容?

eqqqjvef

eqqqjvef1#

nn.Module.parameters函数将递归遍历父模块的所有子模块并返回其所有参数。它与MLPClassifier模块的实际结构无关。当在__init__中定义新的子模块属性时,父模块会将其注册为子模块,这样它们的参数(如果有,* 例如 * 您的nn.ReLUnn.Softmax没有任何...)以后可以通过parameters调用访问。

相关问题