使用PyTorch将模型保存为.h5格式

vlju58qv  于 2023-02-12  发布在  其他
关注(0)|答案(1)|浏览(646)

我使用PyTorch将模型导出为.h5,如下所示。

torch.save(model.state_dict(), 'model.h5')

但是当我在TensorFlow中加载这个model.h5时,我收到了这个错误。

File "h5py/h5f.pyx", line 106, in h5py.h5f.open
OSError: Unable to open file (file signature not found)

我尝试了两种保存模型的方法。

# 1.
torch.save(model.state_dict(), 'model.h5')

# 2.
torch.save(model, 'model.h5')

在这两种情况下,我都得到了相同的错误。

qxsslcnc

qxsslcnc1#

你不能这么做!
根据官方的TF documentation,TensorFlow使用HDF5文件格式,而PyTorch不使用。Pytorch使用Python的pickle来存储权重。
仅仅传递一个随机的文件扩展名并不能神奇地将其转换为该文件格式!你必须明确地将你的权重保存为hdf 5。

相关问题