matplotlib TSNE图很快消失

7cjasjjr  于 2022-11-15  发布在  其他
关注(0)|答案(1)|浏览(115)

我想在mnist数据集上使用t-SNE算法进行降维,稍后我想使用降维后的数据进行可视化(可能的聚类或分类),以下是我的代码:

`import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
import  seaborn as sns
from sklearn.preprocessing import StandardScaler
df =pd.read_csv('mnist_train.csv')
y =df['label']
X =df.drop('label',axis=1)
standardized_X =StandardScaler().fit_transform(X)
data_1000 = standardized_X[0:1000, :]
labels_1000 = y[0:1000]
model =TSNE(n_components=2,random_state=1)
transformed =model.fit_transform(data_1000)
tsne =np.vstack((transformed.T,labels_1000)).T
tsne_df = pd.DataFrame(data = tsne,
     columns =("Dim_1", "Dim_2", "label"))
#print(tsne_df.head())
sns.FacetGrid(tsne_df,hue='label',height=6).map(plt.scatter,'Dim_1','Dim_2')
plt.legend()
plt.show(block=False)

但是当我运行这段代码的时候,plt的图形很快就消失了,我怎么才能让图形停止这样的动作呢?是内存问题还是我应该添加一些小的行来保持图形打开?谢谢

dsekswqp

dsekswqp1#

我做了几个实验,找到了解决办法:将块参数设置为true,结果如下

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
import  seaborn as sns
from sklearn.preprocessing import StandardScaler
df =pd.read_csv('mnist_train.csv')
y =df['label']
X =df.drop('label',axis=1)
standardized_X =StandardScaler().fit_transform(X)
data_1000 = standardized_X[0:1000, :]
labels_1000 = y[0:1000]
model =TSNE(n_components=2,random_state=1)
transformed =model.fit_transform(data_1000)
tsne =np.vstack((transformed.T,labels_1000)).T
tsne_df = pd.DataFrame(data = tsne,
     columns =("Dim_1", "Dim_2", "label"))
#print(tsne_df.head())
#fig = plt.figure()
sns.FacetGrid(tsne_df,hue='label',height=6).map(plt.scatter,'Dim_1','Dim_2')
plt.legend()
plt.show(block=True)

enter image description here

相关问题