gpt4 book ai didi

Scala/Spark Apriori 实现速度极慢

转载 作者:行者123 更新时间:2023-12-03 23:42:40 26 4
gpt4 key购买 nike

我们正在尝试实现 Apriori algorithm在 Scala 中使用 Spark(您不需要知道回答这个问题的算法)。
计算Apriori算法项集的函数是freq() .代码是正确的,但每次迭代后都会变慢 whilefreq()函数,直到花费几秒钟对具有 1 行与自身的表执行交叉联接。

import System.{exit, nanoTime}
import scala.collection.mutable.WrappedArray
import org.apache.spark.sql.{Column, SparkSession, DataFrame}
import org.apache.spark.sql.functions._
import spark.implicits._

object Main extends Serializable {
val s = 0.03

def loadFakeData() : DataFrame = {
var data = Seq("1 ",
"1 2 ",
"1 2",
"3",
"1 2 3 ",
"1 2 ")
.toDF("baskets_str")
.withColumn("baskets", split('baskets_str, " ").cast("array<int>"))
data
}

def combo(a1: WrappedArray[Int], a2: WrappedArray[Int]): Array[Array[Int]] = {
var a = a1.toSet
var b = a2.toSet
var res = a.diff(b).map(b+_) ++ b.diff(a).map(a+_)
return res.map(_.toArray.sortWith(_ < _)).toArray
}
val comboUDF = udf[Array[Array[Int]], WrappedArray[Int], WrappedArray[Int]](combo)

def getCombinations(df: DataFrame): DataFrame = {
df.crossJoin(df.withColumnRenamed("itemsets", "itemsets_2"))
.withColumn("combinations", comboUDF(col("itemsets"), col("itemsets_2")))
.select("combinations")
.withColumnRenamed("combinations", "itemsets")
.withColumn("itemsets", explode(col("itemsets")))
.dropDuplicates()
}

def countCombinations(data : DataFrame, combinations: DataFrame) : DataFrame = {
data.crossJoin(combinations)
.where(size(array_intersect('baskets, 'itemsets)) === size('itemsets))
.groupBy("itemsets")
.count
}

def freq() {
val spark = SparkSession.builder.appName("FreqItemsets")
.master("local[*]")
.getOrCreate()

// data is a dataframe where each row contains an array of integer values
var data = loadFakeData()
val basket_count = data.count

// Itemset is a dataframe containing all possible sets of 1 element
var itemset : DataFrame = data
.select(explode('baskets))
.na.drop
.dropDuplicates()
.withColumnRenamed("col", "itemsets")
.withColumn("itemsets", array('itemsets))
var itemset_count : DataFrame = countCombinations(data, itemset).filter('count > s*basket_count)
var itemset_counts = List(itemset_count)

// We iterate creating each time itemsets of length k+1 from itemsets of length k
// pruning those that do not have enough support
var stop = (itemset_count.count == 0)
while(!stop) {
itemset = getCombinations(itemset_count.select("itemsets"))
itemset_count = countCombinations(data, itemset).filter('count > s*basket_count)
stop = (itemset_count.count == 0)
if (!stop) {
itemset_counts = itemset_counts :+ itemset_count
}
}

spark.stop()
}
}

最佳答案

由于 Spark 保留随时重新生成数据集的权利,这可能是正在发生的事情,在这种情况下,缓存昂贵的转换结果可以显着提高性能。
在这种情况下,乍一看就像 itemset是重量级选手,所以

itemset = getCombinations(itemset_count.select("itemsets")).cache
可以分红。
还应该注意的是,通过在循环中追加来构建列表通常比通过前置来构建要慢得多( O(n^2) )。如果正确性不受 itemset_counts 的顺序影响, 然后:
itemset_counts = itemset_count :: itemset_counts
将至少产生边际加速。

关于Scala/Spark Apriori 实现速度极慢,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/64835206/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com