matplotlib 面内具有分组条形图的catplot的自定义错误条

mmvthczy  于 2023-04-12  发布在  其他
关注(0)|答案(1)|浏览(161)

pandas 1.5.3seaborn 0.12.2
我的代码和部分数据如下所示。我试图绘制在dataframe (val_lo,val_hi)中预先计算的错误条。似乎sns.catplotkind=bar支持使用errorbar,正如提到的here-我如何让它工作?或者如何使用matplotlib错误条的任何指导?

import pandas as pd
import re
import seaborn as sns
from matplotlib.ticker import PercentFormatter

df = pd.DataFrame([
    ['C', 'G1', 'gbt',    'auc', 0.7999, 0.7944, 0.8032],
    ['C', 'G1', 'gbtv2',  'auc', 0.8199, 0.8144, 0.8232],
    ['C', 'G1', 'gbt',  'pr@2%', 0.0883, 0.0841, 0.0909],
    ['C', 'G1', 'gbt', 'pr@10%', 0.0430, 0.0416, 0.0435],
    ['C', 'G2', 'gbt',    'auc', 0.7554, 0.7506, 0.7573],
    ['C', 'G2', 'gbt',  'pr@2%', 0.0842, 0.0795, 0.0872],
    ['C', 'G2', 'gbt', 'pr@10%', 0.0572, 0.0556, 0.0585],
    ['C', 'G3', 'gbt',    'auc', 0.7442, 0.7404, 0.7460],
    ['C', 'G3', 'gbt',  'pr@2%', 0.0894, 0.0836, 0.0913],
    ['C', 'G3', 'gbt', 'pr@10%', 0.0736, 0.0714, 0.0742],
    ['E', 'G1', 'gbt',    'auc', 0.7988, 0.7939, 0.8017],
    ['E', 'G1', 'gbt',  'pr@2%', 0.0810, 0.0770, 0.0832],
    ['E', 'G1', 'gbt', 'pr@10%', 0.0354, 0.0342, 0.0361],
    ['E', 'G1', 'gbtv3','pr@10%',0.0454, 0.0442, 0.0461],
    ['E', 'G2', 'gbt',    'auc', 0.7296, 0.7253, 0.7311],
    ['E', 'G2', 'gbt',  'pr@2%', 0.1071, 0.1034, 0.1083],
    ['E', 'G2', 'gbt', 'pr@10%', 0.0528, 0.0508, 0.0532],
    ['E', 'G3', 'gbt',    'auc', 0.6958, 0.6914, 0.6978],
    ['E', 'G3', 'gbt',  'pr@2%', 0.1007, 0.0961, 0.1030],
    ['E', 'G3', 'gbt', 'pr@10%', 0.0536, 0.0518, 0.0541],
  ], columns=["src","grp","model","metric","val","val_lo","val_hi"])

sns.reset_defaults()
sns.set(style="whitegrid", font_scale=1.)
g = sns.catplot(data=df, x="grp", y="val", hue="model", 
  col="metric", row="src", kind="bar", sharey=False)
for ax in g.axes.flat:
  ax.yaxis.set_major_formatter(PercentFormatter(1))
  if re.search("metric = auc",ax.get_title(),re.IGNORECASE):
    _ = ax.set_ylim((.5,1.))
plt.show()
8gsdolmq

8gsdolmq1#

  • ax.set_ylim((.5, 1.))是一个可怕的方式来呈现酒吧。
  • 条形图应始终使用共同的零值基线。
  • 这是数据经常被歪曲的方式,因为它具有夸大比较条的差异的效果。
  • 因此,将不包括该方面。
  • 有一些方法可以手动将线添加到seaborn图中作为误差条,但这违背了目的,而且很麻烦。
  • seabornmatplotlib的高级API,这使得一些事情更容易实现,但是,如果您的绘图需要自定义,那么直接使用matplotlib可能是更好的选择
  • g.map(plt.errorbar, 'grp', 'val', 'yerr', marker='none', color='r', ls='none')没有正确地避开错误条以与条对齐,如here所示。
  • pandas.DataFrame.plot使用matplotlib作为默认绘图后端。
  • 可以直接使用pandasmatplotlib.pyplot.subplots创建相同的图。
    *python 3.11.2pandas 2.0.0matplotlib 3.7.1seaborn 0.12.2中测试
  • 如果'val_lo''val_hi'关于杆顶对称
  • 使用df.val_hi.sub(df.val_lo)计算yerr,然后使用pandas.DataFrame.plot中的yerr=参数添加误差线。
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.ticker import PercentFormatter
import numpy as np

# given the DataFrame in the OP

# setting metric and Categorical and ordered will insure the order of each subplot as long as df.metric.unique() is in the desired order, otherwise pass a list of the ordered unique values
df.metric = pd.Categorical(df.metric, df.metric.unique(), ordered=True)

# calculate a single metric for the errorbars
df['yerr'] = df.val_hi.sub(df.val_lo)

# create the figure and subplots
fig, axes = plt.subplots(2, 3, figsize=(10, 8), sharex=True, sharey=True)

# flatten the axes for easy access
axes = axes.flat

# get a set of the unique model values
models = set(df.model.unique())

