numpy 双峰分布-如何找到将分布分为3组的X的3个值?

kdfy810k  于 12个月前  发布在  其他
关注(0)|答案(1)|浏览(72)

我想找出那些把双峰分布分成3组的X值。例如,在我下面的代码中,基于双峰图,近似值是x小于5,(x大于5且小于90)和(x大于90)
但我没有得到这些值。这是我的代码

import numpy as np
from scipy.stats import gaussian_kde
from scipy.signal import find_peaks 
    
# Generate bimodal data
dist1 = np.random.normal(loc=0, scale=1, size=100) 
dist2 = np.random.normal(loc=90, scale=1, size=100)
bimodal = np.concatenate((dist1, dist2))
    
# Fit KDE 
kde = gaussian_kde(bimodal)
xgrid = np.linspace(0,100)
pdf = kde.evaluate(xgrid)
    
plt.hist(bimodal, bins=100, density=True)
plt.title("Bimodal distribution")
    
# Find peaks 
peaks, _ = find_peaks(pdf)
peak1, peak2 = xgrid[peaks]
    
# Find valley
pdf_min = np.min(pdf[(xgrid > peak1) & (xgrid < peak2)])
valley = xgrid[(pdf == pdf_min) & (xgrid > peak1) & (xgrid < peak2)]
    
# Create group labels
groups = np.ones(len(bimodal), dtype=int)
groups[bimodal < valley] = 0 
groups[(bimodal >= valley) & (bimodal <= peak1)] = 1
groups[bimodal > peak1] = 2
    
# Plot
plt.hist(bimodal)
plt.vlines([valley, peak1], 0, 100, colors='r')
plt.title("Bimodal distribution clustered into 3 groups")
plt.show()
    
print(groups)
aamkag61

aamkag611#

这基本上是一个Gaussian mixture model。有很多复杂的方法来适应它;我展示了一个非常简单的例子。如果你的数据类似于你写的示例参数,那么它将工作得很好。

import matplotlib.pyplot as plt
import numpy as np
import scipy.stats

def generate_opaque(rand: np.random.Generator) -> np.ndarray:
    # return rand.normal(loc=(0, 90), scale=1, size=(100, 2)).ravel()
    return np.concatenate((
        rand.normal(loc=5, scale=2, size=450),
        rand.normal(loc=45, scale=10, size=150),
    ))

def bigaussian(
    x: np.ndarray,
    loc_a: float, scale_a: float,
    loc_b: float, scale_b: float,
    bal: float,
) -> np.ndarray:
    norm_a = scipy.stats.norm(loc_a, scale_a)
    norm_b = scipy.stats.norm(loc_b, scale_b)
    cdf = bal*norm_a.cdf(x) + (1-bal)*norm_b.cdf(x)
    return cdf

def main() -> None:
    rand = np.random.default_rng(seed=0)
    data = np.sort(generate_opaque(rand))

    # Empirical cumulative PDF: this is the reference against which fitting is compared
    ecdf = scipy.stats.ecdf(data).cdf.probabilities

    # Mean of whole dataset. So long as modes are balanced, this will
    # be a roughly sensible estimate of midpoint between the modes.
    mean_est = data.mean()

    # Rough left-side-dominated and right-side-dominated random variable section
    lhs = data[data < mean_est]
    rhs = data[data > mean_est]

    # Rough normal fits for both modes. These are only sensible estimates
    # if inter-mode distance is high and scales are low.
    loc_est_a, scale_est_a = scipy.stats.norm.fit(lhs)
    loc_est_b, scale_est_b = scipy.stats.norm.fit(rhs)

    # Estimated proportion of left normal to entire random variable
    bal_est = lhs.size / data.size

    popt, _ = scipy.optimize.curve_fit(
        f=bigaussian,
        xdata=data, ydata=ecdf,
        p0=(loc_est_a, scale_est_a, loc_est_b, scale_est_b, bal_est),
        bounds=(
            (data.min(),      0, data.min(),      0, 0),
            (data.max(), np.inf, data.max(), np.inf, 1),
        ),
    )
    (loc_a, scale_a, loc_b, scale_b, bal) = popt
    ecdf_fit = bigaussian(data, *popt)

    # One workable definition of separation:
    # empirical value of x s.t. CDF(x) ~ 0.5; aka. mid-point of the density based on modal balance
    i_mid = ecdf_fit.searchsorted(v=bal)
    x_mid = (data[i_mid] + data[i_mid-1])/2

    print(f'lhs est     x ~ {loc_est_a:5.2f} ±{scale_est_a:.2f}, {bal_est:.1%}')
    print(f'lhs refined x ~ {loc_a:5.2f} ±{scale_a:.2f}, {bal:.1%}')
    print(f'rhs est     x ~ {loc_est_b:5.2f} ±{scale_est_b:.2f}, {1-bal_est:.1%}')
    print(f'rhs refined x ~ {loc_b:5.2f} ±{scale_b:.2f}, {1-bal:.1%}')
    print(f'Sep est     {mean_est:.2f}')
    print(f'Sep refined {x_mid:.2f}')

    x_hires = np.linspace(start=data.min(), stop=data.max(), num=1000)
    fit_pdf = (
        scipy.stats.norm(loc_a, scale_a).pdf(x_hires)*bal +
        scipy.stats.norm(loc_b, scale_b).pdf(x_hires)*(1-bal)
    )
    ax: plt.Axes
    fig, ax = plt.subplots()
    ax.hist(data, bins=40, density=True, label='empirical')
    ax.plot(x_hires, fit_pdf, label='fit')
    ax.axvline(x=x_mid, linestyle='--', color='black', label='midpoint')
    ax.legend()

    plt.show()

if __name__ == '__main__':
    main()
lhs est     x ~  4.98 ±2.01, 75.2%
lhs refined x ~  4.95 ±2.03, 74.9%
rhs est     x ~ 44.67 ±9.16, 24.8%
rhs refined x ~ 44.46 ±9.47, 25.1%
Sep est     14.84
Sep refined 15.54

相关问题