PyTorch二维卷积的高效伪逆算法

bf1o4zei  于 2022-11-09  发布在  其他
关注(0)|答案(1)|浏览(258)

背景:

感谢大家的关注!我正在学习二维卷积、线性代数和PyTorch的基础知识,遇到了卷积算子的伪逆的实现问题,具体来说,我不知道如何高效地实现它,具体请看下面的问题陈述,欢迎大家的帮助/提示/建议。
(非常感谢您的关注!)

原始问题:

我有一个形状为[b,c,h,w]的图像特征x和一个形状为[c,c,3,3]的卷积核3x3,还有一个y = K * x,如何在y上高效地实现相应的伪逆?
有[y = K * x = Ax],如何实现[x_hat =(A^+)y]?
我想应该有一些操作使用torch.fft,但是我还不知道如何实现,不知道以前有没有实现。

import torch
import torch.nn.functional as F

c = 32
K = torch.randn(c, c, 3, 3)
x = torch.randn(1, c, 128, 128)
y = F.conv2d(x, K, padding=1)

print(y.shape)

# How to implement pseudo-inverse for y = K * x in an efficient way?

我的一些努力:

我可能知道二维卷积是一个线性算子,它相当于一个“矩阵积”算子。我们实际上可以写出卷积的矩阵形式,并计算它的伪逆。但是,我认为这种类型的运算效率很低。而且我不知道如何以有效的方式实现它。
根据Wikipedia,伪逆可以满足A(A_pinv(x))=x的性质,其中A是卷积算子,A_pinv是它的伪逆,并且x可以是任何图像特征。
(再次感谢您阅读这么长的帖子!)

zd287kbt

zd287kbt1#

卷积本身是一个线性运算,你可以确定运算的矩阵,直接求解最小二乘问题1(https://dsp.stackexchange.com/a/84405/55891),或者像你提到的那样计算伪逆,然后应用于不同的输出,预测输入的投影。
我正在将您的代码更改为使用padding=0

import torch
import torch.nn.functional as F

# your code

c = 32
K = torch.randn(c, c, 1, 1)
x = torch.randn(4, c, 128, 128)
y = F.conv2d(x, K, bias=torch.zeros((c,)))

此外,正如您可能已经建议的那样,卷积可以计算为ifft(fft(h)*fft(x))。但是,conv2d函数是互相关函数,因此您必须对导致ifft(fft(h)*fft(x))的滤波器进行共轭,还必须将其应用于两个轴,并且必须确保使用相同的表示法计算FFT(大小),由于数据是真实的,我们可以使用multi-dimensional real FFT。完整地说,conv2d适用于多个通道,因此我们必须计算卷积的总和。由于FFT是线性的,我们可以使用einsum简单地计算频域上的总和。

s = y.shape[-2:]
K_f = torch.fft.rfftn(K, s)
x_f = torch.fft.rfftn(x, s)
y_f = torch.einsum('jkxy,ikxy->ijxy', K_f.conj(), x_f)
y_hat = torch.fft.irfftn(y_f, s)

除了边界之外,它应该是准确的(记住FFT计算循环卷积)。

torch.max(abs(y_hat[:,:,:-2,:-2] - y[:,:,:,:]))

现在,注意einsum上的模式jk,ik->ij,这意味着y_f[i,j] = sum(K_f[j,k] * x_f[i,k]) = x_f @ K_f.T,如果@是前两个维度上的矩阵积。所以要反转这个运算,我们必须将前两个维度解释为矩阵。函数pinv将计算后两个轴上的伪逆。所以为了使用它,我们必须置换轴。如果我们将输出乘以转置的K_f的伪逆,我们应该反转这个操作。

s = 128,128
K_f = torch.fft.rfftn(K, s)
K_f_inv = torch.linalg.pinv(K_f.T).T
y_f = torch.fft.rfftn(y_hat, s)
x_f = torch.einsum('jkxy,ikxy->ijxy', K_f_inv.conj(), y_f)
x_hat = torch.fft.irfftn(x_f, s)
print(torch.mean((x - x_hat)**2) / torch.mean((x)**2))

注意,我使用的是全卷积,但conv2d实际上裁剪了图像。

y_hat[:,:,128-(k-1):,:] = 0
y_hat[:,:,:,128-(k-1):] = 0

重复计算,你会发现输入不再准确,所以你必须小心处理卷积,但在某些情况下,你可以让它工作,它实际上是有效的。

s = 128,128
K_f = torch.fft.rfftn(K, s)
K_f_inv = torch.linalg.pinv(K_f.T).T
y_f = torch.fft.rfftn(y_hat, s)
x_f = torch.einsum('jkxy,ikxy->ijxy', K_f_inv.conj(), y_f)
x_hat = torch.fft.irfftn(x_f, s)
print(torch.mean((x - x_hat)**2) / torch.mean((x)**2))

相关问题