pytorch 如何创建具有固定形状和dtype的高维Tensor?

cidc1ykv  于 2023-03-30  发布在  其他
关注(0)|答案(3)|浏览(171)

我想返回一个固定形状大小的Tensor,比如torch.Size([1,345])

import torch
pt1 = torch.tensor(data=(1, 345), dtype=torch.int64)

它只返回torch.Size([2])
我跟着Tensor教程试着

pt1 = torch.tensor(1, 345, dtype=torch.int64)
pt1 = torch.tensor((1, 345), dtype=torch.int64)
pt1 = torch.tensor(shape=(1, 345), dtype=torch.int64)

它仍然显示像tensor() takes 1 positional argument but 2 were given这样的错误,我知道一些方法意味着数据是(1,345)不是形状,但是...我是pytorch的新手,仍然没有找到它的解决方案...

zqdjd7g9

zqdjd7g91#

torch.tensordata参数应该是 “Tensor的初始数据”,而不是它的形状。
所有Tensor都有固定的形状。你可以用torch.empty定义一个未初始化的Tensor,或者用特定的值(torch.zerostorch.onestorch.full,...)来填充它们。例如:

>>> pt1 = torch.empty(1, 345, dtype=torch.int64)
o8x7eapl

o8x7eapl2#

你可以使用torch.zeroses来创建一个新的Tensor

import torch

a = torch.zeros((1,345), dtype=torch.int64)
print(a.shape)
torch.Size([1, 345])
j8yoct9x

j8yoct9x3#

根据pytorch关于torch.tensor的文档,您的torch.tensor(data=(1, 345), dtype=torch.int64)代码正在创建一个2元素的Tensor[1,345](因此它的大小为[2])。要创建您想要的Tensor,请用途:

import torch

pt1 = torch.zeros((1, 345), dtype=torch.int64)
pt1.size()
# torch.Size([1, 345])

相关问题