gpt4 book ai didi

apache-spark - Pyspark:计算连续观察的条纹

转载 作者:行者123 更新时间:2023-12-02 03:22:10 25 4
gpt4 key购买 nike

我有一个 Spark (2.4.0) 数据框,其中一列只有两个值(01)。我需要计算此数据中连续的 01 的条纹,如果值发生变化,则将条纹重置为零。

一个例子:

from pyspark.sql import (SparkSession, Window)
from pyspark.sql.functions import (to_date, row_number, lead, col)

spark = SparkSession.builder.appName('test').getOrCreate()

# Create dataframe
df = spark.createDataFrame([
('2018-01-01', 'John', 0, 0),
('2018-01-01', 'Paul', 1, 0),
('2018-01-08', 'Paul', 3, 1),
('2018-01-08', 'Pete', 4, 0),
('2018-01-08', 'John', 3, 0),
('2018-01-15', 'Mary', 6, 0),
('2018-01-15', 'Pete', 6, 0),
('2018-01-15', 'John', 6, 1),
('2018-01-15', 'Paul', 6, 1),
], ['str_date', 'name', 'value', 'flag'])

df.orderBy('name', 'str_date').show()
## +----------+----+-----+----+
## | str_date|name|value|flag|
## +----------+----+-----+----+
## |2018-01-01|John| 0| 0|
## |2018-01-08|John| 3| 0|
## |2018-01-15|John| 6| 1|
## |2018-01-15|Mary| 6| 0|
## |2018-01-01|Paul| 1| 0|
## |2018-01-08|Paul| 3| 1|
## |2018-01-15|Paul| 6| 1|
## |2018-01-08|Pete| 4| 0|
## |2018-01-15|Pete| 6| 0|
## +----------+----+-----+----+

有了这些数据,我想计算连续的零和一的条纹,按日期排序并按名称“加窗”:

# Expected result:
## +----------+----+-----+----+--------+--------+
## | str_date|name|value|flag|streak_0|streak_1|
## +----------+----+-----+----+--------+--------+
## |2018-01-01|John| 0| 0| 1| 0|
## |2018-01-08|John| 3| 0| 2| 0|
## |2018-01-15|John| 6| 1| 0| 1|
## |2018-01-15|Mary| 6| 0| 1| 0|
## |2018-01-01|Paul| 1| 0| 1| 0|
## |2018-01-08|Paul| 3| 1| 0| 1|
## |2018-01-15|Paul| 6| 1| 0| 2|
## |2018-01-08|Pete| 4| 0| 1| 0|
## |2018-01-15|Pete| 6| 0| 2| 0|
## +----------+----+-----+----+--------+--------+

当然,如果“标志”发生变化,我需要连胜将自身重置为零。

有没有办法做到这一点?

最佳答案

这需要采用行号差异方法,首先对具有相同值的连续行进行分组,然后在这些组中使用排名方法。

from pyspark.sql import Window 
from pyspark.sql import functions as f
#Windows definition
w1 = Window.partitionBy(df.name).orderBy(df.date)
w2 = Window.partitionBy(df.name,df.flag).orderBy(df.date)

res = df.withColumn('grp',f.row_number().over(w1)-f.row_number().over(w2))
#Window definition for streak
w3 = Window.partitionBy(res.name,res.flag,res.grp).orderBy(res.date)
streak_res = res.withColumn('streak_0',f.when(res.flag == 1,0).otherwise(f.row_number().over(w3))) \
.withColumn('streak_1',f.when(res.flag == 0,0).otherwise(f.row_number().over(w3)))
streak_res.show()

关于apache-spark - Pyspark:计算连续观察的条纹,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/54445961/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com