假设我有一个Py spark Dataframe ,如下所示:
b1 b2 b3 b4 b5 b6
1 1 1 0 1 1
test_df = spark.createDataFrame([
(1,1,1,0,1,1)
], ("b1", "b2","b3","b4","b5","b6"))
这里,最长的连续1的长度是3,我有一个1M行的大数据集,对于每一行,我想这样计算。
所以,我最初的想法是创建一个新的列,并将列中的每个值连接起来。所以,我遵循了这样的方法。首先,我将所有列的值连接成一个新值:
test_df = test_df.withColumn('Ind', F.concat(*cols))
我得到了这样的 Dataframe :
b1 b2 b3 b4 b5 b6 ind
1 1 1 0 1 1 '111011'
然后创建一个单独的UDF:
def findMaxConsecutiveOnes(X) -> int:
nums = [int(j) for a,j in enumerate(X)]
count = 0
maxCount = 0
for idx, num in enumerate(nums):
if num == 1:
count += 1
if num == 0 or idx == len(nums) - 1:
maxCount = max(maxCount, count)
count = 0
return maxCount
然后创建了一个UDF:
maxcon_udf = udf(lambda x: findMaxConsecutiveOnes(x))
最后,
test_df = test_df.withColumn('final', maxcon_udf('ind'))
但是,这显示错误。有人能帮我解决这个问题吗?
2条答案
按热度按时间bqf10yzr1#
在添加
cols=test_df.columns
并相应地缩进findMaxConsecutiveOnes
函数之后,您的代码可以正常工作。然而,我建议尽可能避免在pyspark中使用UDF,因为在spark中执行python代码的开销相当大。
您可以通过以下方式解决Spark函数的问题:
1.将
ind
拆分到非1的任意值上1.计算每个结果子字符串的长度
1.得到最大值。
hi3rlvi22#
你的代码工作得很好,也许你忘了用test_df.columns替换cols: