gpt4 book ai didi

python - 如何在 PySpark 数据帧的第 0 轴上找到数组(数组列)的平均值?

转载 作者:行者123 更新时间:2023-12-04 15:09:54 24 4
gpt4 key购买 nike

我有一个 PySpark 数据框-

df = spark.createDataFrame([
("u1", [[1., 2., 3.], [1., 2., 0.], [1., 0., 0.]]),
("u2", [[1., 10., 0.]]),
("u3", [[1., 0., 3.], [10., 0., 0.]]),
],
['user_id', 'features'])

print(df.printSchema())
df.show(truncate=False)

输出-

root
|-- user_id: string (nullable = true)
|-- features: array (nullable = true)
| |-- element: array (containsNull = true)
| | |-- element: double (containsNull = true)

None
+-------+---------------------------------------------------+
|user_id|features |
+-------+---------------------------------------------------+
|u1 |[[1.0, 2.0, 3.0], [1.0, 2.0, 0.0], [1.0, 0.0, 0.0]]|
|u2 |[[1.0, 10.0, 0.0]] |
|u3 |[[1.0, 0.0, 3.0], [10.0, 0.0, 0.0]] |
+-------+---------------------------------------------------+

我想为第 0 轴上的每个用户计算这些数组的平均值。所需的输出看起来像-

+-------+---------------------------------------------------+----------------+
|user_id|features |avg_features |
+-------+---------------------------------------------------+----------------+
|u1 |[[1.0, 2.0, 3.0], [1.0, 2.0, 0.0], [1.0, 0.0, 0.0]]|[1.0, 1.33, 1.0]|
|u2 |[[1.0, 10.0, 0.0]] |[1.0, 10.0, 0.0]|
|u3 |[[1.0, 0.0, 3.0], [10.0, 0.0, 0.0]] |[5.5, 0.0, 1.5]|
+-------+---------------------------------------------------+----------------+

我如何实现这一目标?

最佳答案

编辑:更具可扩展性的解决方案:

import pyspark.sql.functions as F

df2 = df.withColumn(
'exploded_features', F.explode('features')
).select(
'user_id', 'features', F.posexplode('exploded_features')
).groupBy(
'user_id', 'features', 'pos'
).agg(
F.mean('col')
).groupBy(
'user_id', 'features'
).agg(
F.array_sort(
F.collect_list(
F.array('pos', 'avg(col)')
)
).alias('avg_features')
).withColumn(
'avg_features',
F.expr('transform(avg_features, x -> x[1])')
)

df2.show(truncate=False)
+-------+---------------------------------------------------+------------------------------+
|user_id|features |avg_features |
+-------+---------------------------------------------------+------------------------------+
|u1 |[[1.0, 2.0, 3.0], [1.0, 2.0, 0.0], [1.0, 0.0, 0.0]]|[1.0, 1.3333333333333333, 1.0]|
|u2 |[[1.0, 10.0, 0.0]] |[1.0, 10.0, 0.0] |
|u3 |[[1.0, 0.0, 3.0], [10.0, 0.0, 0.0]] |[5.5, 0.0, 1.5] |
+-------+---------------------------------------------------+------------------------------+

使用aggregatetransform 对数组进行操作:

df2 = df.selectExpr(
'user_id',
'features',
'array(
aggregate(transform(features, x -> x[0]), cast(0 as double), (x, y) -> (x + y)) / size(features),
aggregate(transform(features, x -> x[1]), cast(0 as double), (x, y) -> (x + y)) / size(features),
aggregate(transform(features, x -> x[2]), cast(0 as double), (x, y) -> (x + y)) / size(features)
) as avg'
)

df2.show(truncate=False)
+-------+---------------------------------------------------+------------------------------+
|user_id|features |avg |
+-------+---------------------------------------------------+------------------------------+
|u1 |[[1.0, 2.0, 3.0], [1.0, 2.0, 0.0], [1.0, 0.0, 0.0]]|[1.0, 1.3333333333333333, 1.0]|
|u2 |[[1.0, 10.0, 0.0]] |[1.0, 10.0, 0.0] |
|u3 |[[1.0, 0.0, 3.0], [10.0, 0.0, 0.0]] |[5.5, 0.0, 1.5] |
+-------+---------------------------------------------------+------------------------------+

关于python - 如何在 PySpark 数据帧的第 0 轴上找到数组(数组列)的平均值?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/65404484/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com