gpt4 book ai didi

c++ - Armadillo:矩阵乘法占用大量内存

转载 作者:行者123 更新时间:2023-11-28 01:52:22 29 4
gpt4 key购买 nike

我正在尝试使用 armadillo 进行线性回归,如下面的函数所示:

void compute_weights()
{
printf("transpose\n");
const mat &xt(X.t());
printf("inverse\n");
mat xd;
printf("mul\n");
xd = (xt * X);
printf("inv\n");
xd = xd.i();
printf("mul2\n");
xd = xd * xt;
printf("mul3\n");
W = xd * Y;
}

我已将其拆分,以便我可以看到程序变得如此庞大时发生了什么。矩阵 X 有 64 列和超过 2300 万行。转置并不算太糟糕,但是第一次乘法会导致内存占用完全爆炸。现在,据我了解,如果我乘以 X.t() * X,矩阵乘积的每个元素将是 X 的一列和 X.t() 的一行的点积,结果应该是一个 64x64 矩阵。

当然要花很长时间,但是为什么内存会突然爆到将近30G呢?

然后它似乎卡在那个内存上,然后当它进行第二次乘法时,它太多了,操作系统会因为变得如此之大而杀死它。

有没有一种方法可以在不占用太多内存的情况下计算乘积?那段内存能回收吗?有没有更好的方法来表示这些计算?

最佳答案

除非您使用大型工作站,否则您不可能一次性完成整个乘法运算。正如 hbrerkere 所说,您的初始消耗量约为 22 GB。因此,您要么为此做好准备,要么另辟蹊径。

如果您没有这样的工作站,另一种方法是自己进行乘法运算,然后将其并行化。方法如下:

  1. 不要将整个矩阵加载到内存中,而是加载其中的一部分。
  2. 加载大约一百万行 X,并将其存储在某个地方。
  3. 加载 Y 的一百万列
  4. 使用std::transform使用二元运算符 std::multiplies乘以你加载的部分(这将利用你的处理器的矢量化,并使其快速),并填写你计算的部分结果。
  5. 加载矩阵的下一部分,然后重复

这不会那么有效,但它会起作用。另一种选择是在将矩阵分解为更小的矩阵后考虑使用 Armadillo ,其乘法将产生子结果。

这两种方法都比完全乘法慢得多,原因有两个:

  1. 从内存中加载和删除数据的开销
  2. 矩阵乘法已经是一个复杂度为 O(N^3) 的问题...现在拆分你的乘法是复杂度为 O(N^2),所以它会变成复杂度为 O(N^6)...

祝你好运!

关于c++ - Armadillo:矩阵乘法占用大量内存,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/42501116/

29 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com