matplotlib Python中的PLS-DA加载图

qqrboqgw  于 2022-11-15  发布在  Python
关注(0)|答案(1)|浏览(114)

如何使用Matplotlib制作PLS-DA图的载荷图,就像PCA的载荷图一样?
此答案解释了如何使用PCA完成此操作:绘制PCA载荷和sklearn中双标图的载荷(类似于R的autoplot)
然而,这两种方法之间存在一些显著的差异,这使得实现方式也有所不同。(此处解释了一些相关差异https://learnche.org/pid/latent-variable-modelling/projection-to-latent-structures/interpreting-pls-scores-and-loadings
要绘制PLS-DA图,我使用以下代码:

from sklearn.preprocessing import StandardScaler
from sklearn.cross_decomposition import PLSRegression
import numpy as np
import pandas as pd

targets = [0, 1]

x_vals = StandardScaler().fit_transform(df.values)

y = [g == targets[0] for g in sample_description]
y = np.array(y, dtype=int)

plsr = PLSRegression(n_components=2, scale=False)
plsr.fit(x_vals, y)

colormap = {
    targets[0]: '#ff0000',  # Red
    targets[1]: '#0000ff',  # Blue
}

colorlist = [colormap[c] for c in sample_description]

scores = pd.DataFrame(plsr.x_scores_)
scores.index = x.index

x_loadings = plsr.x_loadings_
y_loadings = plsr.y_loadings_

fig1, ax = get_default_fig_ax('Scores on LV 1', 'Scores on LV 2', title)
ax = scores.plot(x=0, y=1, kind='scatter', s=50, alpha=0.7,
                 c=colorlist, ax=ax)
lf5gs5x2

lf5gs5x21#

我对你的代码进行了改进。双标图是通过简单地叠加分数和加载图来获得的。其他更严格的图可以根据https://blogs.sas.com/content/iml/2019/11/06/what-are-biplots.html#:~:text= A%20biplot%20is%20an%20overlay,them%20on%20a%20single%20plot用真正共享的轴来制作。
下面的代码为包含约200个要素的数据集生成此图像(因此显示了约200个红色箭头):

from sklearn.cross_decomposition import PLSRegression
pls2 = PLSRegression(n_components=2)
pls2.fit(X_train, Y_train)

x_loadings = pls2.x_loadings_
y_loadings = pls2.y_loadings_

fig, ax = plt.subplots(constrained_layout=True)

scores = pd.DataFrame(pls2.x_scores_)
scores.plot(x=0, y=1, kind='scatter', s=50, alpha=0.7,
                 c=Y_train.values[:,0], ax = ax)

newax = fig.add_axes(ax.get_position(), frameon=False)
feature_n=x_loadings.shape[0]
print(x_loadings.shape)
for feature_i in range(feature_n):
    comp_1_idx=0
    comp_2_idx=1
    newax.arrow(0, 0, x_loadings[feature_i,comp_1_idx], x_loadings[feature_i,comp_2_idx],color = 'r',alpha = 0.5)
newax.get_xaxis().set_visible(False)
newax.get_yaxis().set_visible(False)

plt.show()

相关问题