gpt4 book ai didi

scala - 关于优化简单的 Scala FoldLeft 多个值的建议?

转载 作者:行者123 更新时间:2023-12-03 05:11:56 25 4
gpt4 key购买 nike

我正在从 Java 到 Scala 重新实现一些代码(一个简单的贝叶斯推理算法,但这并不重要)。我希望以尽可能高效的方式实现它,同时通过尽可能避免可变性来保持代码的整洁和功能。

这是 Java 代码片段:

    // initialize
double lP = Math.log(prior);
double lPC = Math.log(1-prior);

// accumulate probabilities from each annotation object into lP and lPC
for (Annotation annotation : annotations) {
float prob = annotation.getProbability();
if (isValidProbability(prob)) {
lP += logProb(prob);
lPC += logProb(1 - prob);
}
}

很简单,对吧?所以我决定第一次尝试使用 Scala FoldLeft 和 map 方法。由于我有两个要累加的值,因此累加器是一个元组:

    val initial  = (math.log(prior), math.log(1-prior))
val probs = annotations map (_.getProbability)
val (lP,lPC) = probs.foldLeft(initial) ((r,p) => {
if(isValidProbability(p)) (r._1 + logProb(p), r._2 + logProb(1-p)) else r
})

不幸的是,这段代码的执行速度比 Java 慢大约 5 倍(使用简单且不精确的指标;只是在循环中调用该代码 10000 次)。有一个缺陷是非常明显的:我们遍历列表两次,一次是在调用map时,另一次是在foldLeft中。所以这是一个遍历列表一次的版本。

    val (lP,lPC) = annotations.foldLeft(initial) ((r,annotation) => {
val p = annotation.getProbability
if(isValidProbability(p)) (r._1 + logProb(p), r._2 + logProb(1-p)) else r
})

这样更好!它的性能比 Java 代码差大约 3 倍。我的下一个预感是,在折叠的每个步骤中创建所有新元组可能会涉及一些成本。所以我决定尝试一个遍历列表两次但不创建元组的版本。

    val lP = annotations.foldLeft(math.log(prior)) ((r,annotation) => {
val p = annotation.getProbability
if(isValidProbability(p)) r + logProb(p) else r
})
val lPC = annotations.foldLeft(math.log(1-prior)) ((r,annotation) => {
val p = annotation.getProbability
if(isValidProbability(p)) r + logProb(1-p) else r
})

其性能与之前的版本大致相同(比 Java 版本慢 3 倍)。并不奇怪,但我充满希望。

所以我的问题是,是否有一种更快的方法在 Scala 中实现这个 Java 代码片段,同时保持 Scala 代码干净、避免不必要的可变性并遵循 Scala 习惯用法?我确实希望最终在并发环境中使用此代码,因此保持不变性的值(value)可能会超过单个线程中较慢的性能。

最佳答案

首先,您的一些损失可能来自您使用的集合类型。但其中大部分可能是对象创建,实际上您无法通过运行循环两次来避免对象创建,因为数字必须装箱。

相反,您可以创建一个可变类来为您累积值:

class LogOdds(var lp: Double = 0, var lpc: Double = 0) {
def *=(p: Double) = {
if (isValidProbability(p)) {
lp += logProb(p)
lpc += logProb(1-p)
}
this // Pass self on so we can fold over the operation
}
def toTuple = (lp, lpc)
}

现在,虽然您可以不安全地使用它,但您不必这样做。事实上,您可以将其折叠起来。

annotations.foldLeft(new LogOdds()) { (r,ann) => r *= ann.getProbability } toTuple

如果您使用此模式,所有可变的不安全因素都会隐藏在折叠内;它永远不会逃脱。

现在,您无法进行平行折叠,但您可以进行聚合,这就像折叠加上额外的操作来组合各个部分。所以你添加方法

def **(lo: LogOdds) = new LogOdds(lp + lo.lp, lpc + lo.lpc)

LogOdds,然后

annotations.aggregate(new LogOdds())(
(r,ann) => r *= ann.getProbability,
(l,r) => l**r
).toTuple

你就可以开始了。

(为此可以随意使用非数学符号,但由于您基本上是在乘以概率,因此乘法符号似乎比合并概率或类似的符号更能直观地了解正在发生的情况。)

关于scala - 关于优化简单的 Scala FoldLeft 多个值的建议?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/9116506/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com