pytorch 如何根据另一个数组的唯一索引拆分多维数组?

thigvfpy  于 2022-11-23  发布在  其他
关注(0)|答案(2)|浏览(156)

我有两个Tensortorchb

import torch
torch.manual_seed(0) # for reproducibility

a = torch.rand(size = (5, 10, 1))
b = torch.tensor([3, 3, 1, 5, 3, 1, 0, 2, 1, 2])

我想根据b中的唯一值拆分a的第二维(在Python中为dim = 1)。
到目前为止,我已经尝试过:

# find the unique values and unique indices of b
unique_values, unique_indices = torch.unique(b, return_inverse = True)

# split a in where dim = 1, based on unique indices
l = torch.tensor_split(a, unique_indices, dim = 1)

我希望l是一个n个Tensor的列表,其中nb中唯一值的数量。我还希望Tensor具有以下形状(5,与unique_values对应的元素数量,1)。
但是,我得到以下结果:

print(l)

(tensor([[[0.8198],
         [0.9971],
         [0.6984]],

        [[0.7262],
         [0.7011],
         [0.2038]],

        [[0.1147],
         [0.3168],
         [0.6965]],

        [[0.0340],
         [0.9442],
         [0.8802]],

        [[0.6833],
         [0.7529],
         [0.8579]]]), tensor([], size=(5, 0, 1)), tensor([], size=(5, 0, 1)), tensor([[[0.9971],
         [0.6984],
         [0.5675]],

        [[0.7011],
         [0.2038],
         [0.6511]],

        [[0.3168],
         [0.6965],
         [0.9143]],

        [[0.9442],
         [0.8802],
         [0.0012]],

        [[0.7529],
         [0.8579],
         [0.6870]]]), tensor([], size=(5, 0, 1)), tensor([], size=(5, 0, 1)), tensor([], size=(5, 0, 1)), tensor([[[0.8198],
         [0.9971]],

        [[0.7262],
         [0.7011]],

        [[0.1147],
         [0.3168]],

        [[0.0340],
         [0.9442]],

        [[0.6833],
         [0.7529]]]), tensor([], size=(5, 0, 1)), tensor([[[0.9971]],

        [[0.7011]],

        [[0.3168]],

        [[0.9442]],

        [[0.7529]]]), tensor([[[0.6984],
         [0.5675],
         [0.8352],
         [0.2056],
         [0.5932],
         [0.1123],
         [0.1535],
         [0.2417]],

        [[0.2038],
         [0.6511],
         [0.7745],
         [0.4369],
         [0.5191],
         [0.6159],
         [0.8102],
         [0.9801]],

        [[0.6965],
         [0.9143],
         [0.9351],
         [0.9412],
         [0.5995],
         [0.0652],
         [0.5460],
         [0.1872]],

        [[0.8802],
         [0.0012],
         [0.5936],
         [0.4158],
         [0.4177],
         [0.2711],
         [0.6923],
         [0.2038]],

        [[0.8579],
         [0.6870],
         [0.0051],
         [0.1757],
         [0.7497],
         [0.6047],
         [0.1100],
         [0.2121]]]))

为什么我会得到像tensor([], size=(5, 0, 1))这样的空Tensor,我如何实现我想要实现的呢?

guykilcj

guykilcj1#

根据您对所需结果的描述:
我还期望Tensor的形状是(5, number of elements corresponding to unique_values, 1)
我相信你正在寻找唯一值的 count(或 frequency)。如果你想继续使用torch.unique,那么你可以提供return_counts参数并调用torch.cumsum
应该可以这样做:

>>> indices = torch.cumsum(counts, dim=0)
>>> splits = torch.tensor_split(a, indices[:-1], dim = 1)

让我们看一看:

>>> for x in splits:
...     print(x.shape)
torch.Size([5, 1, 1])
torch.Size([5, 3, 1])
torch.Size([5, 2, 1])
torch.Size([5, 3, 1])
torch.Size([5, 1, 1])
vdzxcuhz

vdzxcuhz2#

您是否正在寻找index_select方法?
您已在unique_values中正确获取唯一
现在您需要做的是:

l = a.index_select(1, unique_values)

相关问题