gpt4 book ai didi

python - Pyspark - 带过滤器的 groupby - 优化速度

转载 作者:行者123 更新时间:2023-12-01 06:57:59 26 4
gpt4 key购买 nike

我有数十亿行需要使用 Pyspark 处理。

数据框如下所示:

category    value    flag
A 10 1
A 12 0
B 15 0
and so on...

我需要运行两个 groupby 操作:一个针对 flag==1 的行,另一个针对所有行。目前我正在这样做:

frame_1 = df.filter(df.flag==1).groupBy('category').agg(F.sum('value').alias('foo1'))
frame_2 = df.groupBy('category').agg(F.sum('value').alias(foo2))
final_frame = frame1.join(frame2,on='category',how='left')

到目前为止,这段代码正在运行,但我的问题是它非常慢。有没有办法在速度方面提高此代码,或者这是限制,因为我了解 PySpark 的延迟评估确实需要一些时间,但此代码是执行此操作的最佳方法吗?

最佳答案

IIUC,您可以避免昂贵的连接并使用一个 groupBy 来实现这一点。

final_frame_2 = df.groupBy("category").agg(
F.sum(F.col("value")*F.col("flag")).alias("foo1"),
F.sum(F.col("value")).alias("foo2"),
)
final_frame_2.show()
#+--------+----+----+
#|category|foo1|foo2|
#+--------+----+----+
#| B| 0.0|15.0|
#| A|10.0|22.0|
#+--------+----+----+

现在比较执行计划:

首先是你的方法:

final_frame.explain()
#== Physical Plan ==
#*(5) Project [category#0, foo1#68, foo2#75]
#+- SortMergeJoin [category#0], [category#78], LeftOuter
# :- *(2) Sort [category#0 ASC NULLS FIRST], false, 0
# : +- *(2) HashAggregate(keys=[category#0], functions=[sum(cast(value#1 as double))])
# : +- Exchange hashpartitioning(category#0, 200)
# : +- *(1) HashAggregate(keys=[category#0], functions=[partial_sum(cast(value#1 as double))])
# : +- *(1) Project [category#0, value#1]
# : +- *(1) Filter (isnotnull(flag#2) && (cast(flag#2 as int) = 1))
# : +- Scan ExistingRDD[category#0,value#1,flag#2]
# +- *(4) Sort [category#78 ASC NULLS FIRST], false, 0
# +- *(4) HashAggregate(keys=[category#78], functions=[sum(cast(value#79 as double))])
# +- Exchange hashpartitioning(category#78, 200)
# +- *(3) HashAggregate(keys=[category#78], functions=[partial_sum(cast(value#79 as double))])
# +- *(3) Project [category#78, value#79]
# +- Scan ExistingRDD[category#78,value#79,flag#80]

现在 final_frame_2 也一样:

final_frame_2.explain()
#== Physical Plan ==
#*(2) HashAggregate(keys=[category#0], functions=[sum((cast(value#1 as double) * cast(flag#2 as double))), sum(cast(value#1 as double))])
#+- Exchange hashpartitioning(category#0, 200)
# +- *(1) HashAggregate(keys=[category#0], functions=[partial_sum((cast(value#1 as double) * cast(flag#2 as double))), partial_sum(cast(value#1 as double))])
# +- Scan ExistingRDD[category#0,value#1,flag#2]

注意:严格来说,这与您给出的示例完全相同的输出(如下所示),因为您的内部联接将消除所有不存在的类别带有 flag = 1 的行。

final_frame.show()
#+--------+----+----+
#|category|foo1|foo2|
#+--------+----+----+
#| A|10.0|22.0|
#+--------+----+----+

您可以向总和标志添加聚合,并根据需要过滤总和为零的聚合,而对性能的影响很小。

final_frame_3 = df.groupBy("category").agg(
F.sum(F.col("value")*F.col("flag")).alias("foo1"),
F.sum(F.col("value")).alias("foo2"),
F.sum(F.col("flag")).alias("foo3")
).where(F.col("foo3")!=0).drop("foo3")

final_frame_3.show()
#+--------+----+----+
#|category|foo1|foo2|
#+--------+----+----+
#| A|10.0|22.0|
#+--------+----+----+

关于python - Pyspark - 带过滤器的 groupby - 优化速度,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/58725489/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com