matplotlib 如何分隔图形中的最后一个子图线和最后一个子图列

b4wnujal  于 2023-04-06  发布在  其他
关注(0)|答案(2)|浏览(194)

使用matplotlib,我绘制了n x n个子图(在这个例子中n=8):

fig, ax = plt.subplots(8,8, figsize=(18,10), sharex='col', sharey='row')
fig.subplots_adjust(hspace=0, wspace=0)

我想把最后一行和最后一列分开,得到这样的结果:

我必须使用Gridspec(如在答案https://stackoverflow.com/a/54747473/7462275中)与for循环,因为n是一个参数,或者有更简单的解决方案吗?谢谢回答。

编辑1

在回答https://stackoverflow.com/a/54747473/7462275之后,我写了这段代码。

import matplotlib
n=7
m=5
outer_gs = matplotlib.gridspec.GridSpec(2,2, height_ratios=[n,1], width_ratios=[m,1], hspace=0.1, wspace=0.1)
inner_gs = matplotlib.gridspec.GridSpecFromSubplotSpec(n,m, subplot_spec=outer_gs[0,0], hspace=0, wspace=0)
line_gs = matplotlib.gridspec.GridSpecFromSubplotSpec(n,1, subplot_spec=outer_gs[0,1], hspace=0, wspace=0.1)
col_gs = matplotlib.gridspec.GridSpecFromSubplotSpec(1,m, subplot_spec=outer_gs[1,0], hspace=0.1, wspace=0)
ax_11=[]
ax_12=[]
ax_21=[]
ax_22=[]
fig = plt.figure(figsize=(18,10))
for i in range (n):
    ax_11.append([])
    for j in range(m):
        ax_11[i].append(fig.add_subplot(inner_gs[i,j]))
        
for i in range(n):
        ax_12.append(fig.add_subplot(line_gs[i,0]))
        
for i in range(m):
        ax_21.append(fig.add_subplot(col_gs[0,i]))

ax_22 = fig.add_subplot(outer_gs[1,1])

我得到了

但是,现在我必须删除内部轴标签,只保留外部轴标签。

编辑2

它使用以下代码:

import matplotlib
n=7
m=5
outer_gs = matplotlib.gridspec.GridSpec(2,2, height_ratios=[n,1], width_ratios=[m,1], hspace=0.1, wspace=0.1)
inner_gs = matplotlib.gridspec.GridSpecFromSubplotSpec(n,m, subplot_spec=outer_gs[0,0], hspace=0, wspace=0)
line_gs = matplotlib.gridspec.GridSpecFromSubplotSpec(n,1, subplot_spec=outer_gs[0,1], hspace=0, wspace=0.1)
col_gs = matplotlib.gridspec.GridSpecFromSubplotSpec(1,m, subplot_spec=outer_gs[1,0], hspace=0.1, wspace=0)
ax_11=[]
ax_12=[]
ax_21=[]
ax_22=[]
fig = plt.figure(figsize=(18,10))
for i in range (n):
    ax_11.append([])
    for j in range(m):
        ax_11[i].append(fig.add_subplot(inner_gs[i,j]))
        if not ax_11[i][j].get_subplotspec().is_first_col():
            ax_11[i][j].axes.yaxis.set_visible(False)
        ax_11[i][j].axes.xaxis.set_visible(False)
        
for i in range(n):
        ax_12.append(fig.add_subplot(line_gs[i,0]))
        ax_12[i].axes.yaxis.set_visible(False)
        ax_12[i].axes.xaxis.set_visible(False)
        
for i in range(m):
        ax_21.append(fig.add_subplot(col_gs[0,i]))
        if not ax_21[i].get_subplotspec().is_first_col():
            ax_21[i].axes.yaxis.set_visible(False)

ax_22 = fig.add_subplot(outer_gs[1,1])
ax_22.axes.yaxis.set_visible(False);

但是,我想知道这是否是获得这样的图的最简单的方法。我正在发现matplotlib,我不确定这段代码是否好,即使它给出了预期的结果。

jljoyd4f

jljoyd4f1#

另一种方法是继续使用plt.subplots()创建轴,但差距添加额外的列和行,可以在创建后删除。
请注意,此方法意味着您的轴将继续像以前一样共享,并且不需要为特定轴打开或关闭记号标签。
使用gridspec_kwheight_ratioswidth_ratios选项,我们可以控制这些间隙的高度和宽度。在下面的示例中,我将它们设置为常规轴尺寸的五分之一,但这很容易根据您的需要进行更改。
我们可以使用.remove()方法移除虚拟轴,然后它们不会妨碍我们的ax数组,我们可以使用numpy.delete移除这些切片。
例如:

import matplotlib.pyplot as plt
import numpy as np

fig, ax = plt.subplots(9, 9, figsize=(18, 10), sharex='col', sharey='row',
                       gridspec_kw={
                           'height_ratios': (1, 1, 1, 1, 1, 1, 1, 0.2, 1), 
                           'width_ratios': (1, 1, 1, 1, 1, 1, 1, 0.2, 1),
                           'hspace': 0, 'wspace': 0
                           })

# remove 8th column
for dax in ax[:, 7]:
    dax.remove()
ax = np.delete(ax, 7, 1)

# remove 8th row
for dax in ax[7, :]:
    dax.remove()
ax = np.delete(ax, 7, 0)

plt.show()

py49o6xq

py49o6xq2#

基于此answerdocumentation,您可以使用tick_params更改刻度和刻度标签的外观。
我在代码中添加了一些if语句和tick_params。

import matplotlib.pyplot as plt
import matplotlib
n=7
m=5
outer_gs = matplotlib.gridspec.GridSpec(2,2, height_ratios=[n,1],
width_ratios=
[m,1], hspace=0.1, wspace=0.1)
inner_gs = matplotlib.gridspec.GridSpecFromSubplotSpec(n,m, 
subplot_spec=outer_gs[0,0], hspace=0, wspace=0)
line_gs = matplotlib.gridspec.GridSpecFromSubplotSpec(n,1, 
subplot_spec=outer_gs[0,1], hspace=0, wspace=0.1)
col_gs = matplotlib.gridspec.GridSpecFromSubplotSpec(1,m, 
subplot_spec=outer_gs[1,0], hspace=0.1, wspace=0)
ax_11=[]
ax_12=[]
ax_21=[]
ax_22=[]
fig = plt.figure(figsize=(18,10))
for i in range (n):
    ax_11.append([])
    for j in range(m):
        ax_11[i].append(fig.add_subplot(inner_gs[i,j]))
        plt.tick_params(labelbottom=False, length=0)
        if j!=0:
            plt.tick_params(labelleft=False, length=0)

for i in range(n):
        ax_12.append(fig.add_subplot(line_gs[i,0]))
        plt.tick_params(labelleft=False, length=0)
        plt.tick_params(labelbottom=False, length=0)

for i in range(m):
    ax_21.append(fig.add_subplot(col_gs[0,i]))
    if i!=0:
        plt.tick_params(labelleft=False, length=0)

ax_22 = fig.add_subplot(outer_gs[1,1])
plt.tick_params(labelleft=False, length=0)

plt.show()

下面是结果。

相关问题