d3.js 如何用文本热图可视化注意力向量?

hsvhsicv  于 2022-11-12  发布在  其他
关注(0)|答案(1)|浏览(241)

我正在做一个NLP研究项目,我想可视化注意力向量的输出。
例如,数据如下所示:

def sample_data():

    sent = '''the USS Ronald Reagan - an aircraft carrier docked in Japan - during his tour of the region, vowing to "defeat any attack and meet any use of conventional or nuclear weapons with an overwhelming and effective American response".'''

    words    = sent.split()
    word_num = len(words)
    attention = [(x+1.)/word_num*100 for x in range(word_num)]

    return {'text': words, 'attention': attention}

如下所示:

{'text': ['the', 'USS', 'Ronald', 'Reagan', '-', 'an', 'aircraft', 'carrier', 'docked', 'in', 'Japan', '-', 'during', 'his', 'tour', 'of', 'the', 'region,', 'vowing', 'to', '"defeat', 'any', 'attack', 'and', 'meet', 'any', 'use', 'of', 'conventional', 'or', 'nuclear', 'weapons', 'with', 'an', 'overwhelming', 'and', 'effective', 'American', 'response".'], 'attention': [2.564102564102564, 5.128205128205128, 7.6923076923076925, 10.256410256410255, 12.82051282051282, 15.384615384615385, 17.94871794871795, 20.51282051282051, 23.076923076923077, 25.64102564102564, 28.205128205128204, 30.76923076923077, 33.33333333333333, 35.8974358974359, 38.46153846153847, 41.02564102564102, 43.58974358974359, 46.15384615384615, 48.717948717948715, 51.28205128205128, 53.84615384615385, 56.41025641025641, 58.97435897435898, 61.53846153846154, 64.1025641025641, 66.66666666666666, 69.23076923076923, 71.7948717948718, 74.35897435897436, 76.92307692307693, 79.48717948717949, 82.05128205128204, 84.61538461538461, 87.17948717948718, 89.74358974358975, 92.3076923076923, 94.87179487179486, 97.43589743589743, 100.0]}

每个标记都被分配给一个浮点值(注意力得分)。有什么选项可以可视化这些数据?有任何库/工具可以在任何语言R/Python/JS中使用吗?

eimct9ow

eimct9ow1#

一个处理长句的解决方案是在控制台中打印一个彩色的句子。你可以通过在控制台中打印转义字符来实现这一点:\033[38;2;255;0;0m test \033[0m将在控制台中打印红色test(RGB代码(255,0,0))。
通过使用这个想法,我们可以制作一个从绿色到红色的渐变(低到高的关注度)并打印文本:

import numpy as np

data = sample_data()

def colorFader(c1,c2,mix=0):
    return (1-mix)*np.array(c1) + mix*np.array(c2)
def colored(c, text):
    return "\033[38;2;{};{};{}m{} \033[0m".format(int(c[0]), int(c[1]), int(c[2]), text)

normalizer = max(dic["attention"])
output = ""
for word, attention in zip(dic["text"], dic["attention"]):
    color = colorFader([0, 255, 0], [255, 0, 0], mix=attention/normalizer)
    output += colored(color, word)

print(output)

此解决方案将在控制台上输出如下内容:

我发现这是一个有效的可视化工具,但事实上,它在控制台中进行可视化可能不是一件好事。
另一种方法是绘制热图:

import matplotlib.pyplot as plt

data = sample_data()

# Create a pyplot figure

fig, ax = plt.subplots(1, 1)

# Creating the heatmap image with the <plasma> colormap

img = ax.imshow([data["attention"]], cmap='plasma', aspect='auto', extent=[-1,1,-1,1])

# Setting the x_ticks position to be in the middle of the corresponding color

ax.set_xticks([-1 + (i+0.5)*2/len(data["text"]) for i in range(len(data["text"]))])

# Setting the x_ticks labels as the text, rotated to 80° for space purpose

ax.set_xticklabels(data["text"], rotation=80)

# Display the heatmap

plt.show()

这给出了下面的结果(在此基础上,您可以修改一些参数,如高度、宽度、颜色等...)

如果您的句子很长,这可能不是最佳解决方案,因为刻度标签将很难看到。

相关问题