gpt4 book ai didi

Java实现矩阵乘法以及优化的方法实例

转载 作者:qq735679552 更新时间:2022-09-29 22:32:09 27 4
gpt4 key购买 nike

CFSDN坚持开源创造价值,我们致力于搭建一个资源共享平台,让每一个IT人在这里找到属于你的精彩世界.

这篇CFSDN的博客文章Java实现矩阵乘法以及优化的方法实例由作者收集整理,如果你对这篇文章有兴趣,记得点赞哟.

传统的矩阵乘法实现 。

  首先,两个矩阵能够相乘,必须满足一个前提:前一个矩阵的行数等于后一个矩阵的列数.

  第一个矩阵的第m行和第二个矩阵的第n列的乘积和即为乘积矩阵第m行第n列的值,可用如下图像表示这个过程.

Java实现矩阵乘法以及优化的方法实例

矩阵乘法过程展示 。

c[1][1] = a[1][0] * b[0][1] + a[1][1] * b[1][1] + a[1][2] * b[2][1] + a[1][3] * b[3][1] + a[1][4] * b[4][1] 。

  而用java实现该过程的传统方法就是按照该规则实现一个三重循环,把各项乘积累加:

?
1
2
3
4
5
6
7
8
9
10
11
12
public int [][] multiply( int [][] mat1, int [][] mat2){
     int m = mat1.length, n = mat2[ 0 ].length;
     int [][] mat = new int [m][n];
     for ( int i = 0 ; i < m; i++){
         for ( int j = 0 ; j < n; j++){
             for ( int k = 0 ; k < mat1[ 0 ].length; k++){
                 mat[i][j] += mat1[i][k] * mat2[k][j];
             }
         }
     }
     return mat;
}

  可以看出该方法的时间复杂度为o(n3),当矩阵维数比较大的时候程序就很容易超时.

优化方法(strassen算法) 。

  strassen算法是由volker strassen在1966年提出的第一个时间复杂度低于o(n³)的矩阵乘法算法,其主要思想是通过分治来实现矩阵乘法的快速运算,计算过程如图所示:

Java实现矩阵乘法以及优化的方法实例

将一次矩阵乘法拆分成多个乘法与加法的结合 。

  为什么这个方法会更快呢,我们知道,按照传统的矩阵乘法:

c11 = a11 * b11 + a12 * b21 c12 = a11 * b12 + a12 * b22 c21 = a21 * b11 + a22 * b21 c22 = a21 * b12 + a22 * b22 。

  我们需要8次矩阵乘法和4次矩阵加法,正是这8次乘法最耗时;而strassen方法只需要7次矩阵乘法,尽管代价是矩阵加法次数变为18次,但是基于数量级考虑,18次加法仍然快于1次乘法.

  当然,strassen算法的代码实现也比传统算法复杂许多,这里附上另一个大神写的java实现(原文链接:):

