matplotlib 如何使用ax正确编辑图表图例

pod7payv  于 2023-03-09  发布在  其他
关注(0)|答案(1)|浏览(172)

我有一些代码输出以下图像:

我想编辑左下图中的图例。我想将其改为:
H_0:“星星”
H_1:“绿色”
堆叠在另一个的顶部,其中“星星”和“绿色”被替换为各自的实际符号(在图中使用)。

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits import mplot3d
import matplotlib.colors as colors

# Define the function to plot
def f(x, y):
    return np.sin(np.sqrt(x**2 + y**2))

# Generate data for the x, y, and z coordinates
x = np.linspace(-6, 6, 100)
y = np.linspace(-6, 6, 100)
X, Y = np.meshgrid(x, y)
Z = f(X, Y)
cmap = colors.ListedColormap(['black'])

# Create a 3D figure and a contour plot side by side
fig = plt.figure(figsize=(10, 8))
ax1 = fig.add_subplot(221, projection='3d')
ax2 = fig.add_subplot(222)
ax3 = fig.add_subplot(223)
ax4 = fig.add_subplot(224)

# NEED A NAME FOR THIS SECTION
Zero_dim_births = np.array([-1, -.25, .75])
One_dim_births = np.array([-1,.75])

# Plot the surface on the left subplot
ax1.plot_surface(X, Y, Z, cmap='jet')

i = 1 # intialize i
level_set_speed = .075 # how quickly the level sets expand
plane_speed = .05 # how quickly the plane moves up

for a in np.arange(-1,1.05,plane_speed): # controls the movement of the plane
    i += level_set_speed

    #Plot the plane moving up the surface on the left
    ax1.cla()
    plane = np.zeros_like(X)
    plane = np.zeros_like(X) + a
    ax1.plot_surface(X, Y, Z, cmap='jet')
    ax1.plot_wireframe(X, Y, plane, color='black')

    # Plot the contour on the right subplot
    contour_levels = np.arange(Z.min(), Z.min()+i, i/2)
    ax2.contourf(X, Y, Z, levels=contour_levels, cmap=cmap, extend='min')\
    
    # Plot the persistence diagram
    ax3.cla()
    input = np.arange(-1.1,1.1,.1)
    id_fun = input
    ax3.plot(input,id_fun, color = 'blue')
    ax3.axhline(y = a, color = 'black', linestyle = '-')

    for index in range(0,3):
        ax3.plot(Zero_dim_births[index], 1, marker = "o", markersize = 7, markeredgecolor = "green", markerfacecolor = "green")

    for index in range(0,2):
        ax3.plot(One_dim_births[index], 1,  marker = "*", markersize = 5, markeredgecolor = "red", markerfacecolor = "red")

    # Plot the barcodes
    ax4.cla()
    ax4.plot([-1,1], [1,1], color='green') # 0D component 1 
    ax4.plot([-1,1], [2,2], color='red') # 1D component 1
    ax4.plot([-.25,1], [3,3], color='green') # 0D component 2
    ax4.plot([.75,1], [4,4], color='green') # 0D component 3
    ax4.plot([.75,1], [5,5], color='red') # 1D component 4
    ax4.axvline(x = a, color = 'black', linestyle = '-')


    # Labels for all the plots
    # plot 1
    ax1.set_xlabel('x')
    ax1.set_ylabel('y')
    ax1.set_zlabel('z')
    ax1.set_title(r'$f(x, y) = sin(sqrt(x^2 + y^2))$')

    # plot 2
    ax2.set_title('Sublevel set filtration')

    ax3.set_xlabel('Birth (height)')
    ax3.set_ylabel('Death')
    ax3.set_title('Persistence Diagram')
    ax3.legend(['NEED TO EDIT'], loc='lower right')

    # plot 4
    ax4.set_xlabel('Persistence')
    ax4.set_ylabel('Components')
    ax4.set_yticks([])
    ax4.set_title('Barcodes')

    # snapshot
    plt.pause(.1)

# Show the plot
plt.tight_layout()
plt.show()

有谁能教我如何像我想的那样重新演绎这个传奇吗?

cigdeys3

cigdeys31#

这里的技巧是使用matplotlib lines的Line2D来创建一个“假”线,然后你可以从它创建一个图例。下面是一段代码:

