numpy 如何在散点图上绘制多个类别的多项式模型

fnvucqvd  于 12个月前  发布在  其他
关注(0)|答案(1)|浏览(102)

我正在使用一个标准的钻石数据集,我需要创建一个以下类型的图:

目前我所拥有的只有1)

import seaborn as sns
import matplotlib.pyplot as plt

# load the data
df = sns.load_dataset('diamonds')

plt.figure(figsize=(12, 8), dpi=200)

scatterplot = sns.scatterplot(data=df, x='carat', y='price', hue='cut', palette='viridis')

sns.lineplot(data=df, x='carat', y='price', hue='cut', palette='viridis', ax=scatterplot)

plt.xlabel('Carat')
plt.ylabel('Price')
plt.title('Scatter Plot of Price vs. Carat with Curved Lines (Viridis Palette)')

plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')

plt.show()

(二)

plt.figure(figsize=(12, 8), dpi=200)

cut_categories = df['cut'].unique()

for cut in cut_categories:
    data = df[df['cut'] == cut]
    sns.regplot(data=data, x='carat', y='price', scatter_kws={'s': 10}, label=cut)

plt.xlabel('Carat')
plt.ylabel('Price')
plt.title('Regression Plot of Price vs. Carat by Cut')

plt.legend(title='Cut')

plt.show()

我怎样才能得到一个多项式拟合的图形?

wbgh16ku

wbgh16ku1#

数据和导入

import seaborn as sns
import numpy as np
import matplotlib.pyplot as plt

# load data
df = sns.load_dataset('diamonds')

   carat      cut color clarity  depth  table  price     x     y     z
0   0.23    Ideal     E     SI2   61.5   55.0    326  3.95  3.98  2.43
1   0.21  Premium     E     SI1   59.8   61.0    326  3.89  3.84  2.31
2   0.23     Good     E     VS1   56.9   65.0    327  4.05  4.07  2.31
3   0.29  Premium     I     VS2   62.4   58.0    334  4.20  4.23  2.63
4   0.31     Good     J     SI2   63.3   58.0    335  4.34  4.35  2.75

np.polyfitnp.poly1d

# create figure and Axes
fig, ax = plt.subplots(figsize=(12, 8))

# plot the scatter points
sns.scatterplot(data=df, x='carat', y='price', hue='cut', palette='viridis', s=10, alpha=0.4, ec='none', ax=ax)

# matching palette colors from viridis
colors = palette = sns.color_palette('viridis', n_colors=len(df.cut.unique())

# iterate through the unique cuts and matching color
for cut, color in zip(df.cut.unique(), colors):

    # select the data for a given cut
    data = df[df.cut.eq(cut)]

    # create the polynomial model
    p = np.poly1d(np.polyfit(data.carat, data.price, 5))

    # create x values to pass to the model
    xp = np.linspace(data.carat.min(), data.carat.max(), 1000)

    # plot the model
    sns.lineplot(x=xp, y=p(xp), color=color, ax=ax, ls=':')

sns.move_legend(ax, bbox_to_anchor=(1, 0.5), loc='center left', frameon=False)

sns.lmplot

  • 如果order大于1,则使用numpy.polyfit估计多项式回归。
  • 使用hue参数分隔类别。
# plot the polynomial model
g = sns.lmplot(data=df, x='carat', y='price', hue='cut', palette='viridis', order=5, truncate=True, ci=None, scatter_kws={'s': 10, 'alpha': 1}, height=8, aspect=1.25)

# access the axes to add the manual poly model to
ax = g.axes.flat[0]

# plot the manual model for comparison
for cut, color in zip(df.cut.unique(), colors):
    data = df[df.cut.eq(cut)]
    p = np.poly1d(np.polyfit(data.carat, data.price, 5))
    xp = np.linspace(data.carat.min(), data.carat.max(), 1000)
    sns.lineplot(x=xp, y=p(xp), color='k', ax=ax, ls=':', legend=False)

sns.regplot

  • 必须指定order=并设置ci=None
  • lmplot不同,hue没有参数。
fig, ax = plt.subplots(figsize=(12, 8))

for cut in df['cut'].unique():
    data = df[df['cut'] == cut]
    sns.regplot(data=data, x='carat', y='price', scatter_kws={'s': 10}, label=cut, order=5, ci=None, ax=ax)

相关问题