# iteate through the axes and groupby objects
for ax, ((src, metric), data) in zip(axes, df.groupby(['src', 'metric'])):
    
    # pivot the val the yerr column for plotting
    yerr = data.pivot(index='grp', columns='model', values='yerr')
    data = data.pivot(index='grp', columns='model', values='val')
    
    # add the missing columns to  data to unsure all grp are shown
    cols = list(models.difference(set(data.columns)))
    data[cols] = 0
    
    # sort the columns so bars are plotted in the same position in each axes
    data = data.sort_index(axis=1)
    
    # plot the bars for data
    data.plot(kind='bar', yerr=yerr, ax=ax, rot=0, yticks=np.arange(0, 1.1, .1), title=f'src: {src} | metric: {metric}')
    
    # change the yaxis to percent
    ax.yaxis.set_major_formatter(PercentFormatter(1))
    
    # remove the spines the match catplot
    ax.spines[['right', 'top']].set_visible(False)

# extract the axes level legend properties
handles, labels = axes[-1].get_legend_handles_labels()

# remove all the axes level legends
for ax in axes:
    ax.legend().remove()

# add a figure level legend
fig.legend(handles, labels, title='Model', loc='outside right center', frameon=False)

  • 如果'val_lo''val_hi'对称于杆顶
  • 使用.vlines绘制一条垂直线作为误差线
  • 'val_lo''val_hi'分别作为yminymax
  • 使用.get_center提取相应的bar容器的xtick位置,可以将其传递给x
  • 有关此方法的其他详细信息,请参见How to draw vertical lines on a given plot
# given the DataFrame in the OP

# setting metric and Categorical and ordered will insure the order of each subplot as long as df.metric.unique() is in the desired order, otherwise pass a list of the ordered unique values
df.metric = pd.Categorical(df.metric, df.metric.unique(), ordered=True)

# create the figure and subplots
fig, axes = plt.subplots(2, 3, figsize=(20, 20), sharex=True, sharey=True, dpi=300)

# flatten the axes for easy access
axes = axes.flat

# get a set of the unique model values
models = set(df.model.unique())

# iteate through the axes and groupby objects
for ax, ((src, metric), data) in zip(axes, df.groupby(['src', 'metric'])):
    
    # get the error columns
    error_data = data[['grp', 'model', 'src', 'val_lo', 'val_hi']].copy()
    
    # pivot the val column for plotting
    data = data.pivot(index='grp', columns='model', values='val')
    
    # add the missing columns to  data to unsure all grp are shown
    cols = list(models.difference(set(data.columns)))
    data[cols] = 0
    
    # sort the columns so bars are plotted in the same position in each axes
    data = data.sort_index(axis=1)

    # plot the bars for data
    data.plot(kind='bar', ax=ax, rot=0, yticks=np.arange(0, 1.1, .1), title=f'src: {src} | metric: {metric}')
    
    # iterate through each bar container
    for c in ax.containers:
        # get the label of the bar
        label = c.get_label()
        
        # select the appropriate error data
        eb = error_data[error_data.model.eq(label)]
        
        # get the center x value of the existing bars
        x = [center[0] for v in c if (center := v.get_center()).any() and center[1] != 0]
        
        # if eb isn't empty for the current label, add the vertical lines
        if not eb.empty:
            ax.vlines(x, ymin=eb.val_lo, ymax=eb.val_hi, color='k')

    # change the yaxis to percent
    ax.yaxis.set_major_formatter(PercentFormatter(1))
    
    # remove the spines the match catplot
    ax.spines[['right', 'top']].set_visible(False)
    
# extract the axes level legend properties
handles, labels = axes[-1].get_legend_handles_labels()

# remove all the axes level legends
for ax in axes:
    ax.legend().remove()

# add a figure level legend
fig.legend(handles, labels, title='Model', loc='outside right center', frameon=False)
  • 这个图像被保存为一个非常大的尺寸和dpi,因为一些误差条非常小,否则就是barely visible

df

src grp  model  metric     val  val_lo  val_hi    yerr
0    C  G1    gbt     auc  0.7999  0.7944  0.8032  0.0088
1    C  G1  gbtv2     auc  0.8199  0.8144  0.8232  0.0088
2    C  G1    gbt   pr@2%  0.0883  0.0841  0.0909  0.0068
3    C  G1    gbt  pr@10%  0.0430  0.0416  0.0435  0.0019
4    C  G2    gbt     auc  0.7554  0.7506  0.7573  0.0067
5    C  G2    gbt   pr@2%  0.0842  0.0795  0.0872  0.0077
6    C  G2    gbt  pr@10%  0.0572  0.0556  0.0585  0.0029
7    C  G3    gbt     auc  0.7442  0.7404  0.7460  0.0056
8    C  G3    gbt   pr@2%  0.0894  0.0836  0.0913  0.0077
9    C  G3    gbt  pr@10%  0.0736  0.0714  0.0742  0.0028
10   E  G1    gbt     auc  0.7988  0.7939  0.8017  0.0078
11   E  G1    gbt   pr@2%  0.0810  0.0770  0.0832  0.0062
12   E  G1    gbt  pr@10%  0.0354  0.0342  0.0361  0.0019
13   E  G1  gbtv3  pr@10%  0.0454  0.0442  0.0461  0.0019
14   E  G2    gbt     auc  0.7296  0.7253  0.7311  0.0058
15   E  G2    gbt   pr@2%  0.1071  0.1034  0.1083  0.0049
16   E  G2    gbt  pr@10%  0.0528  0.0508  0.0532  0.0024
17   E  G3    gbt     auc  0.6958  0.6914  0.6978  0.0064
18   E  G3    gbt   pr@2%  0.1007  0.0961  0.1030  0.0069
19   E  G3    gbt  pr@10%  0.0536  0.0518  0.0541  0.0023

相关问题