gpt4 book ai didi

matlab - 矢量化和嵌套矩阵乘法

转载 作者:太空宇宙 更新时间:2023-11-03 20:13:48 25 4
gpt4 key购买 nike

原代码如下:

K = zeros(N*N)
for a=1:N
for i=1:I
for j=1:J
M = kron(X(:,:,a).',Y(:,:,a,i,j));

%A function that essentially adds M to K.
end
end
end

目标是矢量化 kroniker 乘法调用。我的直觉是将 X 和 Y 视为矩阵的容器(作为引用,馈送到 kron 的 X 和 Y 的切片是 7x7 阶方矩阵)。在此容器方案下,X 显示为 1-D 容器,Y 显示为 3-D 容器。我的下一个猜测是将 Y reshape 为二维容器或更好的一维容器,然后对 X 和 Y 进行元素明智的乘法。问题是:如何以保留 M 和matlab 甚至可以处理这个容器想法中的这个想法,还是容器需要进一步 reshape 以进一步暴露内部矩阵元素?

最佳答案

方法 #1:使用 6D 置换的矩阵乘法

% Get sizes
[m1,m2,~] = size(X);
[n1,n2,N,n4,n5] = size(Y);

% Lose the third dim from X and Y with matrix-multiplication
parte1 = reshape(permute(Y,[1,2,4,5,3]),[],N)*reshape(X,[],N).';

% Rearrange the leftover dims to bring kron format
parte2 = reshape(parte1,[n1,n2,I,J,m1,m2]);

% Lose dims correspinding to last two dims coming in from Y corresponding
% to the iterative summation as suggested in the question
out = reshape(permute(sum(sum(parte2,3),4),[1,6,2,5,3,4]),m1*n1,m2*n2)

方法 #2:简单的 7D 置换

% Get sizes
[m1,m2,~] = size(X);
[n1,n2,N,n4,n5] = size(Y);

% Perform kron format elementwise multiplication betwen the first two dims
% of X and Y, keeping the third dim aligned and "pushing out" leftover dims
% from Y to the back
mults = bsxfun(@times,permute(X,[4,2,5,1,3]),permute(Y,[1,6,2,7,3,4,5]));

% Lose the two dims with summation reduction for final output
out = sum(reshape(mults,m1*n1,m2*n2,[]),3);

验证

这是运行原始方法和建议方法的设置 -

% Setup inputs
X = rand(10,10,10);
Y = rand(10,10,10,10,10);

% Original approach
[n1,n2,N,I,J] = size(Y);
K = zeros(100);
for a=1:N
for i=1:I
for j=1:J
M = kron(X(:,:,a).',Y(:,:,a,i,j));
K = K + M;
end
end
end

% Approach #1
[m1,m2,~] = size(X);
[n1,n2,N,n4,n5] = size(Y);
mults = bsxfun(@times,permute(X,[4,2,5,1,3]),permute(Y,[1,6,2,7,3,4,5]));
out1 = sum(reshape(mults,m1*n1,m2*n2,[]),3);

% Approach #2
[m1,m2,~] = size(X);
[n1,n2,N,n4,n5] = size(Y);
parte1 = reshape(permute(Y,[1,2,4,5,3]),[],N)*reshape(X,[],N).';
parte2 = reshape(parte1,[n1,n2,I,J,m1,m2]);
out2 = reshape(permute(sum(sum(parte2,3),4),[1,6,2,5,3,4]),m1*n1,m2*n2);

运行后,我们看到最大值。提议的方法与原始方法的绝对偏差 -

>> error_app1 = max(abs(K(:)-out1(:)))
error_app1 =
1.1369e-12
>> error_app2 = max(abs(K(:)-out2(:)))
error_app2 =
1.1937e-12

值(value)观对我来说很好!


基准测试

使用与验证相同的大数据集对这三种方法进行计时,我们得到这样的结果 -

----------------------------- With Loop
Elapsed time is 1.541443 seconds.
----------------------------- With BSXFUN
Elapsed time is 1.283935 seconds.
----------------------------- With MATRIX-MULTIPLICATION
Elapsed time is 0.164312 seconds.

对于这些大小的数据集,矩阵乘法似乎做得相当好!

关于matlab - 矢量化和嵌套矩阵乘法,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/38640734/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com