PyTorch load_state_dict()不加载精确值

f45qwnt8  于 11个月前  发布在  其他
关注(0)|答案(1)|浏览(134)

为了简单起见,假设我想用下面的代码将 Torch 模型的所有参数设置为常量72114982

model = Net()
params = model.state_dict()

for k, v in params.items():
    params[k] = torch.full(v.shape, 72114982, dtype=torch.long) 

model.load_state_dict(params)
print(model.state_dict().values())

字符串
然后print语句显示所有的值实际上都被设置为72114984,这与我最初想要的值相差2。
为简单起见,如下定义Net

class Net(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(2, 2, 2)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(2, 2)

vptzau2j

vptzau2j1#

这是数据类型的问题。
模型参数被转换为浮点Tensor。72114984足够大,其浮点表示舍入为72114984
您可以通过以下方式验证这一点:

x = torch.tensor(72114982, dtype=torch.long)
y = x.float() # y will actually be `72114984.0`

# this returns `True` because x is cast to a float before evaluating
x == y
> tensor(True)

# for the same reason, this returns 0.
y - x
> tensor(0.)

# this returns `False` because the tensors have different values and we don't cast to float
x == y.long()
> tensor(False)

# as longs, the difference correctly evaluates to 2
y.long() - x
> tensor(2)

字符串

相关问题