matplotlib 使用seaborn创建分组条形图

vlf7wbxs  于 2023-03-30  发布在  其他
关注(0)|答案(2)|浏览(163)

我有一个dataframe如下

category val1 val2 val3
A       2    3     2
A       3    4     1
B       4    5     2
C       3    3     2
B       4    5     2
C       3    3     2

我试图创建一个分组的酒吧视觉,有类别的x轴,和val1,val2,val3作为y轴。
我的代码类似于这样:

plt.bar(df['category'], df['var1'])
plt.bar(df['category'], df['var2'])
plt.bar(df['category'], df['var3'])

但是,我没有得到一个分组条形图。它是这样的。有什么办法来解决这个问题吗?

ctrmrzij

ctrmrzij1#

我不确定这是你想要的,但你可以尝试以下两种选择:

import pandas as pd
from io import StringIO
import seaborn as sns

data = """category val1 val2 val3
A       2    3     2
A       3    4     1
B       4    5     2
C       3    3     2
B       4    5     2
C       3    3     2"""

df = pd.read_csv(StringIO(data), sep="\s+")

g = sns.barplot(
    data=df.melt(id_vars = ["category"], value_vars=["val1", "val2", "val3"]),
    y="value", x="variable", hue="category", ci=None
)

g = sns.catplot(
    data=df.melt(id_vars = ["category"], value_vars=["val1", "val2", "val3"]),
    kind="bar",
    y="value", x="variable", col="category", ci=None
)
g.set_axis_labels("", "")

这些方法的关键是使用melt来取消数据透视。
还要注意的是,上面并没有处理重复的类别。如果你想让你的值是唯一的,你可以在绘图之前将dfcategory分组并聚合值。

wwodge7n

wwodge7n2#

您的类别不是唯一的。您希望在x轴上看到什么?
假设它们是唯一的,那么你可以简单地做:

import pandas as pd
import seaborn as sns
sns.set()

df = pd.DataFrame()
df['category'] = ['A','B','C','D','E','F']
df['val1'] = [2,3,4,3,4,3]
df['val2'] = [3,4,5,3,5,3]
df['val3'] = [2,1,2,2,2,2]

df.set_index('category').plot(kind='bar', stacked=True)

编辑:seaborn本身并不支持堆叠条形图,但如果你需要(或者其他人正在寻找标题中的实际内容),这里有一个黑客方法。

import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
sns.set()

df = pd.DataFrame()
df['category'] = ['A','B','C','D','E','F']
df['val1'] = [2,3,4,3,4,3]
df['val2'] = [3,4,5,3,5,3]
df['val3'] = [2,1,2,2,2,2]

# create total columns
df['v1+v2+v3'] = df.val1 + df.val2 + df.val3
df['v1+v2'] = df.val1 + df.val2

# plot total
g = sns.barplot(x=df['category'], y=df['v1+v2+v3'], color='green', label='val3')

# plot middle values
g = sns.barplot(x=df['category'], y=df['v1+v2'], color='orange', label='val2')

# plot bottom values
g = sns.barplot(x=df['category'], y=df['val1'], color='blue', label='val1')

# rename axes
g.set_ylabel('vals')
g.set_xlabel('categories')

plt.legend()

相关问题