pandas 如何创建具有连接点的Swarmplot,与使用色调的Boxplot一致

t1qtbnec  于 2023-04-28  发布在  其他
关注(0)|答案(2)|浏览(89)

由于我的数据的性质,我有两个年龄组进行了两次测试。重要的是,我有一种方法来可视化整个样本的行为(箱线图)以及每个个体在会话之间的变化(Swarmplot/Lineplot)。
当不使用色调或组时,只需连续使用这三个函数,或者跳过linepot,就像这里(Swarmplot with connected dots);但由于我使用色调来区分不同的组,我还没有设法将每个主题的前和后的数据点连接起来。
到目前为止,我已经实现了绘制线条,但它们没有与箱线图对齐,而是与“Pre”和“Post”条件的刻度对齐:
下图显示了四个箱线图(pre_young、pre_old和post_young、post_old),数据点与每个箱线图对齐,但线与“Pre”和“Post”的刻度对齐,而不是与实际数据点或箱线图的中间对齐。

我是通过这个代码得到的:

fig, ax = plt.subplots(figsize=(7,5))
sns.boxplot(data=test_data, 
            x="Session", 
            y="Pre_Post", 
            hue="Age", 
            palette="pastel", 
            boxprops=boxprops, 
            ax=ax)

sns.swarmplot(data=test_data, 
              x="Session", 
              y="Pre_Post", 
              hue="Age", 
              dodge=True, 
              palette="dark", 
              ax=ax)
    
sns.lineplot(data=test_data, 
                 x="Session", 
                 y="Pre_Post", 
                 hue="Age", 
                 estimator=None, 
                 units="Subject", 
                 style="Age", 
                 markers=True, 
                 palette="dark", 
                 ax=ax)

plt.title("Test")
plt.xlabel("Session")
plt.ylabel("Score")

# Move the legend outside the plot
plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)

plt.show()

我也试着通过以下方法得到点的坐标:

points = ax.collections[0]
offsets = points.get_offsets()
x_coords = offsets[:, 0]
y_coords = offsets[:, 1]

但我无法将每个坐标与它们所涉及的主题联系起来。
我正在添加一个数据集的样本,如果它能帮助你帮助我。它是csv格式的:

'Session,Subject,Age,Pre_Post\nPre,SY01,young,14.0\nPre,SY02,young,14.0\nPre,SY03,young,13.0\nPre,SY04,young,13.0\nPre,SY05,young,13.0\nPre,SY06,young,15.0\nPre,SY07,young,14.0\nPre,SY08,young,14.0\nPre,SA01,old,5.0\nPre,SA02,old,1.0\nPre,SA03,old,10.0\nPre,SA04,old,3.0\nPre,SA05,old,9.0\nPre,SA06,old,5.0\nPre,SA07,old,13.0\nPre,SA08,old,13.0\nPost,SY01,young,14.0\nPost,SY02,young,13.0\nPost,SY03,young,14.0\nPost,SY04,young,13.0\nPost,SY05,young,15.0\nPost,SY06,young,13.0\nPost,SY07,young,15.0\nPost,SY08,young,14.0\nPost,SA01,old,6.0\nPost,SA02,old,2.0\nPost,SA03,old,10.0\nPost,SA04,old,7.0\nPost,SA05,old,8.0\nPost,SA06,old,11.0\nPost,SA07,old,14.0\nPost,SA08,old,11.0\n'
dced5bon

dced5bon1#

这将工作:

import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

