matplotlib 基于现有颜色系列添加图例

ttcibm8c  于 2023-05-18  发布在  其他
关注(0)|答案(2)|浏览(103)

我使用散点图绘制了一些数据,并将其指定为:

plt.scatter(rna.data['x'], rna.data['y'], s=size,
                    c=rna.data['colors'], edgecolors='none')

并且rna.data对象是Pandas数据框架,其被组织成使得每行表示数据点(“x”和“y”表示坐标,“colors”是0-5之间的整数,表示点的颜色)。我将数据点分成6个不同的聚类,编号为0-5,并将聚类编号放在每个聚类的平均坐标处。
这将输出以下图形:

我想知道如何在这个图中添加一个图例,指定颜色和它对应的簇号。plt.legend()要求样式代码的格式为red_patch,但它似乎不接受数值(或数字字符串)。那么我如何使用matplotlib添加这个图例呢?有没有办法将我的数值颜色代码转换为plt.legend()采用的格式?多谢了!

agyaoht7

agyaoht71#

可以使用颜色基于散点图的颜色Map和归一化,使用空图创建图例控制柄。

import pandas as pd
import numpy as np; np.random.seed(1)
import matplotlib.pyplot as plt

x = [np.random.normal(5,2, size=20), np.random.normal(10,1, size=20),
     np.random.normal(5,1, size=20), np.random.normal(10,1, size=20)]
y = [np.random.normal(5,1, size=20), np.random.normal(5,1, size=20),
     np.random.normal(10,2, size=20), np.random.normal(10,2, size=20)]
c = [np.ones(20)*(i+1) for i in range(4)]

df = pd.DataFrame({"x":np.array(x).flatten(), 
                   "y":np.array(y).flatten(), 
                   "colors":np.array(c).flatten()})

size=81
sc = plt.scatter(df['x'], df['y'], s=size, c=df['colors'], edgecolors='none')

lp = lambda i: plt.plot([],color=sc.cmap(sc.norm(i)), ms=np.sqrt(size), mec="none",
                        label="Feature {:g}".format(i), ls="", marker="o")[0]
handles = [lp(i) for i in np.unique(df["colors"])]
plt.legend(handles=handles)
plt.show()

或者,您可以通过颜色列中的值过滤数据框,例如:使用groubpy,并为每个要素绘制一个散点图:

import pandas as pd
import numpy as np; np.random.seed(1)
import matplotlib.pyplot as plt

x = [np.random.normal(5,2, size=20), np.random.normal(10,1, size=20),
     np.random.normal(5,1, size=20), np.random.normal(10,1, size=20)]
y = [np.random.normal(5,1, size=20), np.random.normal(5,1, size=20),
     np.random.normal(10,2, size=20), np.random.normal(10,2, size=20)]
c = [np.ones(20)*(i+1) for i in range(4)]

df = pd.DataFrame({"x":np.array(x).flatten(), 
                   "y":np.array(y).flatten(), 
                   "colors":np.array(c).flatten()})

size=81
cmap = plt.cm.viridis
norm = plt.Normalize(df['colors'].values.min(), df['colors'].values.max())

for i, dff in df.groupby("colors"):
    plt.scatter(dff['x'], dff['y'], s=size, c=cmap(norm(dff['colors'])), 
                edgecolors='none', label="Feature {:g}".format(i))

plt.legend()
plt.show()

两种方法生成相同的图:

lnvxswe2

lnvxswe22#

Altair可以是一个很好的选择。

连续类

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

df = pd.DataFrame(40*np.random.randn(10, 3), columns=['A', 'B','C'])

from altair import *
Chart(df).mark_circle().encode(x='A',y='B', color='C').configure_cell(width=200, height=150)

离散类

df = pd.DataFrame(10*np.random.randn(40, 2), columns=['A', 'B'])
df['C'] = np.random.choice(['alpha','beta','gamma','delta'], size=40)

from altair import *
Chart(df).mark_circle().encode(x='A',y='B', color='C').configure_cell(width=200, height=150)

相关问题