gpt4 book ai didi

arrays - 在 Scala 中表示欧几里得距离的最简单方法

转载 作者:行者123 更新时间:2023-12-04 23:20:21 27 4
gpt4 key购买 nike

我正在 Scala 中编写数据挖掘算法,我想为给定的测试和几个火车实例编写欧几里德距离函数。我有一个 Array[Array[Double]]带有测试和训练实例。我有一种方法可以针对所有训练实例循环遍历每个测试实例并计算两者之间的距离(每次迭代选择一个测试和训练实例)并返回 Double .

比如说,我有以下数据点:

testInstance = Array(Array(3.2, 2.1, 4.3, 2.8))
trainPoints = Array(Array(3.9, 4.1, 6.2, 7.3), Array(4.5, 6.1, 8.3, 3.8), Array(5.2, 4.6, 7.4, 9.8), Array(5.1, 7.1, 4.4, 6.9))

我有一个方法 stub (突出显示距离函数),它返回给定测试实例周围的邻居:
def predictClass(testPoints: Array[Array[Double]], trainPoints: Array[Array[Double]], k: Int): Array[Double] = {

for(testInstance <- testPoints)
{
for(trainInstance <- trainPoints)
{
for(i <- 0 to k)
{
distance = euclideanDistanceBetween(testInstance, trainInstance) //need help in defining this function
}
}
}
return distance
}

我知道如何将通用的欧几里得距离公式写为:
math.sqrt(math.pow((x1 - y1), 2) + math.pow((x2 - y2), 2))

关于我希望该方法对函数的基本定义做什么,我有一些伪步骤:
def distanceBetween(testInstance: Array[Double], trainInstance: Array[Double]): Double = {
// subtract each element of trainInstance with testInstance
// for example,
// iteration 1 will do [Array(3.9, 4.1, 6.2, 7.3) - Array(3.2, 2.1, 4.3, 2.8)]
// i.e. sqrt(3.9-3.2)^2+(4.1-2.1)^2+(6.2-4.3)^2+(7.3-2.8)^2
// return result
// iteration 2 will do [Array(4.5, 6.1, 8.3, 3.8) - Array(3.2, 2.1, 4.3, 2.8)]
// i.e. sqrt(4.5-3.2)^2+(6.1-2.1)^2+(8.3-4.3)^2+(3.8-2.8)^2
// return result, and so on......
}

我怎样才能在代码中写这个?

最佳答案

所以你输入的公式只适用于二维向量。你有四个维度,但你可能应该编写你的函数来灵活处理这个问题。所以退房this formula .

所以你真正想说的是:

for each position i:
subtract the ith element of Y from the ith element of X
square it
add all of those up
square root the whole thing

为了使这种更具功能性的编程风格更像是:
square root the:
sum of:
zip X and Y into pairs
for each pair, square the difference

所以这看起来像:
import math._

def distance(xs: Array[Double], ys: Array[Double]) = {
sqrt((xs zip ys).map { case (x,y) => pow(y - x, 2) }.sum)
}

val testInstances = Array(Array(5.0, 4.8, 7.5, 10.0), Array(3.2, 2.1, 4.3, 2.8))
val trainPoints = Array(Array(3.9, 4.1, 6.2, 7.3), Array(4.5, 6.1, 8.3, 3.8), Array(5.2, 4.6, 7.4, 9.8), Array(5.1, 7.1, 4.4, 6.9))

distance(testInstances.head, trainPoints.head)
// 3.2680269276736382

至于预测类(class),您也可以使其更具功能性,但尚不清楚您打算返回的 Double 是什么。似乎您想预测每个测试实例的类?也许选择类(class) c对应最近的训练点?
def findNearestClasses(testPoints: Array[Array[Double]], trainPoints: Array[Array[Double]]): Array[Int] = {
testPoints.map { testInstance =>
trainPoints.zipWithIndex.map { case (trainInstance, c) =>
c -> distance(testInstance, trainInstance)
}.minBy(_._2)._1
}
}

findNearestClasses(testInstances, trainPoints)
// Array(2, 0)

或者您可能想要 k -最近的邻居:
def findKNearestClasses(testPoints: Array[Array[Double]], trainPoints: Array[Array[Double]], k: Int): Array[Int] = {
testPoints.map { testInstance =>
val distances =
trainPoints.zipWithIndex.map { case (trainInstance, c) =>
c -> distance(testInstance, trainInstance)
}
val classes = distances.sortBy(_._2).take(k).map(_._1)
val classCounts = classes.groupBy(identity).mapValues(_.size)
classCounts.maxBy(_._2)._1
}
}

findKNearestClasses(testInstances, trainPoints)
// Array(2, 1)

关于arrays - 在 Scala 中表示欧几里得距离的最简单方法,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/28949591/

27 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com