s = 'Session,Subject,Age,Pre_Post\nPre,SY01,young,14.0\nPre,SY02,young,14.0\nPre,SY03,young,13.0\nPre,SY04,young,13.0\nPre,SY05,young,13.0\nPre,SY06,young,15.0\nPre,SY07,young,14.0\nPre,SY08,young,14.0\nPre,SA01,old,5.0\nPre,SA02,old,1.0\nPre,SA03,old,10.0\nPre,SA04,old,3.0\nPre,SA05,old,9.0\nPre,SA06,old,5.0\nPre,SA07,old,13.0\nPre,SA08,old,13.0\nPost,SY01,young,14.0\nPost,SY02,young,13.0\nPost,SY03,young,14.0\nPost,SY04,young,13.0\nPost,SY05,young,15.0\nPost,SY06,young,13.0\nPost,SY07,young,15.0\nPost,SY08,young,14.0\nPost,SA01,old,6.0\nPost,SA02,old,2.0\nPost,SA03,old,10.0\nPost,SA04,old,7.0\nPost,SA05,old,8.0\nPost,SA06,old,11.0\nPost,SA07,old,14.0\nPost,SA08,old,11.0'

a = np.array([r.split(',') for r in s.split('\n')])

test_data = pd.DataFrame(a[1:, :], columns = a[0])
test_data['Pre_Post'] = test_data['Pre_Post'].apply(float)
def encode_session(x):
  if x=='Pre':
    return 0
  else:
    return 1
test_data['Session'] = test_data['Session'].apply(encode_session)

test_data2 = test_data.copy()
def offset_session(row):
  if row['Age']=='young':
    return row['Session']-0.2
  else:
    return row['Session']+0.2
test_data2['Session'] = test_data2.apply(offset_session, axis=1)

fig, ax = plt.subplots(figsize=(7,5))
sns.boxplot(data=test_data, 
            x="Session", 
            y="Pre_Post", 
            hue="Age", 
            palette="pastel", 
            #boxprops=boxprops, 
            ax=ax)

sns.swarmplot(data=test_data, 
              x="Session", 
              y="Pre_Post", 
              hue="Age", 
              dodge=True, 
              palette="dark", 
              ax=ax)
    
sns.lineplot(data=test_data2, 
                 x="Session", 
                 y="Pre_Post", 
                 hue="Age", 
                 estimator=None, 
                 units="Subject", 
                 style="Age", 
                 markers=True, 
                 palette="dark", 
                 ax=ax)

plt.title("Test")
plt.xlabel("Session")
plt.ylabel("Score")

# Move the legend outside the plot
plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)

plt.xticks([0,1],['Pre', 'Post'])

plt.show()

我们可以讨论这个图的优点。它肯定是混乱的,可能会更好地在两个单独的轴上分裂,减少数据相互重叠。我个人不认为条形图更好。之前/之后的线图可以很好地讲述故事。例如,在one below that I found on google中,我更喜欢在条形图中看这个~40对条形图:

bnlyeluc

bnlyeluc2#

  • 可视化的目的是使从数据中提取意义变得更容易。
  • 通常将swarmplot放在boxplot上,因为它提供了有关发行版的其他信息。
  • 可以,但不应该在分布图上放置趋势线。这是两种类型的图,它们传达有关数据的不同信息,并且图变得难以解释。
  • 由于重点是显示数据的分布,清楚地显示了每个'Subject''Score'变化,一个barplot更合适。
  • 这也是一个更清晰的可视化来分离'Age'组。
  • 如另一个答案所示:
  • a trendline can be shown between associated markers。但是,结果很难读;几乎不可能辨别出相关的标记
  • 请求是为每个'Age'的每个标记添加从'Pre''Post'的趋势线,这会导致plot难以读取**,即使是数据的一个小子集**。
  • 当有许多标记时,趋势线将只指向中心标记,因为lineplot无法与从swarmplot开始的标记对齐。

导入和数据

import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

# data string
s = 'Session,Subject,Age,Pre_Post\nPre,SY01,young,14.0\nPre,SY02,young,14.0\nPre,SY03,young,13.0\nPre,SY04,young,13.0\nPre,SY05,young,13.0\nPre,SY06,young,15.0\nPre,SY07,young,14.0\nPre,SY08,young,14.0\nPre,SA01,old,5.0\nPre,SA02,old,1.0\nPre,SA03,old,10.0\nPre,SA04,old,3.0\nPre,SA05,old,9.0\nPre,SA06,old,5.0\nPre,SA07,old,13.0\nPre,SA08,old,13.0\nPost,SY01,young,14.0\nPost,SY02,young,13.0\nPost,SY03,young,14.0\nPost,SY04,young,13.0\nPost,SY05,young,15.0\nPost,SY06,young,13.0\nPost,SY07,young,15.0\nPost,SY08,young,14.0\nPost,SA01,old,6.0\nPost,SA02,old,2.0\nPost,SA03,old,10.0\nPost,SA04,old,7.0\nPost,SA05,old,8.0\nPost,SA06,old,11.0\nPost,SA07,old,14.0\nPost,SA08,old,11.0'