import matplotlib.lines as mlines
# Create lines with markers
star = mlines.Line2D([], [], color='white', marker='*', markerfacecolor='r', markeredgecolor='r',ls='', label='H_0')
dot = mlines.Line2D([], [],  color='white', marker='o', markerfacecolor='g', markeredgecolor='g',ls='', label='H_1')
# Add legend
ax3.legend(handles=[star, dot], loc='lower right')

color='white'ls=''的存在使得图例中仅显示标记(星星、绿色),后面没有线。

以下是完整的代码,以防万一:

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits import mplot3d
import matplotlib.colors as colors
import matplotlib.lines as mlines

# Define the function to plot
def f(x, y):
    return np.sin(np.sqrt(x**2 + y**2))

# Generate data for the x, y, and z coordinates
x = np.linspace(-6, 6, 100)
y = np.linspace(-6, 6, 100)
X, Y = np.meshgrid(x, y)
Z = f(X, Y)
cmap = colors.ListedColormap(['black'])

# Create a 3D figure and a contour plot side by side
fig = plt.figure(figsize=(10, 8))
ax1 = fig.add_subplot(221, projection='3d')
ax2 = fig.add_subplot(222)
ax3 = fig.add_subplot(223)
ax4 = fig.add_subplot(224)

# NEED A NAME FOR THIS SECTION
Zero_dim_births = np.array([-1, -.25, .75])
One_dim_births = np.array([-1,.75])

# Plot the surface on the left subplot
ax1.plot_surface(X, Y, Z, cmap='jet')

i = 1 # intialize i
level_set_speed = .075 # how quickly the level sets expand
plane_speed = .05 # how quickly the plane moves up

for a in np.arange(-1,1.05,plane_speed): # controls the movement of the plane
    i += level_set_speed

    #Plot the plane moving up the surface on the left
    ax1.cla()
    plane = np.zeros_like(X)
    plane = np.zeros_like(X) + a
    ax1.plot_surface(X, Y, Z, cmap='jet')
    ax1.plot_wireframe(X, Y, plane, color='black')

    # Plot the contour on the right subplot
    contour_levels = np.arange(Z.min(), Z.min()+i, i/2)
    ax2.contourf(X, Y, Z, levels=contour_levels, cmap=cmap, extend='min') \
 \
    # Plot the persistence diagram
    ax3.cla()
    input = np.arange(-1.1,1.1,.1)
    id_fun = input
    ax3.plot(input,id_fun, color = 'blue')
    ax3.axhline(y = a, color = 'black', linestyle = '-')

    for index in range(0,3):
        ax3.plot(Zero_dim_births[index], 1, marker = "o", markersize = 7, markeredgecolor = "green",
                 markerfacecolor = "green", label='H_1')

    for index in range(0,2):
        ax3.plot(One_dim_births[index], 1,  marker = "*", markersize = 5, markeredgecolor = "red",
                 markerfacecolor = "red", label='H_0')

    # Plot the barcodes
    ax4.cla()
    ax4.plot([-1,1], [1,1], color='green') # 0D component 1
    ax4.plot([-1,1], [2,2], color='red') # 1D component 1
    ax4.plot([-.25,1], [3,3], color='green') # 0D component 2
    ax4.plot([.75,1], [4,4], color='green') # 0D component 3
    ax4.plot([.75,1], [5,5], color='red') # 1D component 4
    ax4.axvline(x = a, color = 'black', linestyle = '-')


    # Labels for all the plots
    # plot 1
    ax1.set_xlabel('x')
    ax1.set_ylabel('y')
    ax1.set_zlabel('z')
    ax1.set_title(r'$f(x, y) = sin(sqrt(x^2 + y^2))$')

    # plot 2
    ax2.set_title('Sublevel set filtration')

    ax3.set_xlabel('Birth (height)')
    ax3.set_ylabel('Death')
    ax3.set_title('Persistence Diagram')
    # Add legend
    star = mlines.Line2D([], [], color='white', marker='*', markerfacecolor='r', markeredgecolor='r',
                         ls='', label='H_0')
    dot = mlines.Line2D([], [],  color='white', marker='o', markerfacecolor='g', markeredgecolor='g',
                        ls='',  label='H_1')
    ax3.legend(handles=[star, dot], loc='lower right')

    # plot 4
    ax4.set_xlabel('Persistence')
    ax4.set_ylabel('Components')
    ax4.set_yticks([])
    ax4.set_title('Barcodes')

    # snapshot
    plt.pause(.1)

# Show the plot
plt.tight_layout()
plt.show()

希望这对你有帮助,干杯!

相关问题