我定义了定制类,如下所示:
class MLPClassifier(nn.Module):
"""
A basic multi-layer perceptron classifier with 3 layers.
"""
def __init__(self, input_size, hidden_size, num_classes):
"""
The constructor for the MLPClassifier class.
"""
super(MLPClassifier, self).__init__()
self.fc1 = nn.Linear(input_size, hidden_size) # weights & biases for the input-to-hidden layer
self.ac1 = nn.ReLU() # non-linear activation for the input-to-hidden layer
self.fc2 = nn.Linear(hidden_size, num_classes) # weights & biases for the hidden-to-output layer
self.ac2 = nn.Softmax(dim=1) # non-linear activation for the hidden-to-output layer
当我运行下面的脚本时,我得到了这个:
hyper_param_input_size = 4
hyper_param_hidden_size = 64
hyper_param_num_classes = 3
model = MLPClassifier(hyper_param_input_size, hyper_param_hidden_size, hyper_param_num_classes)
for p in model.parameters():
print(p.shape)
>>> torch.Size([64, 4])
>>> torch.Size([64])
>>> torch.Size([3, 64])
>>> torch.Size([3])
PyTorch究竟是如何自动知道我内部定义的属性的呢?我从来没有明确地告诉过它,它是否循环遍历类中的所有内容,并检查是否为isinstance(self, nn.Layer)
或其他内容?
1条答案
按热度按时间eqqqjvef1#
nn.Module.parameters
函数将递归遍历父模块的所有子模块并返回其所有参数。它与MLPClassifier
模块的实际结构无关。当在__init__
中定义新的子模块属性时,父模块会将其注册为子模块,这样它们的参数(如果有,* 例如 * 您的nn.ReLU
和nn.Softmax
没有任何...)以后可以通过parameters
调用访问。