将bfloat
torch.tensor
作为原始二进制文件保存到磁盘的惯用方法是什么?下面的代码将抛出错误,因为numpy不支持bfloat16
。
import torch
import numpy as np
tensor = torch.tensor([1, 2, 3, 4, 5]).bfloat16()
# TypeError: Got unsupported ScalarType BFloat16
arr = tensor.numpy()
arr.tofile("output.bin")
字符串
1条答案
按热度按时间iyfamqjs1#
如果我们假设在Python中任何对象都可以被序列化,你可以简单地用Python的pickile注册你的arr对象。
第一个月
with open(""output.bin", "wb") as me:
个pickle.dump(arr, me)
个