scipy 为什么RegularGridInterpolator不能返回多个值(对于以$R^d$输出的函数)

yeotifhr  于 2023-02-12  发布在  其他
关注(0)|答案(1)|浏览(178)

MRE(有工作输出和不工作的输出,尽管我希望它工作,因为这将是直观的事情):

import numpy as np
from scipy.interpolate import RegularGridInterpolator, griddata

def f(x1, x2, x3):
    return x1 + 2*x2 + 3*x3, x1**2, x2

# Define the input points
xi = [np.linspace(0, 1, 5), np.linspace(0, 1, 5), np.linspace(0, 1, 5)]

# Mesh grid
x1, x2, x3 = np.meshgrid(*xi, indexing='ij')

# Outputs
y = f(x1, x2, x3)
assert (y[0][1][1][3] == (0.25 + 2*0.25 + 3*0.75))
assert (y[1][1][1][3] == (0.25**2))
assert (y[2][1][1][3] == 0.25)

#### THIS WORKS BUT I CAN ONLY GET THE nth (with n integer in [1, d]) VALUE RETURNED BY f
# Interpolate at point 0.3, 0.3, 0.4
interp = RegularGridInterpolator(xi, y[0])
print(interp([0.3, 0.3, 0.4]))  # outputs 2.1 as expected

#### THIS DOESN'T WORK (I WOULD'VE EXPECTED A LIST OF TUPLES FOR EXAMPLE)
# Interpolate at point 0.3, 0.3, 0.4
interp = RegularGridInterpolator(xi, y)
print(interp([0.3, 0.3, 0.4]))  # doesn't output array([2.1, 0.1, 0.3])

有趣的是griddata确实支持以R^d为单位输出值的函数

# Same with griddata
grid_for_griddata = np.array([x1.flatten(), x2.flatten(), x3.flatten()]).T
assert (grid_for_griddata.shape == (125, 3))
y_for_griddata = np.array([y[0].flatten(), y[1].flatten(), y[2].flatten()]).T
assert (y_for_griddata.shape == (125, 3))
griddata(grid_for_griddata, y_for_griddata, [0.3, 0.3, 0.4], method='linear')[0]  # outputs array([2.1, 0.1, 0.3]) as expected

我使用RegularGridInterpolator的方式是否错误?
我知道有些人可能会说“只要使用griddata“,但是因为我的数据是在直线网格中,所以我应该使用RegularGridInterpolator,这样会更快,对吗?
证明它更快:

zte4gxcn

zte4gxcn1#

如果我定义一个y,最后一个维度是3:

In [196]: yarr = np.stack(y,axis=3); yarr.shape
Out[196]: (5, 5, 5, 3)

设置工作正常(没有关于3与5不匹配的投诉):

In [197]: interp = RegularGridInterpolator(xi, yarr)

和插值:

In [198]: interp([.3,.3,.4])
Out[198]: array([[2.1, 0.1, 0.3]])

对于多个点:

In [202]: interp([[.3,.3,.4],[.31,.31,.41],[.5,.4,.4]])
Out[202]: 
array([[2.1   , 0.1   , 0.3   ],
       [2.16  , 0.1075, 0.31  ],
       [2.5   , 0.25  , 0.4   ]])

虽然以上只是一个猜测,但我看到这些文档可以这样解释:

values: array_like, shape (m1, …, mn, …)

最后的...表示数组可以有0个或更多的尾部维度(超过与点维度匹配的n)。但是这种灵活性可能更多地应用于linearnearest方法。其他方法似乎有问题。
这在__call__的文档页面上更清楚:
https://docs.scipy.org/doc/scipy/reference/generated/scipy.interpolate.RegularGridInterpolator.call.html#scipy.interpolate.RegularGridInterpolator.call

Returns
    values_x : ndarray, shape xi.shape[:-1] + values.shape[ndim:]

interpn也记录了这一点。

相关问题