gpt4 book ai didi

matlab/octave - 广义矩阵乘法

转载 作者:太空宇宙 更新时间:2023-11-03 19:17:01 25 4
gpt4 key购买 nike

我想做一个函数来概括矩阵乘法。基本上,它应该能够执行标准矩阵乘法,但它应该允许通过任何其他函数更改两个二元运算符的积/和。

目标是尽可能提高 CPU 和内存的效率。当然,它总是比 A*B 效率低,但运算符(operator)的灵 active 是这里的重点。

以下是我在阅读 various 后可能想到的一些命令interesting threads :

A = randi(10, 2, 3);
B = randi(10, 3, 4);

% 1st method
C = sum(bsxfun(@mtimes, permute(A,[1 3 2]),permute(B,[3 2 1])), 3)
% Alternative: C = bsxfun(@(a,b) mtimes(a',b), A', permute(B, [1 3 2]))

% 2nd method
C = sum(bsxfun(@(a,b) a*b, permute(A,[1 3 2]),permute(B,[3 2 1])), 3)

% 3rd method (Octave-only)
C = sum(permute(A, [1 3 2]) .* permute(B, [3 2 1]), 3)

% 4th method (Octave-only): multiply nxm A with nx1xd B to create a nxmxd array
C = bsxfun(@(a, b) sum(times(a,b)), A', permute(B, [1 3 2]));
C = C2 = squeeze(C(1,:,:)); % sum and turn into mxd

方法 1-3 的问题在于它们会在使用 sum() 折叠它们之前生成 n 个矩阵。 4 更好,因为它在 bsxfun 内部执行了 sum(),但是 bsxfun 仍然生成 n 个矩阵(除了它们大部分是空的,只包含一个非零值向量作为和,其余填充 0 以匹配尺寸要求)。

我想要的是类似于第 4 种方法但没有无用的 0 来节省内存。

有什么想法吗?

最佳答案

这是您发布的解决方案的稍微完善的版本,有一些小的改进。

我们检查我们的行数是否多于列数,或者相反,然后通过选择将行与矩阵相乘或矩阵与列相乘来相应地进行乘法运算(从而使循环迭代次数最少)。

A*B

注意:这可能并不总是最好的策略(按行而不是按列),即使行数少于列数;事实上,MATLAB 数组存储在 column-major order 中在内存中使得按列切片更有效,因为元素是连续存储的。而访问行涉及通过 strides 遍历元素(缓存不友好——想一想 spatial locality)。

除此之外,代码应该处理双/单、实数/复数、全数/稀疏(以及不可能组合的错误)。它还尊重空矩阵和零维。

function C = my_mtimes(A, B, outFcn, inFcn)
% default arguments
if nargin < 4, inFcn = @times; end
if nargin < 3, outFcn = @sum; end

% check valid input
assert(ismatrix(A) && ismatrix(B), 'Inputs must be 2D matrices.');
assert(isequal(size(A,2),size(B,1)),'Inner matrix dimensions must agree.');
assert(isa(inFcn,'function_handle') && isa(outFcn,'function_handle'), ...
'Expecting function handles.')

% preallocate output matrix
M = size(A,1);
N = size(B,2);
if issparse(A)
args = {'like',A};
elseif issparse(B)
args = {'like',B};
else
args = {superiorfloat(A,B)};
end
C = zeros(M,N, args{:});

% compute matrix multiplication
% http://en.wikipedia.org/wiki/Matrix_multiplication#Inner_product
if M < N
% concatenation of products of row vectors with matrices
% A*B = [a_1*B ; a_2*B ; ... ; a_m*B]
for m=1:M
%C(m,:) = A(m,:) * B;
%C(m,:) = sum(bsxfun(@times, A(m,:)', B), 1);
C(m,:) = outFcn(bsxfun(inFcn, A(m,:)', B), 1);
end
else
% concatenation of products of matrices with column vectors
% A*B = [A*b_1 , A*b_2 , ... , A*b_n]
for n=1:N
%C(:,n) = A * B(:,n);
%C(:,n) = sum(bsxfun(@times, A, B(:,n)'), 2);
C(:,n) = outFcn(bsxfun(inFcn, A, B(:,n)'), 2);
end
end
end

比较

毫无疑问,该函数自始至终都比较慢,但对于较大的尺寸,它比内置矩阵乘法差几个数量级:

        (tic/toc times in seconds)
(tested in R2014a on Windows 8)

size mtimes my_mtimes
____ __________ _________
400 0.0026398 0.20282
600 0.012039 0.68471
800 0.014571 1.6922
1000 0.026645 3.5107
2000 0.20204 28.76
4000 1.5578 221.51

mtimes_vs_mymtimes

测试代码如下:

sz = [10:10:100 200:200:1000 2000 4000];
t = zeros(numel(sz),2);
for i=1:numel(sz)
n = sz(i); disp(n)
A = rand(n,n);
B = rand(n,n);

tic
C = A*B;
t(i,1) = toc;
tic
D = my_mtimes(A,B);
t(i,2) = toc;

assert(norm(C-D) < 1e-6)
clear A B C D
end

semilogy(sz, t*1000, '.-')
legend({'mtimes','my_mtimes'}, 'Interpreter','none', 'Location','NorthWest')
xlabel('Size N'), ylabel('Time [msec]'), title('Matrix Multiplication')
axis tight

额外

为了完整起见,下面是实现广义矩阵乘法的两种更简单的方法(如果您想比较性能,请将 my_mtimes 函数的最后一部分替换为其中任何一种)。我什至懒得张贴他们耗时:)

C = zeros(M,N, args{:});
for m=1:M
for n=1:N
%C(m,n) = A(m,:) * B(:,n);
%C(m,n) = sum(bsxfun(@times, A(m,:)', B(:,n)));
C(m,n) = outFcn(bsxfun(inFcn, A(m,:)', B(:,n)));
end
end

另一种方式(使用三重循环):

C = zeros(M,N, args{:});
P = size(A,2); % = size(B,1);
for m=1:M
for n=1:N
for p=1:P
%C(m,n) = C(m,n) + A(m,p)*B(p,n);
%C(m,n) = plus(C(m,n), times(A(m,p),B(p,n)));
C(m,n) = outFcn([C(m,n) inFcn(A(m,p),B(p,n))]);
end
end
end

接下来要尝试什么?

如果您想获得更多性能,您将不得不转向 C/C++ MEX 文件以减少解释 MATLAB 代码的开销。您仍然可以通过从 MEX 文件调用优化的 BLAS/LAPACK 例程来利用它们(有关示例,请参见 the second part of this post)。 MATLAB 附带 Intel MKL坦率地说,在英特尔处理器上进行线性代数计算时,您无法击败它。

其他人已经在 File Exchange 上提到了一些将通用矩阵例程实现为 MEX 文件的提交(参见 @natan 的回答)。如果您将它们链接到优化的 BLAS 库,它们会特别有效。

关于matlab/octave - 广义矩阵乘法,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/24245225/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com