对可变数量的参数使用scipy curve_fit

agyaoht7  于 2022-11-09  发布在  其他
关注(0)|答案(4)|浏览(260)

我有一个拟合函数,其形式为:

def fit_func(x_data, a, b, c, N)

其中,a、B、c是长度为N的列表,其中的每个条目都是要在scipy.optimiz.curve_fit()中优化的可变参数,并且N是用于循环索引控制的固定数。
this question之后,我想我能够修复N,但是我现在调用curve_fit如下:

params_0 = [a_init, b_init, c_init]
popt, pcov = curve_fit(lambda x, a, b, c: fit_func(x, a, b, c, N), x_data, y_data, p0=params_0)

我收到一个错误:lambda()正好接受Q个参数(给定P)
Q和P的变化取决于我的设置方式。
那么:对于初学者来说,这是可能的吗?我可以将列表作为参数传递给curve_fit,并具有我所希望的行为,即它将列表元素作为单独的参数处理吗?假设答案是肯定的,我在函数调用中做错了什么?

vaqhlq81

vaqhlq811#

这里的解决方案是编写一个 Package 器函数,它接受你的参数列表,并将其转换为fit函数理解的变量。这实际上是唯一必要的,因为我正在处理别人的代码,在一个更直接的应用程序中,这将不需要 Package 器层。基本上

def wrapper_fit_func(x, N, *args):
    a, b, c = list(args[0][:N]), list(args[0][N:2*N]), list(args[0][2*N:3*N])
    return fit_func(x, a, b, c, N)

要固定N,必须在curve_fit中调用它,如下所示:

popt, pcov = curve_fit(lambda x, *params_0: wrapper_fit_func(x, N, params_0), x, y, p0=params_0)

其中

params_0 = [a_1, ..., a_N, b_1, ..., b_N, c_1, ..., c_N]
jexiocij

jexiocij2#

我可以用稍微不同的方法解决同样的问题。我使用scip.optimize.least_squares来解决,而不是curv_fit。我已经在链接-https://stackoverflow.com/a/60409667/11253983下讨论了我的解决方案

unftdfkk

unftdfkk3#

请看这篇文章https://stackoverflow.com/a/73951825/20160627,我建议使用scipy.optimize.curve_fit,其中包含任意数量和位置的参数,以便在列表中进行拟合或修复

q0qdq0h2

q0qdq0h24#

可能不是一个最优的解决方案,但我遇到了类似的问题。我解决它的方法是硬编码“传递”函数,字面上输出我的广义函数,但需要特定数量的参数。


# Defines General Temperature Function

def general_temp_function(t, a, *params):
   # Seperates params arguments in respective lists, i.e.
   # general_temp_regress(t, a, b1, tau1, b2, tau2...)
   # b_list = [b1, b2, ...]
   # tau_list = [tau1, tau2, ...]
   b_list = []
   tau_list = []
   for i in range(0, len(params) // 2):
      pair = [params[i * 2], params[i * 2 + 1]]
      b_list.append(pair[0])
      tau_list.append(pair[1])

   # Calculates
   N = len(params) // 2
   sum_ = 0
   for term in range(0, N):
      sum_ += b_list[term] * np.exp(-t / tau_list[term])

   return a + sum_ 

# Defines Passing Functions

def pass1(t, a, b1, tau1):
   return general_temp_function(t, a, b1, tau1)
def pass2(t, a, b1, tau1, b2, tau2):
   return general_temp_function(t, a, b1, tau1, b2, tau2)
def pass3(t, a, b1, tau1, b2, tau2, b3, tau3):
   return general_temp_function(t, a, b1, tau1, b2, tau2, b3, tau3)
def pass4(t, a, b1, tau1, b2, tau2, b3, tau3, b4, tau4):
   return general_temp_function(t, a, b1, tau1, b2, tau2, b3, tau3, b4, tau4)
def pass5(t, a, b1, tau1, b2, tau2, b3, tau3, b4, tau4, b5, tau5):
   return general_temp_function(t, a, b1, tau1, b2, tau2, b3, tau3, b4, tau4, b5, tau5)

T_fit1_N1params, T_fit1_N1covars = curve_fit(pass1, t, T_data1)

相关问题