tensorflow Seaborn Heatmap -仅当值高于给定阈值时显示热图

ssm49v7z  于 2023-08-06  发布在  其他
关注(0)|答案(2)|浏览(125)

下面的python代码显示句子相似度,它使用通用句子编码器来实现相同的功能。

from absl import logging

import tensorflow as tf

import tensorflow_hub as hub
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import re
import seaborn as sns

module_url = "https://tfhub.dev/google/universal-sentence-encoder/4" 
model = hub.load(module_url)
print ("module %s loaded" % module_url)
def embed(input):
  return model(input)

def plot_similarity(labels, features, rotation):
  corr = np.inner(features, features)
  print(corr)
  sns.set(font_scale=2.4)
  plt.subplots(figsize=(40,30))
  g = sns.heatmap(
      corr,
      xticklabels=labels,
      yticklabels=labels,
      vmin=0,
      vmax=1,
      cmap="YlGnBu",linewidths=1.0)
  g.set_xticklabels(labels, rotation=rotation)
  g.set_title("Semantic Textual Similarity")

def run_and_plot(messages_):
  message_embeddings_ = embed(messages_)
  plot_similarity(messages_, message_embeddings_, 90)

messages = [
"I want to know my savings account balance",
"Show my bank balance",
"Show me my account",
"What is my bank balance",
"Please Show my bank balance"    

]

run_and_plot(messages)

字符串
输出显示为热图,如下所示,还打印值

我只想关注那些看起来很相似的句子,但是当前的热图显示了所有的值。
所以呢
1.有没有一种方法可以让我只查看范围大于0.6且小于0.999的值的热图?
1.是否可以打印位于给定范围内的匹配值对,即0.6和0.99?谢了,拉黑特

dxxyhpgq

dxxyhpgq1#

根据您的问题更新,这里是一个修订版本。显然,在网格中,不能删除单个单元格。但是我们可以大幅减少热图,只显示相关的值对。热图中存在的随机分散的显著值越多,这种效果就越不明显。

import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from copy import copy
import seaborn as sns

#semi-random data generation 
labels = list("ABCDE")
np.random.seed(123)
df = pd.DataFrame(np.random.randint(1, 100, (20, 5)))
df.columns = labels
df.A = df.B - df.D
df.C = df.B + df.A
df.E = df.A + df.C

#your correlation array
corr = df.corr().to_numpy()
print(corr)

#conditions for filtering 0.6<=r<=0.9
val_min = 0.6
val_max = 0.99

#plotting starts here
sns.set(font_scale=2.4)
#two axis objects just for comparison purposes
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15,8))

#define the colormap with clipping values
my_cmap = copy(plt.cm.YlGnBu)
my_cmap.set_over("white")
my_cmap.set_under("white")

#ax1 - full set of conditions as in the initial version 
g1 = sns.heatmap(corr,
    xticklabels=labels,
    yticklabels=labels,
    vmin=val_min,
    vmax=val_max,
    cmap=my_cmap,
    linewidths=1.0,
    linecolor="grey",
    ax=ax1)

g1.set_title("Entire heatmap")

#ax2 - remove empty rows/columns
# use only lower triangle
corr = np.tril(corr)

#delete columns where all elements do not fulfill the conditions
ind_x,  = np.where(np.all(np.logical_or(corr<val_min, corr>val_max), axis=0))
corr = np.delete(corr, ind_x, 1)
#update x labels
map_labels_x = [item for i, item in enumerate(labels) if i not in ind_x]
    
#now the same for rows 
ind_y, = np.where(np.all(np.logical_or(corr<val_min, corr>val_max), axis=1))
corr = np.delete(corr, ind_y, 0)
#update y labels
map_labels_y = [item for i, item in enumerate(labels) if i not in ind_y]

#plot heatmap
g2 = sns.heatmap(corr,
    xticklabels=map_labels_x,
    yticklabels=map_labels_y,
    vmin=val_min,
    vmax=val_max,
    cmap=my_cmap,
    linewidths=1.0,
    linecolor="grey", ax=ax2)

g2.set_title("Reduced heatmap")

plt.show()

字符串
样品输出:
x1c 0d1x的数据
左图,原始方法显示热图的所有元素。对,只保留相关的配对。该问题(以及因此代码)排除了显著的负相关性,例如-0.95。如果不打算这样做,则应使用np.abs()

初始答案

我很惊讶还没有人提供一个独立的解决方案,所以这里有一个:

import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from copy import copy
import seaborn as sns

labels = list("ABCDE")
#semi-random data
np.random.seed(123)
df = pd.DataFrame(np.random.randint(1, 100, (20, 5)))
df.columns = labels
df.A = df.B - df.D
df.E = df.A + df.C

corr = df.corr()
sns.set(font_scale=2.4)
plt.subplots(figsize=(10,8))

#define the cmap with clipping values
my_cmap = copy(plt.cm.YlGnBu)
my_cmap.set_over("white")
my_cmap.set_under("white")

g = sns.heatmap(corr,
    xticklabels=labels,
    yticklabels=labels,
    vmin=0.5,
    vmax=0.9,
    cmap=my_cmap,
    linewidths=1.0,
    linecolor="grey")

g.set_xticklabels(labels, rotation=60)
g.set_title("Important!")

plt.show()


样品输出:

waxmsbnn

waxmsbnn2#

所提供的代码作为@Mr.T here提出的概念的重新实现。然而,这个特定的实现不需要创建标签,因为它只对pandas dataframe对象的操作进行操作,与@Mr.T的解决方案相反,它主要涉及numpy数组对象的操作。

import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

# semi-random data generation
labels = list("ABCDE")
np.random.seed(123)
df = pd.DataFrame(np.random.randint(1, 100, (20, 5)))
df.columns = labels
df.A = df.B - df.D
df.C = df.B + df.A
df.E = df.A + df.C

val_min = 0.6
val_max = 0.999

# Calculate the correlation
corr = df.corr()

# Mask values that is not fall between min and max value
corr_selected = corr.mask(((corr < val_min) | (corr > val_max)), float("NaN"))

# Get the upper triangular matrix
# Use `tril` instead of `triu` if the lower triangular matrix is needed
# Use `np.bool_` instead of `np.bool` if you using NumPy >= 1.20
corr_selected = corr_selected.where(
    np.triu(np.ones(corr_selected.shape)).astype(np.bool_)
)

# Remove rows that contains only NaN
corr_selected = corr_selected.dropna(
    axis=0,
    how="all",
)

# Remove columns that contains only NaN
corr_selected = corr_selected.dropna(
    axis=1,
    how="all",
)

selected = sns.heatmap(
    corr_selected,
    xticklabels=1,
    yticklabels=1,
    vmin=val_min,
    vmax=val_max,
    linewidths=1.0,
    linecolor="grey",
    annot=True,
)

original = sns.heatmap(
    corr,
    xticklabels=1,
    yticklabels=1,
    vmin=val_min,
    vmax=val_max,
    linewidths=1.0,
    linecolor="grey",
    annot=True,
)

plt.show()

字符串

选中的热图

x1c 0d1x的数据

原始热图


相关问题