在python中这种修复插值可以更快吗?

kq4fsx7k  于 2021-09-08  发布在  Java
关注(0)|答案(0)|浏览(232)

根据本文garcia et al.(2012),有一个用matlab编写的修复函数(inpaintn),使用离散余弦变换填充多维数据集中的缺失值。我尝试将此代码(inpaintn.m)移植到python中,如下所示:,

import numpy as np
from scipy.ndimage import distance_transform_edt
from scipy.fft import idctn, dctn
from tqdm import tqdm

def fill_nd(data, invalid=None):
    if invalid is None: invalid = np.isnan(data)

    ind = distance_transform_edt(invalid, return_distances=False, return_indices=True)
    return data[tuple(ind)]

def InitialGuess(y, I):
    z = fill_nd(y)
    s0 = 3
    return z, s0

def idctnn(y):
    return idctn(y, norm='ortho')

def dctnn(y):
    return dctn(y, norm='ortho')

def inpaint(xx, y0=[], n=100, m=2, verbose=False):
    x = xx.copy() #as it changes x itself, so copying it to another variable.

    sizx = np.shape(x)
    d = np.ndim(x)
    Lambda = np.zeros(sizx, dtype='float')

    for i in range(0, d):
        siz0 = np.ones(d, dtype='int')
        siz0[i] = sizx[i]
        Lambda = Lambda + np.cos(np.pi * np.reshape(np.arange(1, sizx[i] + 0.1) - 1, siz0) / sizx[i])

    Lambda = 2 * (d - Lambda)

    # Initial condition
    W = np.isfinite(x)
    if len(y0) == len(x):
        y = y0
        s0 = 3  # note: s = 10 ^ s0
    else:
        if np.any(~W):
            if verbose: print('Initial Guess as Nearest Neighbors')
            y, s0 = InitialGuess(x, np.isfinite(x).astype('bool'))
        else:
            y = x
            s0 = 3
            # return x
    x[~W] = 0.

    # Smoothness parameters: from high to negligible
    s = np.logspace(s0, -6, n)

    RF = 2.  # Relaxation Factor
    Lambda = Lambda**m

    if verbose: print('Inpainting .......')

    for i in tqdm(range(n)):
        Gamma = 1. / (1 + s[i] * Lambda)
        y = RF * idctnn(Gamma * dctnn((W * (x - y)) + y)) + (1 - RF) * y

    y[W] = x[W]

    return y

代码运行得很好,但我一直在努力寻找使代码运行得更快的方法,特别是因为我的数据集很大。使用这种类型的插值的优点是,我可以为整个3d数据集(带有时间和栅格坐标)填充缺少的值,而不是为每个时间坐标进行填充。
下面是一个使用python的示例数据集

import numpy as np

# A 3D dataset with dimensions (time, latitude, longitude)

X = np.random.randn(1000,180,360)

# Randomly choosing indices to insert 64800 NaN values (say).

# NaNs can also be present as blocks in the data, not randomly dispersed as below.

index_nan = np.random.choice(X.size, 64800, replace=False)

# Inserting NaNs.

X.ravel()[index_nan] = np.nan

我试过一些方法,但没有成功,
使用麻木
jit装饰器使其速度变慢,即使有如下选项 parallel/fastmath/vectorize,nopython=True .
使用cython
我试着排版这些函数中使用的所有变量,但仍然比本机python实现慢。而且,在我的机器上编译cython代码很麻烦。
使用numpy矢量化
我已经将离散余弦变换函数及其逆函数替换为 scipy 函数,但我似乎想不出如何将内部for循环矢量化以使其快速,以及是否可能。我试着分析我的代码,瓶颈似乎是使用 scipy . 还有其他瓶颈,但对我来说没有意义。我已经附加了一个图像的剖析以及。

如果有可行的方法来加速这段代码,那将真的很有帮助。我在python方面并不是很先进,但我可以从中学到很多东西,特别是我的问题的可行性。

暂无答案!

目前还没有任何答案,快来回答吧!

相关问题