gpt4 book ai didi

java - Java 6 中的并行矩阵乘法

转载 作者:搜寻专家 更新时间:2023-10-31 20:06:27 24 4
gpt4 key购买 nike

昨天我问了一个关于 Java 7 中使用 fork/join 框架的并行矩阵乘法的问题 here .在 axtavt 的帮助下,我的示例程序开始运行了。现在我正在实现一个仅使用 Java 6 功能的等效程序。我遇到了与昨天相同的问题,尽管应用了 axtavt 给我的反馈(我认为)。我忽略了什么吗?代码:

package algorithms;

import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;

public class Java6MatrixMultiply implements Algorithm {

private static final int SIZE = 1024;
private static final int THRESHOLD = 64;
private static final int MAX_THREADS = Runtime.getRuntime().availableProcessors();

private final ExecutorService executor = Executors.newFixedThreadPool(MAX_THREADS);

private float[][] a = new float[SIZE][SIZE];
private float[][] b = new float[SIZE][SIZE];
private float[][] c = new float[SIZE][SIZE];

@Override
public void initialize() {
init(a, b, SIZE);
}

@Override
public void execute() {
MatrixMultiplyTask task = new MatrixMultiplyTask(a, 0, 0, b, 0, 0, c, 0, 0, SIZE);
task.split();

executor.shutdown();
try {
executor.awaitTermination(Integer.MAX_VALUE, TimeUnit.DAYS);
} catch (InterruptedException e) {
System.out.println("Error: " + e.getMessage());
}
}

@Override
public void printResult() {
check(c, SIZE);

for (int i = 0; i < SIZE && i <= 10; i++) {
for (int j = 0; j < SIZE && j <= 10; j++) {
if(j == 10) {
System.out.print("...");
}
else {
System.out.print(c[i][j] + " ");
}
}

if(i == 10) {
System.out.println();
for(int k = 0; k < 10; k++) System.out.print(" ... ");
}

System.out.println();
}

System.out.println();
}

// To simplify checking, fill with all 1's. Answer should be all n's.
static void init(float[][] a, float[][] b, int n) {
for (int i = 0; i < n; ++i) {
for (int j = 0; j < n; ++j) {
a[i][j] = 1.0F;
b[i][j] = 1.0F;
}
}
}

static void check(float[][] c, int n) {
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
if (c[i][j] != n) {
throw new Error("Check Failed at [" + i + "][" + j + "]: " + c[i][j]);
//System.out.println("Check Failed at [" + i + "][" + j + "]: " + c[i][j]);
}
}
}
}

public class Seq implements Runnable {

private final MatrixMultiplyTask a;
private final MatrixMultiplyTask b;

public Seq(MatrixMultiplyTask a, MatrixMultiplyTask b, int size) {
this.a = a;
this.b = b;

if (size <= THRESHOLD) {
executor.submit(this);
} else {
a.split();
b.split();
}
}

public void run() {
a.multiplyStride2();
b.multiplyStride2();
}
}

