gpt4 book ai didi

scala - Spark Collect_list 并限制结果列表

转载 作者:行者123 更新时间:2023-12-02 18:16:56 25 4
gpt4 key购买 nike

我有以下格式的数据框:

name          merged
key1 (internalKey1, value1)
key1 (internalKey2, value2)
...
key2 (internalKey3, value3)
...

我想要做的是按名称对数据框进行分组,收集列表并限制列表的大小。

这就是我按名称分组并收集列表的方式:

val res = df.groupBy("name")
.agg(collect_list(col("merged")).as("final"))

结果数据框类似于:

 key1   [(internalKey1, value1), (internalKey2, value2),...] // Limit the size of this list 
key2 [(internalKey3, value3),...]

我想要做的是限制每个键生成的列表的大小。我尝试了多种方法来做到这一点,但没有成功。我已经看到一些建议第三方解决方案的帖子,但我想避免这种情况。有办法吗?

最佳答案

因此,虽然 UDF 可以满足您的需要,但如果您正在寻找一种性能更高且对内存敏感的方法,则可以编写 UDAF。不幸的是,UDAF API 实际上不如 Spark 附带的聚合函数那么可扩展。但是,您可以使用其内部 API 来构建内部函数来完成您需要的操作。

这是一个 collect_list_limit 的实现,它主要是 Spark 内部 CollectList AggregateFunction 的复制。我只想扩展它,但它是一个案例类。实际上,所需要的只是重写更新和合并方法以尊重传入的限制:

case class CollectListLimit(
child: Expression,
limitExp: Expression,
mutableAggBufferOffset: Int = 0,
inputAggBufferOffset: Int = 0) extends Collect[mutable.ArrayBuffer[Any]] {

val limit = limitExp.eval( null ).asInstanceOf[Int]

def this(child: Expression, limit: Expression) = this(child, limit, 0, 0)

override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
copy(mutableAggBufferOffset = newMutableAggBufferOffset)

override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
copy(inputAggBufferOffset = newInputAggBufferOffset)

override def createAggregationBuffer(): mutable.ArrayBuffer[Any] = mutable.ArrayBuffer.empty

override def update(buffer: mutable.ArrayBuffer[Any], input: InternalRow): mutable.ArrayBuffer[Any] = {
if( buffer.size < limit ) super.update(buffer, input)
else buffer
}

override def merge(buffer: mutable.ArrayBuffer[Any], other: mutable.ArrayBuffer[Any]): mutable.ArrayBuffer[Any] = {
if( buffer.size >= limit ) buffer
else if( other.size >= limit ) other
else ( buffer ++= other ).take( limit )
}

override def prettyName: String = "collect_list_limit"
}

要实际注册它,我们可以通过 Spark 的内部 FunctionRegistry 来完成,它接收名称和构建器,该构建器实际上是一个使用以下函数创建 CollectListLimit 的函数:提供的表达式:

val collectListBuilder = (args: Seq[Expression]) => CollectListLimit( args( 0 ), args( 1 ) )
FunctionRegistry.builtin.registerFunction( "collect_list_limit", collectListBuilder )

编辑:

事实证明,仅当您尚未创建 SparkContext 时才将其添加到内置中才有效,因为它会在启动时创建不可变的克隆。如果您有现有的上下文,那么这应该可以通过反射添加它:

val field = classOf[SessionCatalog].getFields.find( _.getName.endsWith( "functionRegistry" ) ).get
field.setAccessible( true )
val inUseRegistry = field.get( SparkSession.builder.getOrCreate.sessionState.catalog ).asInstanceOf[FunctionRegistry]
inUseRegistry.registerFunction( "collect_list_limit", collectListBuilder )

关于scala - Spark Collect_list 并限制结果列表,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/52467555/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com