?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
public class matrix {
     private final matrix[] _matrixarray;
     private final int n;
     private int element;
     public matrix( int n) {
         this .n = n;
         if (n != 1 ) {
             this ._matrixarray = new matrix[ 4 ];
             for ( int i = 0 ; i < 4 ; i++) {
                 this ._matrixarray[i] = new matrix(n / 2 );
             }
         } else {
             this ._matrixarray = null ;
         }
     }
     private matrix( int n, boolean needinit) {
         this .n = n;
         if (n != 1 ) {
             this ._matrixarray = new matrix[ 4 ];
         } else {
             this ._matrixarray = null ;
         }
     }
     public void set( int i, int j, int a) {
         if (n == 1 ) {
             element = a;
         } else {
             int size = n / 2 ;
             this ._matrixarray[(i / size) * 2 + (j / size)].set(i % size, j % size, a);
         }
     }
     public matrix multi(matrix m) {
         matrix result = null ;
         if (n == 1 ) {
             result = new matrix( 1 );
             result.set( 0 , 0 , (element * m.element));
         } else {
             result = new matrix(n, false );
             result._matrixarray[ 0 ] = p5(m).add(p4(m)).minus(p2(m)).add(p6(m));
             result._matrixarray[ 1 ] = p1(m).add(p2(m));
             result._matrixarray[ 2 ] = p3(m).add(p4(m));
             result._matrixarray[ 3 ] = p5(m).add(p1(m)).minus(p3(m)).minus(p7(m));
         }
         return result;
     }
     public matrix add(matrix m) {
         matrix result = null ;
         if (n == 1 ) {
             result = new matrix( 1 );
             result.set( 0 , 0 , (element + m.element));
         } else {
             result = new matrix(n, false );
             result._matrixarray[ 0 ] = this ._matrixarray[ 0 ].add(m._matrixarray[ 0 ]);
             result._matrixarray[ 1 ] = this ._matrixarray[ 1 ].add(m._matrixarray[ 1 ]);
             result._matrixarray[ 2 ] = this ._matrixarray[ 2 ].add(m._matrixarray[ 2 ]);
             result._matrixarray[ 3 ] = this ._matrixarray[ 3 ].add(m._matrixarray[ 3 ]);;
         }
         return result;
     }
     public matrix minus(matrix m) {
         matrix result = null ;
         if (n == 1 ) {
             result = new matrix( 1 );
             result.set( 0 , 0 , (element - m.element));
         } else {
             result = new matrix(n, false );
             result._matrixarray[ 0 ] = this ._matrixarray[ 0 ].minus(m._matrixarray[ 0 ]);
             result._matrixarray[ 1 ] = this ._matrixarray[ 1 ].minus(m._matrixarray[ 1 ]);
             result._matrixarray[ 2 ] = this ._matrixarray[ 2 ].minus(m._matrixarray[ 2 ]);
             result._matrixarray[ 3 ] = this ._matrixarray[ 3 ].minus(m._matrixarray[ 3 ]);;
         }
         return result;
     }
     protected matrix p1(matrix m) {
         return _matrixarray[ 0 ].multi(m._matrixarray[ 1 ]).minus(_matrixarray[ 0 ].multi(m._matrixarray[ 3 ]));
     }
     protected matrix p2(matrix m) {
         return _matrixarray[ 0 ].multi(m._matrixarray[ 3 ]).add(_matrixarray[ 1 ].multi(m._matrixarray[ 3 ]));
     }
     protected matrix p3(matrix m) {
         return _matrixarray[ 2 ].multi(m._matrixarray[ 0 ]).add(_matrixarray[ 3 ].multi(m._matrixarray[ 0 ]));
     }
     protected matrix p4(matrix m) {
         return _matrixarray[ 3 ].multi(m._matrixarray[ 2 ]).minus(_matrixarray[ 3 ].multi(m._matrixarray[ 0 ]));
     }
     protected matrix p5(matrix m) {
         return (_matrixarray[ 0 ].add(_matrixarray[ 3 ])).multi(m._matrixarray[ 0 ].add(m._matrixarray[ 3 ]));
     }
     protected matrix p6(matrix m) {
         return (_matrixarray[ 1 ].minus(_matrixarray[ 3 ])).multi(m._matrixarray[ 2 ].add(m._matrixarray[ 3 ]));
     }
     protected matrix p7(matrix m) {
         return (_matrixarray[ 0 ].minus(_matrixarray[ 2 ])).multi(m._matrixarray[ 0 ].add(m._matrixarray[ 1 ]));
     }
     public int get( int i, int j) {
         if (n == 1 ) {
             return element;
         } else {
             int size = n / 2 ;
             return this ._matrixarray[(i / size) * 2 + (j / size)].get(i % size, j % size);
         }
     }
     public void display() {
         for ( int i = 0 ; i < n; i++) {
             for ( int j = 0 ; j < n; j++) {
                 system.out.print(get(i, j));
                 system.out.print( " " );
             }
             system.out.println();
         }
     }
    
     public static void main(string[] args) {
         matrix m = new matrix( 2 );
         matrix n = new matrix( 2 );
         m.set( 0 , 0 , 1 );
         m.set( 0 , 1 , 3 );
         m.set( 1 , 0 , 5 );
         m.set( 1 , 1 , 7 );
         n.set( 0 , 0 , 8 );
         n.set( 0 , 1 , 4 );
         n.set( 1 , 0 , 6 );
         n.set( 1 , 1 , 2 );
         matrix res = m.multi(n);
         res.display();
     }
}

总结 。

到此这篇关于java实现矩阵乘法以及优化的文章就介绍到这了,更多相关java矩阵乘法及优化内容请搜索我以前的文章或继续浏览下面的相关文章希望大家以后多多支持我! 。

原文链接:https://blog.csdn.net/GGG_Yu/article/details/109693318 。

最后此篇关于Java实现矩阵乘法以及优化的方法实例的文章就讲到这里了,如果你想了解更多关于Java实现矩阵乘法以及优化的方法实例的内容请搜索CFSDN的文章或继续浏览相关文章,希望大家以后支持我的博客! 。

27 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com