scipy python中的指数曲线拟合参数没有意义--拟合本身看起来很棒

yiytaume  于 2022-12-13  发布在  Python
关注(0)|答案(2)|浏览(162)

我正在python中使用scipy.curve_fit进行曲线拟合,拟合本身看起来很棒,但是生成的参数没有意义。
方程是(ax)^B + cx,但是使用参数python发现a = -c和b = 1,所以对于每个x值,整个方程都等于0。
这是图(https://i.stack.imgur.com/fBfg7.png)](https://i.stack.imgur.com/fBfg7.png
以下是我使用的实验原始数据:https://pastebin.com/CR2BCJji

xdata = cfu_u
ydata = OD_u

min_cfu = 0.1
max_cfu = 9.1
x_vec = pow(10,np.arange(min_cfu,max_cfu,0.1))

def func(x,a, b, c):
  return (a*x)**b + c*x 

popt, pcov = curve_fit(func, xdata, ydata)

plt.plot(x_vec, func(x_vec, *popt), label = 'curve fit',color='slateblue',linewidth = 2.2)
plt.plot(cfu_u,OD_u,'-',label = 'experimental data',marker='.',markersize=8,color='deepskyblue',linewidth = 1.4)
plt.legend(loc='upper left',fontsize=12)
plt.ylabel("Y",fontsize=12)
plt.xlabel("X",fontsize=12)
plt.xscale("log")
plt.gcf().set_size_inches(7, 5)
plt.show()

print(popt)
[ 1.44930871e+03  1.00000000e+00 -1.44930871e+03]

我用scipy的curve_fit函数对一些数据进行了指数曲线拟合,拟合看起来非常好,所以那部分是成功的。
但是,curve_fit函数输出的参数没有意义,使用这些参数求解f(x)会导致对于x的每个值f(x)=0,这显然不是曲线中发生的情况。

tyu7yeag

tyu7yeag1#

修改模型以显示实际发生的情况:

def func(x: np.ndarray, a: float, b: float, c: float) -> np.ndarray:
    return (a*x)**(1 - b) + (c - a)*x

生产优化参数

[3.49003332e-04 6.60420171e-06 3.13366557e-08]

这在数值上可能不稳定。请尝试在日志域中进行优化。

ie3xauqp

ie3xauqp2#

当我运行你的例子时(在添加导入等之后),我得到了popt的NaNs,我最终意识到您允许一般的实数b与负数x。(见下文),但也许您需要将b限制为整数,以适合整个集合。我不确定如何在Scipy中实现这一点(我假设您需要混合整数-实数优化,我还没有研究Scipy是否支持这一点)。

编码:

import numpy as np
from scipy.optimize import curve_fit
import matplotlib.pyplot as plt

cfu_u, OD_u = np.loadtxt('data.txt', skiprows=1).T

# fit to positive x only
posmask = cfu_u > 0
xdata = cfu_u[posmask]
ydata = OD_u[posmask]

def func(x, a, b, c):
  return (a*x)**b + c*x 

popt, pcov = curve_fit(func, xdata, ydata, p0=[1000,2,1])

x_vec = np.geomspace(xdata.min(), xdata.max())

plt.plot(x_vec, func(x_vec, *popt), label = 'curve fit',color='slateblue',linewidth = 2.2)
plt.plot(cfu_u,OD_u,'-',label = 'experimental data', marker='x',markersize=8,color='deepskyblue',linewidth = 1.4)
plt.legend(loc='upper left',fontsize=12)
plt.ylabel("Y",fontsize=12)
plt.xlabel("X",fontsize=12)
plt.yscale("log")
plt.xscale("symlog")
plt.show()

print(popt)
#[ 1.44930871e+03  1.00000000e+00 -1.44930871e+03]

相关问题