gpt4 book ai didi

java - cublasSgemm与jcuda批量使用

转载 作者:太空宇宙 更新时间:2023-11-04 08:10:21 26 4
gpt4 key购买 nike

我一直在尝试在 jcuda 中使用 cublasSgemmBatched() 函数对于矩阵乘法,我不确定如何正确处理指针传递和批处理矩阵的 vector 。如果有人知道如何修改我的代码以正确处理这个问题,我将非常感激。在此示例中,C 数组在 cublasGetVector 之后保持不变。

public static void SsmmBatchJCublas(int m, int n, int k, float A[], float B[]){

// Create a CUBLAS handle
cublasHandle handle = new cublasHandle();
cublasCreate(handle);

// Allocate memory on the device
Pointer d_A = new Pointer();
Pointer d_B = new Pointer();
Pointer d_C = new Pointer();


cudaMalloc(d_A, m*k * Sizeof.FLOAT);
cudaMalloc(d_B, n*k * Sizeof.FLOAT);
cudaMalloc(d_C, m*n * Sizeof.FLOAT);

float[] C = new float[m*n];
// Copy the memory from the host to the device
cublasSetVector(m*k, Sizeof.FLOAT, Pointer.to(A), 1, d_A, 1);
cublasSetVector(n*k, Sizeof.FLOAT, Pointer.to(B), 1, d_B, 1);
cublasSetVector(m*n, Sizeof.FLOAT, Pointer.to(C), 1, d_C, 1);

Pointer[] Aarray = new Pointer[]{d_A};
Pointer AarrayPtr = Pointer.to(Aarray);
Pointer[] Barray = new Pointer[]{d_B};
Pointer BarrayPtr = Pointer.to(Barray);
Pointer[] Carray = new Pointer[]{d_C};
Pointer CarrayPtr = Pointer.to(Carray);

// Execute sgemm
Pointer pAlpha = Pointer.to(new float[]{1});
Pointer pBeta = Pointer.to(new float[]{0});


cublasSgemmBatched(handle, CUBLAS_OP_N, CUBLAS_OP_N, m, n, k, pAlpha, AarrayPtr, Aarray.length, BarrayPtr, Barray.length, pBeta, CarrayPtr, Carray.length, Aarray.length);
// Copy the result from the device to the host
cublasGetVector(m*n, Sizeof.FLOAT, d_C, 1, Pointer.to(C), 1);

// Clean up
cudaFree(d_A);
cudaFree(d_B);
cudaFree(d_C);
cublasDestroy(handle);
}

最佳答案

我在jcuda官方论坛上提问,很快就得到了答案here .

关于java - cublasSgemm与jcuda批量使用,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/11332327/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com