python 对4参数回归曲线拟合应用权重

mm5n2pyu  于 2023-06-04  发布在  Python
关注(0)|答案(1)|浏览(290)

下面的代码生成一个图和4PL曲线拟合,但在较低的值下拟合较差。这个错误通常可以通过1/y^2加权来解决,但我不知道在这种情况下如何处理。在fit中添加sigma=1/Y_data**2只会让情况变得更糟。

from scipy.optimize import curve_fit
import matplotlib.pyplot as plt

def fourPL(x, A, B, C, D):
    return ((A-D) / (1.0 + np.power(x / C, B))) + D

X_data = np.array([700,200,44,11,3,0.7,0.2,0])
Y_data = np.array([600000,140000,30000,8000,2100,800,500,60])

popt, pcov = curve_fit(fourPL, X_data, Y_data)

fig, ax = plt.subplots()    
ax.scatter(X_data, Y_data, label='Data')
X_curve = np.linspace(min(X_data[np.nonzero(X_data)]), max(X_data), 5000)
Y_curve = fourPL(X_curve, *popt)
ax.plot(X_curve, Y_curve)

ax.set_xscale('log')
ax.set_yscale('log')

plt.show()

pcww981p

pcww981p1#

不加平方反比权重;适合日志域。始终添加边界。在本例中,curve_fit不能很好地完成工作;相反考虑minimize

import numpy as np
from scipy.optimize import curve_fit, minimize
import matplotlib.pyplot as plt

def fourPL(x: np.ndarray, a: float, b: float, c: float, d: float) -> np.ndarray:
    return (a - d)/(1 + (x / c)**b) + d

def estimated(x: np.ndarray, a: float, b: float, c: float, d: float) -> np.ndarray:
    return np.log(fourPL(x, a, b, c, d))

def sqerror(abcd: np.ndarray) -> float:
    y = np.log(fourPL(x_data, *abcd)) - np.log(y_data)
    return y.dot(y)

x_data = np.array([700, 200, 44, 11, 3, 0.7, 0.2, 0])
y_data = np.array([600000, 140000, 30000, 8000, 2100, 800, 500, 60])
guess = (500, 1.05, 1e6, 1e9)
bounds = np.array((
    (1, 0.1, 1, 0),
    (np.inf, 10, np.inf, np.inf),
))

popt, _ = curve_fit(
    f=estimated, xdata=x_data, ydata=np.log(y_data), p0=guess,
    bounds=bounds,
)
print('popt:', popt)
result = minimize(
    fun=sqerror, x0=guess, bounds=bounds.T, tol=1e-9,
)
assert result.success
print('minimize x:', result.x)

x_curve = 10**np.linspace(-1, 3, 1000)

fig, ax = plt.subplots()
ax.scatter(x_data, y_data, label='Data')
ax.plot(x_curve, fourPL(x_curve, *popt), label='curve_fit')
ax.plot(x_curve, fourPL(x_curve, *result.x), label='minimize')
ax.plot(x_curve, fourPL(x_curve, *guess), label='guess')
ax.set_xscale('log')
ax.set_yscale('log')
ax.legend()
plt.show()

相关问题