python 删除具有缺失值的大型Pyspark Dataframe 的高度相关列

egdjgwm8  于 2023-03-07  发布在  Python
关注(0)|答案(1)|浏览(147)

我有一个包含600万行和2k列的庞大 Dataframe 。我想删除高度相关的列,其中许多列是超级稀疏的(90%以上的缺失值)。不幸的是,Pyspark Correlation不处理缺失值,AFAIK。这就是为什么我必须循环遍历列并计算相关性。
下面是重现它的小代码:

from pyspark.sql import SparkSession

spark = SparkSession.builder.appName("test").getOrCreate()
l = [
    (7, -5, -8, None, 1, 456, 8),
    (2, 9, 7, 4, None, 9, -1),
    (-3, 3, None, 6, 0, 11, 9),
    (4, -1, 6, 7, 82, 99, 54),
]
names = ["colA", "colB", "colC", "colD", "colE", "colF", "colG"]
db = spark.createDataFrame(l, names)
db.show()

#+----+----+----+----+----+----+----+
#|colA|colB|colC|colD|colE|colF|colG|
#+----+----+----+----+----+----+----+
#|   7|  -5|  -8|null|   1| 456|   8|
#|   2|   9|   7|   4|null|   9|  -1|
#|  -3|   3|null|   6|   0|  11|   9|
#|   4|  -1|   6|   7|  82|  99|  54|
#+----+----+----+----+----+----+----+
from pyspark.ml.feature import VectorAssembler

newdb = (
    VectorAssembler(inputCols=db.columns, outputCol="features")
    .setHandleInvalid("keep")
    .transform(db)
)
newdb.show()

#+----+----+----+----+----+----+----+--------------------+
#|colA|colB|colC|colD|colE|colF|colG|            features|
#+----+----+----+----+----+----+----+--------------------+
#|   7|  -5|  -8|null|   1| 456|   8|[7.0,-5.0,-8.0,Na...|
#|   2|   9|   7|   4|null|   9|  -1|[2.0,9.0,7.0,4.0,...|
#|  -3|   3|null|   6|   0|  11|   9|[-3.0,3.0,NaN,6.0...|
#|   4|  -1|   6|   7|  82|  99|  54|[4.0,-1.0,6.0,7.0...|
#+----+----+----+----+----+----+----+--------------------+

相关函数无法处理缺失值。

from pyspark.ml.stat import Correlation

Correlation.corr(
    dataset=newdb.select("features"), column="features", method="pearson"
).collect()[0]["pearson(features)"].values

# array([ 1.        , -0.59756161,         nan,         nan,         nan,
#        0.79751788,  0.21792969, -0.59756161,  1.        ,         nan,
#               nan,         nan, -0.82202347, -0.40825556,         nan,
#               nan,  1.        ,         nan,         nan,         nan,
#               nan,         nan,         nan,         nan,  1.        ,
#               nan,         nan,         nan,         nan,         nan,
#               nan,         nan,  1.        ,         nan,         nan,
#        0.79751788, -0.82202347,         nan,         nan,         nan,
#        1.        , -0.06207047,  0.21792969, -0.40825556,         nan,
#               nan,         nan, -0.06207047,  1.        ])

我使用了一个for循环,但这个循环不适用于我的大数据:

import numpy as np
from pyspark.mllib.linalg import Vectors
from pyspark.mllib.stat import Statistics

df_vector = newdb
num_cols = 7
res = np.ones((num_cols, num_cols), dtype=np.float32)
for i in range(1, num_cols):
    for j in range(i):
        feature_pair_df = df_vector.select("features").rdd.map(
            lambda x: Vectors.dense([x[0][i], x[0][j]])
        )
        feature_pair_df = feature_pair_df.filter(
            lambda x: not np.isnan(x[0]) and not np.isnan(x[1])
        )
        corr_matrix = Statistics.corr(feature_pair_df, method="spearman")
        corr = corr_matrix[0, 1]
        res[i, j], res[j, i] = corr, corr
res

#array([[ 1. , -0.8, -1. ,  0.5,  0.5,  0.8,  0. ],
#       [-0.8,  1. ,  1. , -1. , -0.5, -1. , -0.4],
#       [-1. ,  1. ,  1. , -1. ,  1. , -1. , -0.5],
#       [ 0.5, -1. , -1. ,  1. ,  1. ,  1. ,  1. ],
#       [ 0.5, -0.5,  1. ,  1. ,  1. ,  0.5,  0.5],
#       [ 0.8, -1. , -1. ,  1. ,  0.5,  1. ,  0.4],
#       [ 0. , -0.4, -0.5,  1. ,  0.5,  0.4,  1. ]], dtype=float32)

我怎样写它,以便我可以找到一个大数据集的相关矩阵?Map而不是循环或任何类似的想法。

6vl6ewon

6vl6ewon1#

Python3

from pyspark.sql.functions import col

# Compute the percentage of missing values for each column
missing_values = newdb.select([
    (1 - (count(c) / count("*"))).alias(c+"_missing") 
    for c in newdb.columns
])

# Get the list of columns with more than 90% missing values
sparse_columns = missing_values.columns[1:] # exclude the "features" column
sparse_columns = [c for c in sparse_columns if missing_values.select(c).collect()[0][0] >= 0.9]
# Drop the columns with more than 90% missing values
newdb_filtered = newdb.drop(*sparse_columns)

这段代码计算每列缺失值的百分比,并创建一个新的 Dataframe missing_values,其中每个原始列都有一个新列,并附加了_missing后缀。然后,它选择缺失值超过90%的列,并将其放入名为sparse_columns的列表中。一旦获得稀疏列列表,就可以使用drop方法将其从 Dataframe 中删除。

相关问题