# split the data into separate components
data = [v.split(',') for v in s.split('\n')]

# load the list of lists into a dataframe
df = pd.DataFrame(data=data[1:], columns=data[0])

# rename the column
df.rename({'Pre_Post': 'Score'}, axis=1, inplace=True)

# convert the column from a string to a float
df['Score'] = df['Score'].apply(float)

# create separate groups of data for the ages
(_, old), (_, young) = df.groupby('Age')

old

Session Subject  Age  Score
8      Pre    SA01  old    5.0
9      Pre    SA02  old    1.0
10     Pre    SA03  old   10.0
11     Pre    SA04  old    3.0
12     Pre    SA05  old    9.0
13     Pre    SA06  old    5.0
14     Pre    SA07  old   13.0
15     Pre    SA08  old   13.0
24    Post    SA01  old    6.0
25    Post    SA02  old    2.0
26    Post    SA03  old   10.0
27    Post    SA04  old    7.0
28    Post    SA05  old    8.0
29    Post    SA06  old   11.0
30    Post    SA07  old   14.0
31    Post    SA08  old   11.0

young

Session Subject    Age  Score
0      Pre    SY01  young   14.0
1      Pre    SY02  young   14.0
2      Pre    SY03  young   13.0
3      Pre    SY04  young   13.0
4      Pre    SY05  young   13.0
5      Pre    SY06  young   15.0
6      Pre    SY07  young   14.0
7      Pre    SY08  young   14.0
16    Post    SY01  young   14.0
17    Post    SY02  young   13.0
18    Post    SY03  young   14.0
19    Post    SY04  young   13.0
20    Post    SY05  young   15.0
21    Post    SY06  young   13.0
22    Post    SY07  young   15.0
23    Post    SY08  young   14.0

绘图

  • 真实的数据可能有更多的观测值,因此增加figsize元组中的第二个数字以增加图的长度,并调整height_ratios中的第二个数字以使条形图使用更多的图形。
# create the figure using height_ratios to make the bottom subplots larger than the top subplots
fig, axes = plt.subplots(2, 2, figsize=(11, 11), height_ratios=[1, 2])

# flatten the axes for easy access
axes = axes.flat

# plot the boxplots
sns.boxplot(data=young, x="Session", y="Score", ax=axes[0])
sns.boxplot(data=old, x="Session", y="Score", ax=axes[1])

# plot the swarmplots
sns.swarmplot(data=young, x="Session", y="Score", hue='Session', edgecolor='k', linewidth=1, legend=None, ax=axes[0])
sns.swarmplot(data=old, x="Session", y="Score", hue='Session', edgecolor='k', linewidth=1, legend=None, ax=axes[1])

# add a title
axes[0].set_title('Age: Young', fontsize=15)
axes[1].set_title('Age: Old', fontsize=15)

# add the barplots
sns.barplot(data=young, x='Score', y='Subject', hue='Session', ax=axes[2])
sns.barplot(data=old, x='Score', y='Subject', hue='Session', ax=axes[3])

# extract the axes level legend properties
handles, labels = axes[3].get_legend_handles_labels()

# iterate through the bottom axes
for ax in axes[2:]:
    # removed the axes legend
    ax.legend().remove()
    
    # iterate through the containers
    for c in ax.containers:
        
        # annotate the bars
        ax.bar_label(c, label_type='center')
    
# add a figure level legend
_ = fig.legend(handles, labels, title='Session', loc='outside right center', frameon=False)

可视化易读

相关问题