private class MatrixMultiplyTask {
private final float[][] A; // Matrix A
private final int aRow; // first row of current quadrant of A
private final int aCol; // first column of current quadrant of A

private final float[][] B; // Similarly for B
private final int bRow;
private final int bCol;

private final float[][] C; // Similarly for result matrix C
private final int cRow;
private final int cCol;

private final int size;

MatrixMultiplyTask(float[][] A, int aRow, int aCol, float[][] B,
int bRow, int bCol, float[][] C, int cRow, int cCol, int size) {

this.A = A;
this.aRow = aRow;
this.aCol = aCol;
this.B = B;
this.bRow = bRow;
this.bCol = bCol;
this.C = C;
this.cRow = cRow;
this.cCol = cCol;
this.size = size;
}

public void split() {
int h = size / 2;

new Seq(new MatrixMultiplyTask(A,
aRow, aCol, // A11
B, bRow, bCol, // B11
C, cRow, cCol, // C11
h),

new MatrixMultiplyTask(A, aRow, aCol + h, // A12
B, bRow + h, bCol, // B21
C, cRow, cCol, // C11
h), h);

new Seq(new MatrixMultiplyTask(A,
aRow, aCol, // A11
B, bRow, bCol + h, // B12
C, cRow, cCol + h, // C12
h),

new MatrixMultiplyTask(A, aRow, aCol + h, // A12
B, bRow + h, bCol + h, // B22
C, cRow, cCol + h, // C12
h), h);

new Seq(new MatrixMultiplyTask(A, aRow
+ h, aCol, // A21
B, bRow, bCol, // B11
C, cRow + h, cCol, // C21
h),

new MatrixMultiplyTask(A, aRow + h, aCol + h, // A22
B, bRow + h, bCol, // B21
C, cRow + h, cCol, // C21
h), h);

new Seq(new MatrixMultiplyTask(A, aRow
+ h, aCol, // A21
B, bRow, bCol + h, // B12
C, cRow + h, cCol + h, // C22
h),

new MatrixMultiplyTask(A, aRow + h, aCol + h, // A22
B, bRow + h, bCol + h, // B22
C, cRow + h, cCol + h, // C22
h), h);
}

public void multiplyStride2() {
for (int j = 0; j < size; j += 2) {
for (int i = 0; i < size; i += 2) {

float[] a0 = A[aRow + i];
float[] a1 = A[aRow + i + 1];

float s00 = 0.0F;
float s01 = 0.0F;
float s10 = 0.0F;
float s11 = 0.0F;

for (int k = 0; k < size; k += 2) {

float[] b0 = B[bRow + k];

s00 += a0[aCol + k] * b0[bCol + j];
s10 += a1[aCol + k] * b0[bCol + j];
s01 += a0[aCol + k] * b0[bCol + j + 1];
s11 += a1[aCol + k] * b0[bCol + j + 1];

float[] b1 = B[bRow + k + 1];

s00 += a0[aCol + k + 1] * b1[bCol + j];
s10 += a1[aCol + k + 1] * b1[bCol + j];
s01 += a0[aCol + k + 1] * b1[bCol + j + 1];
s11 += a1[aCol + k + 1] * b1[bCol + j + 1];
}

C[cRow + i][cCol + j] += s00;
C[cRow + i][cCol + j + 1] += s01;
C[cRow + i + 1][cCol + j] += s10;
C[cRow + i + 1][cCol + j + 1] += s11;
}
}
}
}
}

最佳答案

我尝试按照我的建议添加同步,这解决了问题。 ;)

我试过了

  • 每行同步 299 毫秒。
  • 交换 mutliplyStride 中的循环,使其按列而不是按行。 253 毫秒
  • 假定为每对行锁定一个锁(即我为两次更新锁定了一行。216 毫秒
  • 禁用偏向锁定-XX:-UseBiasedLocking 207 毫秒
  • 使用 2 倍于线程的处理器数量。 199 毫秒。
  • 相同,除了使用 double 而不是 float 237 毫秒。
  • 根本没有同步。 174 毫秒。

如您所见,第五个选项比没有同步慢不到 10%。如果您想获得更多 yield ,我建议您改变访问数据的方式,使它们对缓存更友好。

总结一下我建议

private final ExecutorService executor = Executors.newFixedThreadPool(MAX_THREADS*2);

public void multiplyStride2() {
for (int i = 0; i < size; i += 2) {
for (int j = 0; j < size; j += 2) {

// code as is......

synchronized (C[cRow + i]) {
C[cRow + i][cCol + j] += s00;
C[cRow + i][cCol + j + 1] += s01;

C[cRow + i + 1][cCol + j] += s10;
C[cRow + i + 1][cCol + j + 1] += s11;
}

有趣的是,如果我计算一个 2x4 block 而不是 2x2,平均时间会下降到 172 毫秒。 (比之前没有同步的结果更快);)

关于java - Java 6 中的并行矩阵乘法,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/5484204/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com