pytorch 基于新旧值之间的元素Map转换二维Tensor的元素?

bq9c1y66  于 2023-01-30  发布在  其他
关注(0)|答案(1)|浏览(173)

我有一个二维Tensor,它包含了其他Tensor的索引。

old = torch.Tensor([
    [1, 2, 12, 12],
    [0, 1, 12, 12],
    [3, 5, 12, 12],
    [7, 8, 12, 12],
    [6, 7, 12, 12],
    [9, 11, 12, 12]])

我有另一个Tensor,它表示oldTensor中元素到newTensor的Map

mapping = torch.Tensor([
    [0, 0],
    [1, 6],
    [2, 1],
    [3, 6],
    [4, 2],
    [5, 6],
    [6, 3],
    [7, 6],
    [8, 4],
    [9, 6],
    [10, 5],
    [11, 6],
    [12, 6]])

也就是说,mapping[:, 0]列表示在old中找到的值,并且[:, 1]表示要转换的值。

new_or_desired = torch.Tensor([
    [6, 1, 6, 6],
    [0, 6, 6, 6],
    [6, 6, 6, 6],
    [6, 4, 6, 6],
    [3, 6, 6, 6],
    [6, 6, 6, 6]])

我已经尝试了很多次迭代,但应用此Map的最佳方法是

old[old == mapping[:, 0]] = mapping[:, 1]

但形状明显不匹配。**如何应用mappingold元素转换为new元素值?**我认为我应该使用scatter_,但我不知道如何正确应用它。

dwthyt8l

dwthyt8l1#

这可以通过torch.take实现

new = torch.take(mapping[:, 1], old.long())

相关问题