gpt4 book ai didi

scala - Spark UDAF : How to get value from input by column field name in UDAF (User-Defined Aggregation Function)?

转载 作者:行者123 更新时间:2023-12-02 04:31:33 25 4
gpt4 key购买 nike

我正在尝试使用 Spark UDAF 将两个现有列汇总到一个新列中。大多数关于 Spark UDAF 的教程都使用索引来获取输入行每一列中的值。像这样:

input.getAs[String](1)

,它用在我的更新方法中(override def update(buffer: MutableAggregationBuffer, input: Row): Unit)。它也适用于我的情况。但是我想使用该列的字段名称来获取该值。像这样:

input.getAs[String](ColumnNames.BehaviorType)

,其中 ColumnNames.BehaviorType 是对象中定义的 String 对象:

 /**
* Column names in the original dataset
*/
object ColumnNames {
val JobSeekerID = "JobSeekerID"
val JobID = "JobID"
val Date = "Date"
val BehaviorType = "BehaviorType"
}

这次不行了。我得到以下异常:

java.lang.IllegalArgumentException: Field "BehaviorType" does not exist. at org.apache.spark.sql.types.StructType$$anonfun$fieldIndex$1.apply(StructType.scala:292) ... at org.apache.spark.sql.Row$class.getAs(Row.scala:333) at org.apache.spark.sql.catalyst.expressions.GenericRow.getAs(rows.scala:165) at com.recsys.UserBehaviorRecordsUDAF.update(UserBehaviorRecordsUDAF.scala:44)

部分相关代码段:

这是我的 UDAF 的一部分:

class UserBehaviorRecordsUDAF extends UserDefinedAggregateFunction {


override def inputSchema: StructType = StructType(
StructField("JobID", IntegerType) ::
StructField("BehaviorType", StringType) :: Nil)

override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
println("XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX")
println(input.schema.treeString)
println
println(input.mkString(","))
println
println(this.inputSchema.treeString)
// println
// println(bufferSchema.treeString)

input.getAs[String](ColumnNames.BehaviorType) match { //ColumnNames.BehaviorType //1 //TODO WHY??
case BehaviourTypes.viewed_job =>
buffer(0) =
buffer.getAs[Seq[Int]](0) :+ //Array[Int] //TODO WHY??
input.getAs[Int](0) //ColumnNames.JobID
case BehaviourTypes.bookmarked_job =>
buffer(1) =
buffer.getAs[Seq[Int]](1) :+ //Array[Int]
input.getAs[Int](0)//ColumnNames.JobID
case BehaviourTypes.applied_job =>
buffer(2) =
buffer.getAs[Seq[Int]](2) :+ //Array[Int]
input.getAs[Int](0) //ColumnNames.JobID
}
}

以下是调用UDAF的部分代码:

val ubrUDAF = new UserBehaviorRecordsUDAF

val userProfileDF = userBehaviorDS
.groupBy(ColumnNames.JobSeekerID)
.agg(
ubrUDAF(
userBehaviorDS.col(ColumnNames.JobID), //userBehaviorDS.col(ColumnNames.JobID)
userBehaviorDS.col(ColumnNames.BehaviorType) //userBehaviorDS.col(ColumnNames.BehaviorType)
).as("profile str"))

输入行的模式中的字段名称似乎没有传递到 UDAF 中:

XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
root
|-- input0: integer (nullable = true)
|-- input1: string (nullable = true)


30917,viewed_job

root
|-- JobID: integer (nullable = true)
|-- BehaviorType: string (nullable = true)

我的代码有什么问题?

最佳答案

我还想在我的更新方法中使用我的 inputSchema 中的字段名称来创建可维护的代码。

import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
class MyUDAF extends UserDefinedAggregateFunction {
def update(buffer: MutableAggregationBuffer, input: Row) = {
val inputWSchema = new GenericRowWithSchema(input.toSeq.toArray, inputSchema)

最终切换到 Aggregator,运行时间缩短了一半。

关于scala - Spark UDAF : How to get value from input by column field name in UDAF (User-Defined Aggregation Function)?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/48256822/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com