pytorch 不清楚以下代码中.repeat()的用法

dgsult0t  于 2023-01-09  发布在  其他
关注(0)|答案(1)|浏览(193)

在我以前的同事给我的一个python代码中,有一行我看不懂,那就是,

single_channel_depth_ch = single_channel_depth.unsqueeze(1).repeat(1,19,1,1)

其中,single_channel_depth是一个维数为(1408,376,1)的数组,我知道unsqueeze()会删除值为1的维数,并生成一个(1408,376)的数组,但我不明白.repeat(1,19,1,1)的含义。
正如我所检查的,我发现repeat()将用作为参数传递给它的元素替换数组元素。
我的理解是正确的吗?或者上面提到的代码行有其他含义吗?

aij0ehis

aij0ehis1#

unsqueeze方法 * 向数组添加 * 一个 * 单例 *(大小为1)维度。
第一个参数是Tensor,第二个参数是插入单元素维的索引

x = torch.tensor([1, 2, 3, 4])
torch.unsqueeze(x, 0)

Tensor1,2,3,4
x一个一个一个一个x一个一个二个x
关于repeat方法,您的理解是正确的,它沿着给定的维度复制数组的元素。
传入该方法的参数是沿 * each * 维重复该Tensor的次数。
一个一个三个一个一个一个一个一个四个一个一个一个一个一个五个一个
手电筒。大小([4,2,3])
在您提供的代码中,repeat方法沿第二维复制数组19次,沿第三维和第四维复制数组一次。
这将产生具有形状(1408,19,376,1)的新阵列,该阵列是通过沿着第二维度将原始阵列复制19次而创建的,第三和第四维度保持不变。

相关问题