2 回答
TA贡献1884条经验 获得超4个赞
您不能将整列数据传递给 UDF,因为 Spark 引擎将计算和数据拆分到多个服务器/执行程序中。
如果您可以调整算法以处理列值的执行程序本地子集,则可以使用RDD.mapPartitions在完整的数据分区上执行单个函数。
或者,如果您知道您的列数据可以适合您的执行程序内存,您可以首先DataFrame.collect()列数据并使用SparkContext.broadcast()将其复制到所有执行程序并使用对UDF中广播变量的引用.
TA贡献1862条经验 获得超7个赞
设置:
>>> d = [{'id': 0, 'value': 1},{'id': 1, 'value': 3},{'id': 2, 'value': 9}]
>>> df0 = spark.createDataFrame(d)
>>> df0.show()
+---+-----+
| id|value|
+---+-----+
| 0| 1|
| 1| 3|
| 2| 9|
+---+-----+
第 1 步:使用collect_list()函数创建一个包含 column 中所有值的数组value,并将该数组作为列添加到初始数据帧中
>>> from pyspark.sql.functions import *
>>> arr = df0.agg(collect_list(df.value).alias('arr_column'))
>>> df1 = df0.crossJoin(arr)
>>> df1.show()
+---+-----+-------------+
| id|value| arr_column|
+---+-----+-------------+
| 0| 1| [1, 3, 9]|
| 1| 3| [1, 3, 9]|
| 2| 9| [1, 3, 9]|
+---+-----+-------------+
交叉连接本质上会将数组广播给所有执行程序,因此请注意要应用它的数据大小。(您可能还需要spark.sql.crossJoin.enabled=true在创建 Spark 上下文时显式设置,因为 Spark 不喜欢交叉连接正是出于这个原因。)
第 2 步:将您的fu函数注册为 Spark UDF
>>> from pyspark.sql.types import *
>>> fu_udf = udf(fu, ArrayType(IntegerType()))
Step3:使用这个UDF来增加数组元素
>>> df3 = df1.withColumn('sums_in_column',fu_udf(df1.value,df1.arr_column))
>>> df3.show()
+---+-----+-------------+--------------+
| id|value| arr_column|sums_in_column|
+---+-----+-------------+--------------+
| 0| 1| [1, 3, 9]| [2, 4, 10]|
| 1| 3| [1, 3, 9]| [4, 6, 12]|
| 2| 9| [1, 3, 9]| [10, 12, 18]|
+---+-----+-------------+--------------+
添加回答
举报