gpt4 book ai didi

cuda - CUBLAS 矩阵乘法

转载 作者:行者123 更新时间:2023-12-01 16:13:56 30 4
gpt4 key购买 nike

使用 CUDA 实现矩阵乘法后。我尝试用CUBLAS实现它(感谢论坛中一些人的建议)。

我可以乘方阵,但是(是的,再次......)我在处理非方阵时遇到困难。唯一有效的非方阵乘法类型是当您改变矩阵 A 的宽度(A*B=C) 时。

我没有收到任何错误,但结果矩阵返回错误的值。这是我的代码(它基本上是 simpleCUBLAS SDK 示例的改编):

#include <stdlib.h>
#include <stdio.h>
#include "cublas.h"
#define HA 2
#define WA 9
#define WB 2
#define HB WA
#define WC WB
#define HC HA
#define index(i,j,ld) (((j)*(ld))+(i))

void printMat(float*P,int uWP,int uHP){
//printf("\n %f",P[1]);
int i,j;
for(i=0;i<uHP;i++){

printf("\n");

for(j=0;j<uWP;j++)
printf("%f ",P[index(i,j,uHP)]);
//printf("%f ",P[i*uWP+j]);
}
}




int main (int argc, char** argv) {
cublasStatus status;
int i,j;
cublasInit();

float *A = (float*)malloc(HA*WA*sizeof(float));
float *B = (float*)malloc(HB*WB*sizeof(float));
float *C = (float*)malloc(HC*WC*sizeof(float));
if (A == 0) {
fprintf (stderr, "!!!! host memory allocation error (A)\n");
return EXIT_FAILURE;
}
if (B == 0) {
fprintf (stderr, "!!!! host memory allocation error (A)\n");
return EXIT_FAILURE;
}
if (C == 0) {
fprintf (stderr, "!!!! host memory allocation error (A)\n");
return EXIT_FAILURE;
}


for (i=0;i<HA;i++)
for (j=0;j<WA;j++)
A[index(i,j,HA)] = (float) index(i,j,HA);
for (i=0;i<HB;i++)
for (j=0;j<WB;j++)
B[index(i,j,HB)] = (float) index(i,j,HB);
/*
for (i=0;i<HA*WA;i++)
A[i]=(float) i;
for (i=0;i<HB*WB;i++)
B[i]=(float) i; */


float* AA; float* BB; float* CC;

/*ALLOCATE ON THE DEVICE*/
status=cublasAlloc(HA*WA,sizeof(float),(void**)&AA);
if (status != CUBLAS_STATUS_SUCCESS) {
fprintf (stderr, "!!!! device memory allocation error (A)\n");
return EXIT_FAILURE;
}

status=cublasAlloc(HB*WB,sizeof(float),(void**)&BB);
if (status != CUBLAS_STATUS_SUCCESS) {
fprintf (stderr, "!!!! device memory allocation error (A)\n");
return EXIT_FAILURE;
}

status=cublasAlloc(HC*WC,sizeof(float),(void**)&CC);
if (status != CUBLAS_STATUS_SUCCESS) {
fprintf (stderr, "!!!! device memory allocation error (A)\n");
return EXIT_FAILURE;
}

/*SET MATRIX*/
status=cublasSetMatrix(HA,WA,sizeof(float),A,HA,AA,HA);
if (status != CUBLAS_STATUS_SUCCESS) {
fprintf (stderr, "!!!! device memory allocation error (A)\n");
return EXIT_FAILURE;
}

status=cublasSetMatrix(HB,WB,sizeof(float),B,HB,BB,HB);
if (status != CUBLAS_STATUS_SUCCESS) {
fprintf (stderr, "!!!! device memory allocation error (A)\n");
return EXIT_FAILURE;
}

/*KERNEL*/
cublasSgemm('n','n',HA,WB,WA,1,AA,HA,BB,HB,0,CC,HC);

status = cublasGetError();
if (status != CUBLAS_STATUS_SUCCESS) {
fprintf (stderr, "!!!! kernel execution error.\n");
return EXIT_FAILURE;
}
cublasGetMatrix(HC,WC,sizeof(float),CC,HC,C,HC);
if (status != CUBLAS_STATUS_SUCCESS) {
fprintf (stderr, "!!!! device read error (A)\n");
return EXIT_FAILURE;
}


/* PERFORMANCE OUTPUT*/

printf("\nMatriz A:\n");
printMat(A,WA,HA);
printf("\nMatriz B:\n");
printMat(B,WB,HB);
printf("\nMatriz C:\n");
printMat(C,WC,HC);

free( A ); free( B ); free ( C );
status = cublasFree(AA);
if (status != CUBLAS_STATUS_SUCCESS) {
fprintf (stderr, "!!!! memory free error (A)\n");
return EXIT_FAILURE;
}
status = cublasFree(BB);
if (status != CUBLAS_STATUS_SUCCESS) {
fprintf (stderr, "!!!! memory free error (B)\n");
return EXIT_FAILURE;
}
status = cublasFree(CC);
if (status != CUBLAS_STATUS_SUCCESS) {
fprintf (stderr, "!!!! memory free error (C)\n");
return EXIT_FAILURE;
}

/* Shutdown */
status = cublasShutdown();
if (status != CUBLAS_STATUS_SUCCESS) {
fprintf (stderr, "!!!! shutdown error (A)\n");
return EXIT_FAILURE;
}

if (argc > 1) {
if (!strcmp(argv[1], "-noprompt") ||!strcmp(argv[1], "-qatest") )
{
return EXIT_SUCCESS;
}
}
else
{
printf("\nPress ENTER to exit...\n");
getchar();
}

return EXIT_SUCCESS;


}

有什么想法吗?另外,有没有人在 CUBLAS 中有一个可以工作的矩阵乘法实现,所以我可以比较?提前致谢。

最佳答案

我不明白为什么您认为您发布的代码不起作用。当我编译并运行它时,生成的可执行文件产生的输出与我在 matlab 中输入相同的矩阵并计算它们的乘积时得到的输出相同。

CUBLAS 是 FORTRAN BLAS,它需要按列主要顺序输入(并且您的代码是列主要顺序)。如果结果与您想要的不匹配,则您一定在某处混淆了列和行的主要排序。

关于cuda - CUBLAS 矩阵乘法,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/5551020/

30 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com