gpt4 book ai didi

scala - Spark : Performant way to find top n values

转载 作者:行者123 更新时间:2023-12-02 18:32:30 25 4
gpt4 key购买 nike

我有一个大数据集,我想找到具有 n 个最高值的行。

id, count
id1, 10
id2, 15
id3, 5
...

我能想到的唯一方法是使用row_number 不带分区

val window = Window.orderBy(desc("count"))

df.withColumn("row_number", row_number over window).filter(col("row_number") <= n)

但是当数据包含数百万或数十亿行时,这绝不是高效的,因为它将数据插入一个分区,我得到了 OOM。

有没有人想出一个高效的解决方案?

最佳答案

我看到了两种提高算法性能的方法。首先是使用 sort limit 检索前 n 行。二是制定你的习惯 Aggregator .

排序和限制方法

您对数据框进行排序,然后取第一个 n行:

val n: Int = ???

import org.apache.spark.functions.sql.desc

df.orderBy(desc("count")).limit(n)

Spark 通过首先对每个分区执行排序来优化这种转换序列,首先取 n每个分区上的行,在最终分区上检索它并重新执行排序并首先获得最后一个 n行。您可以通过执行 explain() 来检查这一点关于转型。你得到以下执行计划:

== Physical Plan ==
TakeOrderedAndProject(limit=3, orderBy=[count#8 DESC NULLS LAST], output=[id#7,count#8])
+- LocalTableScan [id#7, count#8]

通过查看如何 TakeOrderedAndProject步骤在 limit.scala 中执行在 Spark 的源代码中(案例类 TakeOrderedAndProjectExec ,方法 doExecute )。

自定义聚合器方法

对于自定义聚合器,您创建一个 Aggregator这将填充和更新顶部 n 的有序数组行。

import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.Encoder

import scala.collection.mutable.ArrayBuffer

case class Record(id: String, count: Int)

case class TopRecords(limit: Int) extends Aggregator[Record, ArrayBuffer[Record], Seq[Record]] {

def zero: ArrayBuffer[Record] = ArrayBuffer.empty[Record]

def reduce(topRecords: ArrayBuffer[Record], currentRecord: Record): ArrayBuffer[Record] = {
val insertIndex = topRecords.lastIndexWhere(p => p.count > currentRecord.count)
if (topRecords.length < limit) {
topRecords.insert(insertIndex + 1, currentRecord)
} else if (insertIndex < limit - 1) {
topRecords.insert(insertIndex + 1, currentRecord)
topRecords.remove(topRecords.length - 1)
}
topRecords
}

def merge(topRecords1: ArrayBuffer[Record], topRecords2: ArrayBuffer[Record]): ArrayBuffer[Record] = {
val merged = ArrayBuffer.empty[Record]
while (merged.length < limit && (topRecords1.nonEmpty || topRecords2.nonEmpty)) {
if (topRecords1.isEmpty) {
merged.append(topRecords2.remove(0))
} else if (topRecords2.isEmpty) {
merged.append(topRecords1.remove(0))
} else if (topRecords2.head.count < topRecords1.head.count) {
merged.append(topRecords1.remove(0))
} else {
merged.append(topRecords2.remove(0))
}
}
merged
}

def finish(reduction: ArrayBuffer[Record]): Seq[Record] = reduction

def bufferEncoder: Encoder[ArrayBuffer[Record]] = ExpressionEncoder[ArrayBuffer[Record]]

def outputEncoder: Encoder[Seq[Record]] = ExpressionEncoder[Seq[Record]]

}

然后您将此聚合器应用于您的数据框,并展平聚合结果:

val n: Int = ???

import sparkSession.implicits._

df.as[Record].select(TopRecords(n).toColumn).flatMap(record => record)

方法比较

为了比较这两种方法,假设我们要获取 top n分布在 p 上的数据框行分区,每个分区大约有 k记录。所以数据框的大小为 p·k .这给出了以下复杂性(可能会出错):

<表类="s-表"><头>方法操作总数内存消耗
(在执行器上)内存消耗
(在最终执行者上)<正文> 当前代码 O(p·k·log(p·k)) -- O(p·k) 排序和限制 O(p·k·log(k) + p·n·log(p·n)) O(k) O(p·n) 自定义聚合器 O(p·k) O(k) + O(n) O(p·n)

因此,就操作数量而言,Custom Aggregator 的性能最高。但是,此方法是迄今为止最复杂的方法并且意味着大量序列化/反序列化,因此在某些情况下它的性能可能不如排序和限制

结论

你有两种方法可以有效地取得 top n行、排序和限制 以及自定义聚合器。要选择使用哪一个,您应该使用您的真实数据框对这两种方法进行基准测试。如果在基准测试后 Sort and LimitCustom aggregator 慢一点,我会选择 Sort and Limit 因为它的代码更容易维护.

关于scala - Spark : Performant way to find top n values,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/69256388/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com