为了简单起见,假设我想用下面的代码将 Torch 模型的所有参数设置为常量72114982
model = Net()
params = model.state_dict()
for k, v in params.items():
params[k] = torch.full(v.shape, 72114982, dtype=torch.long)
model.load_state_dict(params)
print(model.state_dict().values())
字符串
然后print语句显示所有的值实际上都被设置为72114984
,这与我最初想要的值相差2。
为简单起见,如下定义Net
class Net(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(2, 2, 2)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(2, 2)
型
1条答案
按热度按时间vptzau2j1#
这是数据类型的问题。
模型参数被转换为浮点Tensor。
72114984
足够大,其浮点表示舍入为72114984
。您可以通过以下方式验证这一点:
字符串