gpt4 book ai didi

java - 如何编写用户定义的聚合函数?

转载 作者:行者123 更新时间:2023-11-29 07:29:33 25 4
gpt4 key购买 nike

我正在尝试理解 Java Spark 文档。有一个名为Untyped User Defined Aggregate Functions 的部分,其中包含一些我无法理解的示例代码。这是代码:

package org.apache.spark.examples.sql;

// $example on:untyped_custom_aggregation$
import java.util.ArrayList;
import java.util.List;

import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.expressions.MutableAggregationBuffer;
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
// $example off:untyped_custom_aggregation$

public class JavaUserDefinedUntypedAggregation {

// $example on:untyped_custom_aggregation$
public static class MyAverage extends UserDefinedAggregateFunction {

private StructType inputSchema;
private StructType bufferSchema;

public MyAverage() {
List<StructField> inputFields = new ArrayList<>();
inputFields.add(DataTypes.createStructField("inputColumn", DataTypes.LongType, true));
inputSchema = DataTypes.createStructType(inputFields);

List<StructField> bufferFields = new ArrayList<>();
bufferFields.add(DataTypes.createStructField("sum", DataTypes.LongType, true));
bufferFields.add(DataTypes.createStructField("count", DataTypes.LongType, true));
bufferSchema = DataTypes.createStructType(bufferFields);
}
// Data types of input arguments of this aggregate function
public StructType inputSchema() {
return inputSchema;
}
// Data types of values in the aggregation buffer
public StructType bufferSchema() {
return bufferSchema;
}
// The data type of the returned value
public DataType dataType() {
return DataTypes.DoubleType;
}
// Whether this function always returns the same output on the identical input
public boolean deterministic() {
return true;
}
// Initializes the given aggregation buffer. The buffer itself is a `Row` that in addition to
// standard methods like retrieving a value at an index (e.g., get(), getBoolean()), provides
// the opportunity to update its values. Note that arrays and maps inside the buffer are still
// immutable.
public void initialize(MutableAggregationBuffer buffer) {
buffer.update(0, 0L);
buffer.update(1, 0L);
}
// Updates the given aggregation buffer `buffer` with new input data from `input`
public void update(MutableAggregationBuffer buffer, Row input) {
if (!input.isNullAt(0)) {
long updatedSum = buffer.getLong(0) + input.getLong(0);
long updatedCount = buffer.getLong(1) + 1;
buffer.update(0, updatedSum);
buffer.update(1, updatedCount);
}
}
// Merges two aggregation buffers and stores the updated buffer values back to `buffer1`
public void merge(MutableAggregationBuffer buffer1, Row buffer2) {
long mergedSum = buffer1.getLong(0) + buffer2.getLong(0);
long mergedCount = buffer1.getLong(1) + buffer2.getLong(1);
buffer1.update(0, mergedSum);
buffer1.update(1, mergedCount);
}
// Calculates the final result
public Double evaluate(Row buffer) {
return ((double) buffer.getLong(0)) / buffer.getLong(1);
}
}
// $example off:untyped_custom_aggregation$

public static void main(String[] args) {
SparkSession spark = SparkSession
.builder()
.appName("Java Spark SQL user-defined DataFrames aggregation example")
.getOrCreate();

// $example on:untyped_custom_aggregation$
// Register the function to access it
spark.udf().register("myAverage", new MyAverage());

Dataset<Row> df = spark.read().json("examples/src/main/resources/employees.json");
df.createOrReplaceTempView("employees");
df.show();
// +-------+------+
// | name|salary|
// +-------+------+
// |Michael| 3000|
// | Andy| 4500|
// | Justin| 3500|
// | Berta| 4000|
// +-------+------+

Dataset<Row> result = spark.sql("SELECT myAverage(salary) as average_salary FROM employees");
result.show();
// +--------------+
// |average_salary|
// +--------------+
// | 3750.0|
// +--------------+
// $example off:untyped_custom_aggregation$

spark.stop();
}
}

我对上述代码的疑惑是:

  • 每当我想创建一个 UDF 时,我是否应该拥有函数 initializeupdatemerge
  • 变量inputSchemabufferSchema 有什么意义?我很惊讶它们的存在,因为它们从来没有被用来创建任何 DataFrame。它们应该出现在每个 UDF 中吗?如果是,那么它们应该是完全相同的名字吗?
  • 为什么 inputSchemabufferSchema 的 getter 没有命名为 getInputSchema()getBufferSchema()?为什么没有这些变量的 setter ?
  • 这里称为deterministic() 的函数有什么意义?请给出调用此函数有用的场景。

总的来说,我想知道如何在 Spark 中编写用户定义的聚合函数。

最佳答案

Whenever I want to create a UDF, should I have the functions initialize, update and merge

UDF 代表用户定义的函数,而方法initializeupdatemerge 用于用户定义的聚合函数(又名UDAF)。

UDF 是一个函数,它处理单行以(通常)生成一行(例如 upper 函数)。

UDAF 是一种使用零行或多行生成一行的函数(例如,count 聚合函数)。

您当然不必(也不可能)为用户提供函数initializeupdatemerge -定义函数 (UDF)。

使用任何 udf functions定义和注册 UDF。

val myUpper = udf { (s: String) => s.toUpperCase }

How to how to write a user defined aggregate function in Spark.

What is the significance of the variables inputSchema and bufferSchema?

(无耻插件:我一直在 UserDefinedAggregateFunction — Contract for User-Defined Untyped Aggregate Functions (UDAFs) 的 Mastering Spark SQL 一书中描述 UDAF)

引用 Untyped User-Defined Aggregate Functions :

// Data types of input arguments of this aggregate function
def inputSchema: StructType = StructType(StructField("inputColumn", LongType) :: Nil)

// Data types of values in the aggregation buffer
def bufferSchema: StructType = {
StructType(StructField("sum", LongType) :: StructField("count", LongType) :: Nil)
}

换句话说,inputSchema 是您对输入的期望,而 bufferSchema 是您在进行聚合时临时保留的内容。

Why are there no setters of these variables?

它们是由 Spark 管理的扩展点。

What is the significance of the function called deterministic() here?

引用 Untyped User-Defined Aggregate Functions :

// Whether this function always returns the same output on the identical input
def deterministic: Boolean = true

Please give a scenario when it would be useful to call this function.

这是我仍在努力的事情,所以今天无法回答。

关于java - 如何编写用户定义的聚合函数?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/44936394/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com