pytorch 将bfloat16保存为二进制格式

utugiqy6  于 2023-08-05  发布在  其他
关注(0)|答案(1)|浏览(232)

bfloattorch.tensor作为原始二进制文件保存到磁盘的惯用方法是什么?下面的代码将抛出错误,因为numpy不支持bfloat16

import torch
import numpy as np

tensor = torch.tensor([1, 2, 3, 4, 5]).bfloat16()

# TypeError: Got unsupported ScalarType BFloat16
arr = tensor.numpy()
arr.tofile("output.bin")

字符串

iyfamqjs

iyfamqjs1#

如果我们假设在Python中任何对象都可以被序列化,你可以简单地用Python的pickile注册你的arr对象。
第一个月
with open(""output.bin", "wb") as me:
pickle.dump(arr, me)

相关问题