作者热门文章
- html - 出于某种原因,IE8 对我的 Sass 文件中继承的 html5 CSS 不友好?
- JMeter 在响应断言中使用 span 标签的问题
- html - 在 :hover and :active? 上具有不同效果的 CSS 动画
- html - 相对于居中的 html 内容固定的 CSS 重复背景?
我尝试使用 cblas.h(来自 openblas 库)来计算两个矩阵的乘积。更具体地,I具有维度为n*d的双数组A、维度为m*d的数组B和维度为n*m的数组C。我想计算乘积 A 'time' B 转置。
我的代码是
#include <cblas.h>
#include <stdio.h>
#include <stdlib.h>
void random_matrix(double *X, int rows, int cols);
void print_matrix(double *X, int rows, int cols);
int main(int argc, char** argv)
{
int n = atoi(argv[1]),
m = atoi(argv[2]),
d = atoi(argv[3]);
double *A, *B, *C;
A = malloc(n*d*sizeof(double));
B = malloc(m*d*sizeof(double));
C = malloc(n*m*sizeof(double));
random_matrix(A,n,d);
print_matrix(A,n,d);
random_matrix(B,m,d);
print_matrix(B,m,d);
cblas_dgemm(CblasRowMajor,
CblasNoTrans, CblasTrans, n,m,d,
1.0, A, n, B, m,
0.0, C, n
);
print_matrix(C,n,m);
return 0;
}
void random_matrix(double *X, int rows, int cols){
for(int i = 0; i < rows; i++)
for(int j = 0; j < cols; j++)
X[i*cols+j] = (double)rand() / RAND_MAX + (double)(rand()%10);
}
void print_matrix(double *X, int rows, int cols){
for(int i = 0; i < rows; i++) {
for(int j = 0; j < cols; j++) {
printf("%g ", X[i*cols+j]);
}printf(";\n");
}printf("\n\n");
}
当我以 n = 6、m = 5 和 d = 2 运行程序时,输出是:
9.00001 8.75561 ;
2.53277 8.04704 ;
9.6793 5.3835 ;
2.83097 3.05346 ;
9.67115 2.38342 ;
9.41749 7.58898 ;
3.84617 8.09196 ;
5.416 6.91032 ;
7.26245 3.73608 ;
9.63264 1.99104 ;
2.24704 6.72266 ;
105.466 117.964 4.58702e-309 4.70574e-309 0 ;
-1.83255e-06 35.5969 39.9896 1.59969e-309 1.4802e-309 ;
0 -1.83255e-06 0 0 0 ;
0 0 -1.83255e-06 0 0 ;
0 0 0 -1.83255e-06 0 ;
0 0 0 0 -1.83255e-06 ;
这是错误的,因为当我在 Octave 音阶上尝试时,我得到:
octave:52> A = [9.00001 8.75561 ;
> 2.53277 8.04704 ;
> 9.6793 5.3835 ;
> 2.83097 3.05346 ;
> 9.67115 2.38342 ;
> 9.41749 7.58898 ;];
octave:53> B = [3.84617 8.09196 ;
> 5.416 6.91032 ;
> 7.26245 3.73608 ;
> 9.63264 1.99104 ;
> 2.24704 6.72266 ;];
octave:54> A*B'
ans =
105.466 109.248 98.074 104.127 79.084
74.858 69.325 48.459 40.419 59.789
80.791 89.625 90.409 103.956 57.941
35.597 36.433 31.968 33.349 26.889
56.483 68.849 79.141 97.904 37.754
97.631 103.447 96.747 105.825 72.180
最佳答案
基本上,它与主维度 lda
、ldb
和 ldc
有关。矩阵 A 存储为数组 A = malloc(n*d*sizeof(double))
。当您使用列优先访问时,元素 A_ij 为 A[j*lda + i]
。相反,当您使用行优先时,A_ij 是 A[i*lda + j]
因此,在行优先中,lda
是 A
的列数。分别针对 B
和 C
。
总而言之,您必须写:
cblas_dgemm(CblasRowMajor,
CblasNoTrans, CblasTrans, n,m,d,
1.0, A, d, B, d,
0.0, C, m
);
关于不明白 CBLAS 是如何工作的,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/58902333/
我是一名优秀的程序员,十分优秀!