python 调用matplotlib子图时,图例打印两次

a1o7rhls  于 2022-12-28  发布在  Python
关注(0)|答案(1)|浏览(185)

我在matplotlib中编写了一段代码,在一个子图网格下打印多个直方图,但是,当我在最后调用fig. legend()函数时,每个图的图例都打印了两次。如果您能给出解决这个问题的指导,我们将不胜感激:)以下是我的代码:

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_style('darkgrid')
def get_cmap(n, name='hsv'):
    return plt.cm.get_cmap(name, n)
def isSqrt(n):
   sq_root = int(np.sqrt(n))
   return (sq_root*sq_root) == n
df = pd.read_csv('mpg.csv')
df2 = pd.read_csv('dm_office_sales.csv')
df['miles'] = df2['salary']
numericClassifier = ['int16', 'int32', 'int64', 'float16', 'float32', 'float64']
newdf = df.select_dtypes(numericClassifier)
columns = newdf.columns.tolist()
n = len(columns)
cmap = get_cmap(n)
if(isSqrt(n)):
    nrows = ncols = int(np.sqrt(n))
else:
    ncols = int(np.sqrt(n))
    for i in range(ncols,50):
        if ncols*i >= n:
            nrows = i
            break
        else:
            pass
fig,ax = plt.subplots(nrows,ncols)
count = 0
print(nrows,ncols)
for i in range(0,nrows,1):
    for j in range(0,ncols,1):
        print('ncols = {}'.format(j),'nrows = {}'.format(i),'count = {}'.format(count))
        if count<=n-1:
            plt_new = sns.histplot(df[columns[count]],ax=ax[i,j],facecolor=cmap(count),kde=True,edgecolor='black',label=df[columns[count]].name)
            patches = plt_new.get_children()
            for patch in patches:
                patch.set_alpha(0.8)
            color = patches[0].get_facecolor()
            ax[i,j].set_xlabel('{}'.format(df[columns[count]].name))
            ax[i,j].xaxis.label.set_fontsize(10)
            ax[i,j].xaxis.label.set_fontname('ariel')
            ax[i,j].set(xlabel=None)
            ax[i,j].tick_params(axis='y', labelsize=8)
            count+=1
        else:
            break
    
for i in range(0,nrows,1):
    for j in range(0,ncols,1):
        if not ax[i,j].has_data():
            fig.delaxes(ax[i,j])
        else:
            pass

plt.suptitle('Histograms').set_fontname('ariel')
plt.tight_layout()
fig.legend(loc='upper right')
plt.show()

下面是输出:

eblbsuwk

eblbsuwk1#

sns.histplot似乎创建了两个bar容器。首先是一个虚拟容器,然后是真实的的容器。(用seaborn0.12.1测试;这在其它版本中可能不同)。2因此,标签被分配给虚拟和真实的棒容器。3一个变通方案是移除虚拟棒容器的标签。
以下是修改后的代码。Seaborn的mpg数据集被用作一个易于复制的示例。由于hls颜色Map表的第一个和最后一个颜色是红色,get_cmap(n + 1)确保n选择不同的颜色。一些多余的代码已被删除。

import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np

def get_cmap(n, name='hsv'):
    return plt.cm.get_cmap(name, n)

sns.set_style('darkgrid')
df = sns.load_dataset('mpg')
numericClassifier = ['int16', 'int32', 'int64', 'float16', 'float32', 'float64']
newdf = df.select_dtypes(numericClassifier)
columns = newdf.columns.tolist()
n = len(columns)
cmap = get_cmap(n + 1)
ncols = int(np.sqrt(n))
nrows = int(np.ceil(n / ncols))
fig, ax = plt.subplots(nrows, ncols)
count = 0
print(nrows, ncols)
for i in range(0, nrows):
    for j in range(0, ncols):
        if count < n:
            # print('ncols = {j}; nrows = {i}; count = {count}')
            sns.histplot(df[columns[count]], ax=ax[i, j], facecolor=cmap(count), kde=True, edgecolor='black',
                         label=df[columns[count]].name)
            ax[i, j].containers[0].set_label('')  # seaborn seems to create a dummy bar container, remove its label
            for patch in ax[i, j].get_children():
                patch.set_alpha(0.8)
            ax[i, j].tick_params(axis='y', labelsize=8)
            count += 1
for i in range(0, nrows):
    for j in range(0, ncols):
        if not ax[i, j].has_data():
            fig.delaxes(ax[i, j])

plt.suptitle('Histograms').set_fontname('ariel')
fig.legend(loc='upper right')
plt.tight_layout()
plt.subplots_adjust(right=0.75) # make extra space for the legend
plt.show()

经进一步调查,似乎在使用color=而非facecolor=调用sns.histplot时未创建引锭杆容器。
代码也可以写得更“Python”一点。这意味着a.o.要尽量避免重复代码和显式索引。要做到这一点,zip是一个重要的帮手。除了避免重复,代码也变得更短,更容易修改。一旦你习惯了,它就变得更容易阅读和推理。
主要部分可能如下所示:

fig, axs = plt.subplots(nrows=nrows, ncols=ncols)
for column, ax, color in zip(columns, axs.flat, cmap(range(n))):
    # using `color=` instead of `facecolor=` seems to avoid the creating of dummy bars
    sns.histplot(df[column], ax=ax, color=color, kde=True, edgecolor='black', label=column)
    for patch in ax.get_children():
        patch.set_alpha(0.8)
    ax.tick_params(axis='y', labelsize=8)
for ax in axs.flat:
    if not ax.has_data():
        fig.delaxes(ax)

相关问题