gpt4 book ai didi

java - 可以使用多线程来使这个矩阵 vector 乘法算法更加高效吗?

转载 作者:太空宇宙 更新时间:2023-11-04 09:22:56 25 4
gpt4 key购买 nike

对于我正在制作的程序,我想实现我自己的矩阵 vector 乘法算法。这是它的代码。

static <T extends List<Double>> List<Double> matrixVectorMulti(List<T> matrix, List<Double> vector) {
List<Double> output = new ArrayList<>();

for(int row = 0; row < matrix.size(); row++) {
double sum = 0;
for (int column = 0; column < vector.size(); column++) {
sum += matrix.get(row).get(column) * vector.get(column);
}

output.add(sum);
}

return output;
}

是否可以使用多线程来提高算法的性能,或者我的算法是否更有效。另外,如果您能为我提供多线程实现的示例代码,将会很有帮助。

注意:列表的列表用于对矩阵建模,列表用于对 vector 建模。

最佳答案

多线程是否会提高性能实际上取决于许多因素

  • 矩阵中有多少数据
  • (虚拟)机有多少个进程
  • 工作负载类型(CPU 密集型、IO 密集型)
  • 许多其他因素

您的用例是 CPU 密集型(IO 最少的乘法)。假设矩阵很大,您可以通过实现多线程获得一定程度的好处。

下面是多线程的一种版本,它可以利用并行处理(线程数量几乎等于可用于处理的处理器数量)。

请记住,还有其他方法可以提高性能...例如使用矩阵大小初始化输出 ArrayList 等。

注意:下面计算的性能统计数据不是科学方法......它只是一种非正式的计算方法。该代码未经过充分测试。但它可以给你这个想法。

package my.package;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Random;

public class MatrixMultiplication {

public List<Double> matrixVectorMulti(List<List<Double>> matrix, List<Double> vector) {
List<Double> output = new ArrayList<>();

for(int row = 0; row < matrix.size(); row++) {
double sum = 0;
for (int column = 0; column < vector.size(); column++) {
sum += matrix.get(row).get(column) * vector.get(column);
}
output.add(sum);
}
return output;
}

static List<List<Double>> initializeMatrix(int matrix_size) {
List<List<Double>> matrix = new ArrayList<>();
int rangeMin = 100;
int rangeMax = 200;

for (int i=0; i<matrix_size; i++) {
List<Double> row = new ArrayList<>();
for (int j=0; j<matrix_size; j++) {
row.add(rangeMin + (rangeMax-rangeMin) * new Random().nextDouble());
}
matrix.add(row);
}

return matrix;
}

static List<Double> initializeVector(int matrix_size) {
List<Double> vector = new ArrayList<>();
int rangeMin = 100;
int rangeMax = 200;

for (int j=0; j<matrix_size; j++) {
vector.add(rangeMin + (rangeMax-rangeMin) * new Random().nextDouble());
}
return vector;
}

public List<Double> matrixVectorMultiParallel(List<List<Double>> matrix, List<Double> vector) {
int numOfThreads = Runtime.getRuntime().availableProcessors();
List<Double> result = new ArrayList<>();

//System.out.println(numOfThreads);
//System.exit(0);
int batchSize = matrix.size()/numOfThreads;
PartialVectorMulti[] partialVectorMultis = new PartialVectorMulti[numOfThreads];

int rangeStart = 0;
int rangeEnd = 0;
for (int i=0; i<numOfThreads; i++) {
rangeEnd = rangeStart + batchSize-1;

if (i == numOfThreads-1) {
partialVectorMultis[i] = new PartialVectorMulti(matrix, rangeStart, matrix.size()-1, vector);
} else {
partialVectorMultis[i] = new PartialVectorMulti(matrix, rangeStart, rangeEnd, vector);
}

partialVectorMultis[i].start();
rangeStart = rangeEnd + 1;
}

for (int i=0; i<numOfThreads; i++) {
try {
partialVectorMultis[i].join();
} catch (InterruptedException e) {
e.printStackTrace();
}
}

for (int i=0; i<numOfThreads; i++) {
result.addAll(partialVectorMultis[i].getPartialResult());
}

return result;
}

public static void main(String[] args) {
int matrix_size = 10_000;

List<List<Double>> matrix = initializeMatrix(matrix_size);
List<Double> vector = initializeVector(matrix_size);
List<Double> result;
List<Double> resultMulti;

MatrixMultiplication matrixMultiplication = new MatrixMultiplication();

long startTime = System.currentTimeMillis();
result = matrixMultiplication.matrixVectorMulti(matrix, vector);
long endTime = System.currentTimeMillis();
System.out.println("matrixVectorMulti: " + (endTime-startTime) + "milli seconds");

startTime = System.currentTimeMillis();
resultMulti = matrixMultiplication.matrixVectorMultiParallel(matrix, vector);
endTime = System.currentTimeMillis();
System.out.println("matrixVectorMultiParallel: " + (endTime-startTime) + "milli seconds");

}

class PartialVectorMulti extends Thread {
List<List<Double>> matrix;
List<Double> vector;
int rowStart;
int rowEnd;
List<Double> partialResult = new ArrayList<>();

public PartialVectorMulti(List<List<Double>> matrix, int rowStart, int rowEnd, List<Double> vector) {
this.matrix = matrix;
this.rowStart = rowStart;
this.rowEnd = rowEnd;
this.vector = vector;
}

public List<Double> getPartialResult() {
return this.partialResult;
}

@Override
public void run() {
for (int i=rowStart; i<=rowEnd; i++) {
double sum = 0;
for (int j=0; j<vector.size(); j++) {
sum += matrix.get(i).get(j) * vector.get(j);
}
partialResult.add(sum);
}
}
}
}

关于java - 可以使用多线程来使这个矩阵 vector 乘法算法更加高效吗?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/58089171/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com