如何在pytorch / numpy中删除for循环?

nbysray5  于 2023-05-22  发布在  其他
关注(0)|答案(1)|浏览(121)

这是一个典型的问题,虽然我没有解决它的线索。设timesteps = [12, 5, 6, 7]或一维向量。设noise是一个形状为[5, 1, 1, 28]的矩阵。我想创建形状为[1, 1, 28]的矩阵,其值为timesteps,例如。12 * torch.ones([1, 1, 28]) .最后我把它们联系起来了。我创建了这个代码

for scalar in range(int(timesteps.shape[0])):
    list_mask.append(
        scalar * torch.ones((1, *noise.shape[1:]))
    )

mask = torch.cat(list_mask)

你知道如何不使用循环但得到相同的结果吗?

tf7tbtn2

tf7tbtn21#

由于只有四个时间步长,我假设噪声实际上是4x1x1x28。使用广播,您可以将形状为4x1x1x28的一数组乘以整形的时间步长作为4x1x1x1数组。广播将沿着后三个轴重复时间步长。

noise = np.ones((4, 1, 1, 28))*np.array(timesteps)[:, None, None, None]

相关问题