根据本文garcia et al.(2012),有一个用matlab编写的修复函数(inpaintn),使用离散余弦变换填充多维数据集中的缺失值。我尝试将此代码(inpaintn.m)移植到python中,如下所示:,
import numpy as np
from scipy.ndimage import distance_transform_edt
from scipy.fft import idctn, dctn
from tqdm import tqdm
def fill_nd(data, invalid=None):
if invalid is None: invalid = np.isnan(data)
ind = distance_transform_edt(invalid, return_distances=False, return_indices=True)
return data[tuple(ind)]
def InitialGuess(y, I):
z = fill_nd(y)
s0 = 3
return z, s0
def idctnn(y):
return idctn(y, norm='ortho')
def dctnn(y):
return dctn(y, norm='ortho')
def inpaint(xx, y0=[], n=100, m=2, verbose=False):
x = xx.copy() #as it changes x itself, so copying it to another variable.
sizx = np.shape(x)
d = np.ndim(x)
Lambda = np.zeros(sizx, dtype='float')
for i in range(0, d):
siz0 = np.ones(d, dtype='int')
siz0[i] = sizx[i]
Lambda = Lambda + np.cos(np.pi * np.reshape(np.arange(1, sizx[i] + 0.1) - 1, siz0) / sizx[i])
Lambda = 2 * (d - Lambda)
# Initial condition
W = np.isfinite(x)
if len(y0) == len(x):
y = y0
s0 = 3 # note: s = 10 ^ s0
else:
if np.any(~W):
if verbose: print('Initial Guess as Nearest Neighbors')
y, s0 = InitialGuess(x, np.isfinite(x).astype('bool'))
else:
y = x
s0 = 3
# return x
x[~W] = 0.
# Smoothness parameters: from high to negligible
s = np.logspace(s0, -6, n)
RF = 2. # Relaxation Factor
Lambda = Lambda**m
if verbose: print('Inpainting .......')
for i in tqdm(range(n)):
Gamma = 1. / (1 + s[i] * Lambda)
y = RF * idctnn(Gamma * dctnn((W * (x - y)) + y)) + (1 - RF) * y
y[W] = x[W]
return y
代码运行得很好,但我一直在努力寻找使代码运行得更快的方法,特别是因为我的数据集很大。使用这种类型的插值的优点是,我可以为整个3d数据集(带有时间和栅格坐标)填充缺少的值,而不是为每个时间坐标进行填充。
下面是一个使用python的示例数据集
import numpy as np
# A 3D dataset with dimensions (time, latitude, longitude)
X = np.random.randn(1000,180,360)
# Randomly choosing indices to insert 64800 NaN values (say).
# NaNs can also be present as blocks in the data, not randomly dispersed as below.
index_nan = np.random.choice(X.size, 64800, replace=False)
# Inserting NaNs.
X.ravel()[index_nan] = np.nan
我试过一些方法,但没有成功,
使用麻木
jit装饰器使其速度变慢,即使有如下选项 parallel/fastmath/vectorize,nopython=True
.
使用cython
我试着排版这些函数中使用的所有变量,但仍然比本机python实现慢。而且,在我的机器上编译cython代码很麻烦。
使用numpy矢量化
我已经将离散余弦变换函数及其逆函数替换为 scipy
函数,但我似乎想不出如何将内部for循环矢量化以使其快速,以及是否可能。我试着分析我的代码,瓶颈似乎是使用 scipy
. 还有其他瓶颈,但对我来说没有意义。我已经附加了一个图像的剖析以及。
如果有可行的方法来加速这段代码,那将真的很有帮助。我在python方面并不是很先进,但我可以从中学到很多东西,特别是我的问题的可行性。
暂无答案!
目前还没有任何答案,快来回答吧!