gpt4 book ai didi

apache-spark - Spark 火车测试 split

转载 作者:行者123 更新时间:2023-12-03 23:48:52 29 4
gpt4 key购买 nike

我很好奇在最新的2.0.1版本中是否有类似于sklearn的http://scikit-learn.org/stable/modules/generated/sklearn.model_selection.StratifiedShuffleSplit.html用于apache-spark。

到目前为止,我只能找到https://spark.apache.org/docs/latest/mllib-statistics.html#stratified-sampling,它似乎不太适合将严重失衡的数据集拆分为训练样本/测试样本。

最佳答案

假设我们有一个像这样的数据集:

+---+-----+
| id|label|
+---+-----+
| 0| 0.0|
| 1| 1.0|
| 2| 0.0|
| 3| 1.0|
| 4| 0.0|
| 5| 1.0|
| 6| 0.0|
| 7| 1.0|
| 8| 0.0|
| 9| 1.0|
+---+-----+


该数据集是完全平衡的,但是这种方法也适用于不平衡的数据。

现在,让我们使用其他信息来扩充此DataFrame,这些信息将有助于确定应该对哪些行进行训练。步骤如下:


给定 ratio,确定每个标签应包含多少个示例。
随机排列DataFrame的行。
使用窗口函数按 label对DataFrame进行分区和排序,然后使用 row_number()对每个标签的观察值进行排名。


我们最终得到以下数据框:

+---+-----+----------+
| id|label|row_number|
+---+-----+----------+
| 6| 0.0| 1|
| 2| 0.0| 2|
| 0| 0.0| 3|
| 4| 0.0| 4|
| 8| 0.0| 5|
| 9| 1.0| 1|
| 5| 1.0| 2|
| 3| 1.0| 3|
| 1| 1.0| 4|
| 7| 1.0| 5|
+---+-----+----------+


注意:行被随机排列(请参阅: id列中的随机顺序),按标签进行分区(请参阅: label列)并进行排名。

假设我们要进行80%分割。在这种情况下,我们希望四个 1.0标签和四个 0.0标签进入训练数据集,一个 1.0标签和一个 0.0标签进入测试数据集。我们在 row_number列中有此信息,因此现在我们可以在用户定义的函数中简单地使用它(如果 row_number小于或等于4,则该示例进入训练集)。

应用UDF之后,结果数据帧如下:

+---+-----+----------+----------+
| id|label|row_number|isTrainSet|
+---+-----+----------+----------+
| 6| 0.0| 1| true|
| 2| 0.0| 2| true|
| 0| 0.0| 3| true|
| 4| 0.0| 4| true|
| 8| 0.0| 5| false|
| 9| 1.0| 1| true|
| 5| 1.0| 2| true|
| 3| 1.0| 3| true|
| 1| 1.0| 4| true|
| 7| 1.0| 5| false|
+---+-----+----------+----------+


现在,要获得训练/测试数据,必须要做的是:

val train = df.where(col("isTrainSet") === true)
val test = df.where(col("isTrainSet") === false)


对于某些非常大的数据集,这些排序和分区步骤可能是禁止的,因此我建议首先尽可能地过滤数据集。物理计划如下:

== Physical Plan ==
*(3) Project [id#4, label#5, row_number#11, if (isnull(row_number#11)) null else UDF(label#5, row_number#11) AS isTrainSet#48]
+- Window [row_number() windowspecdefinition(label#5, label#5 ASC NULLS FIRST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS row_number#11], [label#5], [label#5 ASC NULLS FIRST]
+- *(2) Sort [label#5 ASC NULLS FIRST, label#5 ASC NULLS FIRST], false, 0
+- Exchange hashpartitioning(label#5, 200)
+- *(1) Project [id#4, label#5]
+- *(1) Sort [_nondeterministic#9 ASC NULLS FIRST], true, 0
+- Exchange rangepartitioning(_nondeterministic#9 ASC NULLS FIRST, 200)
+- LocalTableScan [id#4, label#5, _nondeterministic#9


这是完整的工作示例(已通过Spark 2.3.0和Scala 2.11.12测试):

import org.apache.spark.SparkConf
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
import org.apache.spark.sql.functions.{col, row_number, udf, rand}

class StratifiedTrainTestSplitter {

def getNumExamplesPerClass(ss: SparkSession, label: String, trainRatio: Double)(df: DataFrame): Map[Double, Long] = {
df.groupBy(label).count().createOrReplaceTempView("labelCounts")
val query = f"SELECT $label AS ratioLabel, count, cast(count * $trainRatio as long) AS trainExamples FROM labelCounts"
import ss.implicits._
ss.sql(query)
.select("ratioLabel", "trainExamples")
.map((r: Row) => r.getDouble(0) -> r.getLong(1))
.collect()
.toMap
}

def split(df: DataFrame, label: String, trainRatio: Double): DataFrame = {
val w = Window.partitionBy(col(label)).orderBy(col(label))

val rowNumPartitioner = row_number().over(w)

val dfRowNum = df.sort(rand).select(col("*"), rowNumPartitioner as "row_number")

dfRowNum.show()

val observationsPerLabel: Map[Double, Long] = getNumExamplesPerClass(df.sparkSession, label, trainRatio)(df)

val addIsTrainColumn = udf((label: Double, rowNumber: Int) => rowNumber <= observationsPerLabel(label))

dfRowNum.withColumn("isTrainSet", addIsTrainColumn(col(label), col("row_number")))
}


}

object StratifiedTrainTestSplitter {

def getDf(ss: SparkSession): DataFrame = {
val data = Seq(
(0, 0.0), (1, 1.0), (2, 0.0), (3, 1.0), (4, 0.0), (5, 1.0), (6, 0.0), (7, 1.0), (8, 0.0), (9, 1.0)
)
ss.createDataFrame(data).toDF("id", "label")
}

def main(args: Array[String]): Unit = {
val spark: SparkSession = SparkSession
.builder()
.config(new SparkConf().setMaster("local[1]"))
.getOrCreate()

val df = new StratifiedTrainTestSplitter().split(getDf(spark), "label", 0.8)

df.cache()

df.where(col("isTrainSet") === true).show()
df.where(col("isTrainSet") === false).show()
}
}


注意:在这种情况下,标签为 Double。如果标签是 String,则必须在此处和此处切换类型。

关于apache-spark - Spark 火车测试 split ,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/39994587/

29 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com