matplotlib 如何在add_subplot中设置列和行?

y0u0uwnf  于 2022-11-15  发布在  其他
关注(0)|答案(1)|浏览(141)

我试图生成16个子图,我的最终目标是有一个8 x 2大小的最终图,我的代码看起来像这样:

def visualize_t2t(token_dict, scores):

  fig = plt.figure(figsize=(50, 50))

  for idx, scores in enumerate(scores):
      scores_np = np.array(scores)

      ax = fig.add_subplot(12, 12, idx+1)
      # append the attention weights
      im = ax.imshow(scores, cmap='viridis')

      fontdict = {'fontsize': 3}

      ax.set_xticks(range(len(all_tokens)))
      ax.set_yticks(range(len(all_tokens)))
      ax.set_xticklabels(all_tokens, fontdict=fontdict, rotation=90)
      ax.set_yticklabels(all_tokens, fontdict=fontdict)
      ax.set_xlabel('{} {}'.format('label_name', idx+1))

      fig.colorbar(im, fraction=0.046, pad=0.04)
  
  plt.tight_layout()
  name_f = str(uuid.uuid4())
  plt.savefig(f'{name_f}.pdf',
  bbox_inches='tight',
    dpi=350)

输入数据

all_tokens = ['[CLS]',
 'what',
 'type',
 'of',
 'heart',
 'issue',
 'does',
 'the',
 'Person',
 'have',
 '[CLS]']

dummy_input = np.random.uniform(-1, 1, [16, len(all_tokens), len(all_tokens)])
visualize_t2t(all_tokens, dummy_input)

但结果看起来是这样的:

如何设置行和列,使一行中有8个子图,而另一行中有8个子图?

z6psavjg

z6psavjg1#

只需将ax = fig.add_subplot(12, 12, idx+1)替换为ax = fig.add_subplot(2, 8, idx+1)即可。

相关问题