如何在PyTorch中计算三个Tensor的批量协方差?

pgky5nke  于 12个月前  发布在  其他
关注(0)|答案(1)|浏览(107)

假设我们有3个大小为(B, C, H, W)的Tensor,其中B是批量大小,C是通道维度。我希望计算这3个Tensor沿着通道维度的协方差。
我尝试了以下代码:

x1_mean = x1.mean(dim=1).unsqueeze(dim=1)
x2_mean = x2.mean(dim=1).unsqueeze(dim=1)
x3_mean = x3.mean(dim=1).unsqueeze(dim=1)
out = torch.matmul(torch.matmul(x1 - x1_mean, x2 - x2_mean), x3 - x3_mean)

字符串
只是想知道我的代码是否有意义。有没有其他方法来计算协方差?任何帮助都将不胜感激。

cgyqldqp

cgyqldqp1#

您计算通道维沿着三个Tensor的协方差的方法是一个很好的开始,但似乎对如何计算协方差存在误解,特别是对于多个Tensor。您计算通道维上的平均值,然后将其解压缩以匹配原始Tensor的形状。这部分对于Tensor的平均居中是正确的。但是,在代码中使用torch.matmul不能正确计算协方差。协方差通常涉及两组变量之间的成对计算,而不是三组变量之间的成对计算,并且计算方法不同。
综上所述,假设你想在三个Tensor之间做某种形式的 * 多变量协方差 *,下面的代码可以为你完成这项工作:

# step 1: reshape tensors
B, C, H, W = x1.shape
x_combined = torch.cat([x1, x2, x3], dim=1)  # Shape: (B, C*3, H, W)
x_combined = x_combined.reshape(B, C*3, H*W) # Shape: (B, C*3, H*W)

# step 2: mean centering
mean_centered = x_combined - x_combined.mean(dim=2, keepdim=True)

# step 3: covariance matrix calculation
cov_matrix = torch.matmul(mean_centered, mean_centered.transpose(1, 2)) / (H*W - 1)

字符串
此代码片段将生成一批协方差矩阵,每个协方差矩阵的大小为(C3,C3),表示每个空间位置上所有三个Tensor的每对通道之间的协方差。每个矩阵都是批次中一个示例的多变量协方差矩阵。

相关问题