在PyTorch中向二维Tensor追加零行

yzckvree  于 2022-11-29  发布在  其他
关注(0)|答案(2)|浏览(597)

假设我有一个2DTensorx,形状为(n,m)。如何通过指定零行在生成的Tensor中的位置索引,在x中添加零行来扩展Tensor的第一维?举个具体的例子:

x = torch.tensor([[1,1,1],
                  [2,2,2],
                  [3,3,3],
                  [4,4,4]])

我想附加2个零行,这样它们的行索引在结果Tensor中分别为1、3,也就是说,在本例中,结果为

X = torch.tensor([1,1,1],
                 [0,0,0],
                 [2,2,2],
                 [0,0,0],
                 [3,3,3],
                 [4,4,4]])

我尝试使用F.padreshape

wvt8vs2t

wvt8vs2t1#

您可以使用torch.tensor.index_add_

import torch

zero_index = [1, 3]
size = (6, 3)

x = torch.tensor([[1,1,1],
                  [2,2,2],
                  [3,3,3],
                  [4,4,4]])

t = torch.zeros(size, dtype=torch.int64)
index = torch.tensor([i for i in range(size[0]) if i not in zero_index])
# index -> tensor([0, 2, 4, 5])

t.index_add_(0, index, x)
print(t)

输出量:

tensor([[1, 1, 1],
        [0, 0, 0],
        [2, 2, 2],
        [0, 0, 0],
        [3, 3, 3],
        [4, 4, 4]])
f1tvaqid

f1tvaqid2#

您可以使用torch.cat

def insert_zeros(x, all_j):
    zeros_ = torch.zeros_like(x[:1])
    pieces = []
    i      = 0
    for j in all_j + [len(x)]:
        pieces.extend([x[i:j],
                       zeros_])
        i = j
    return torch.cat(pieces[:-1],
                      dim=0     )

# insert_zeros(x, [1,2])
# tensor([[1, 1, 1],
#         [0, 0, 0],
#         [2, 2, 2],
#         [0, 0, 0],
#         [3, 3, 3],
#         [4, 4, 4]])

此代码与反向传播兼容,因为Tensor不会就地修改。
更多信息:What's the difference between torch.stack() and torch.cat()?

相关问题