gpt4 book ai didi

julia - 快速计算sum_i f(i)x(i)x(i)'?

转载 作者:行者123 更新时间:2023-12-04 04:17:09 24 4
gpt4 key购买 nike

我正在尝试计算f(i) * x(i) * x(i)'的总和
其中x(i)是列向量,x(i)'是转置,而f(i)是标量。因此,这是外部产品的加权总和。

在MATLAB中,可以通过使用bsxfun来非常快地实现这一点。
以下代码在笔记本电脑(MacBook Air 2010)上以 260 ms 运行

N = 1e5;
d = 100;
f = randn(N, 1);
x = randn(N, d);
% H = zeros(d, d);

tic;
H = x' * bsxfun(@times, f, x);
toc

我一直在努力让 Julia (Julia)做同样的事情,但是我做不了更快。
N = int(1e5);
d = 100;
f = randn(N);
x = randn(N, d);

function hess1(x, f)
N, d = size(x);
temp = zeros(N, d);
@simd for kk = 1:N
@inbounds temp[kk, :] = f[kk] * x[kk, :];
end
H = x' * temp;
end

function hess2(x, f)
N, d = size(x);
H2 = zeros(d,d);
@simd for k = 1:N
@inbounds H2 += f[k] * x[k, :]' * x[k, :];
end
return H2
end

function hess3(x, f)
N, d = size(x);
H3 = zeros(d,d);
for k = 1:N
for k1 = 1:d
@simd for k2 = 1:d
@inbounds H3[k1, k2] += x[k, k1] * x[k, k2] * f[k];
end
end
end
return H3
end

结果是
@time H1 = hess1(x, f);
@time H2 = hess2(x, f);
@time H3 = hess3(x, f);
elapsed time: 0.776116469 seconds (262480224 bytes allocated, 26.49% gc time)
elapsed time: 30.496472345 seconds (16385442496 bytes allocated, 56.07% gc time)
elapsed time: 2.769934563 seconds (80128 bytes allocated)
hess1类似于MATLAB的 bsxfun,但速度较慢,并且 hess3不使用临时内存,但速度明显慢。我最好的julia代码比MATLAB慢3倍。

如何使这个 Julia 代码更快?

IJulia要点: http://nbviewer.ipython.org/gist/memming/669fb8e78af3338ebf6f

Julia 版本:0.3.0-rc1

编辑:
我在功能更强大的计算机上进行了测试(3.5 GHz英特尔i7、4核,L2 256kB,L3 8 MB)
  • 不带-singleCompThread的MATLAB R2014a: 0.053 s
  • 使用-singleCompThread的MATLAB R2014a: 0.080 s (@tholy的建议)
  • Julia 0.3.0-rc1
  • hess1耗时: 0.215406904 秒(分配262498648字节,gc时间为32.74%)
  • hess2耗时:10.722578699秒(已分配16384080176字节,gc时间为62.20%)
  • hess3耗时:1.065504355秒(已分配80176字节)
  • bsxfunstyle耗时: 0.063540168 秒(已分配80081072字节,gc时间为25.04%)(@ IainDunning的解决方案)

  • 的确,使用 broadcast更快,并且可以与MATLAB的bsxfun相提并论。

    最佳答案

    您正在寻找broadcast函数。这是relevant issue discussing the functionality and naming

    我实现了您的版本以及broadcast版本,这是我发现的内容:

    srand(1988)
    N = 100_000
    d = 100
    f = randn(N, 1)
    x = randn(N, d)

    function hess1(x, f)
    N, d = size(x);
    temp = zeros(N, d);
    @simd for kk = 1:N
    @inbounds temp[kk, :] = f[kk] * x[kk, :];
    end
    H = x' * temp;
    end

    function bsxfunstyle(x, f)
    x' * broadcast(*,f,x)
    end

    # Warmup
    hess1(x,f)
    bsxfunstyle(x, f)

    # For real
    println("Hess1")
    @time H1 = hess1(x, f)
    println("Broadcast")
    @time H2 = bsxfunstyle(x, f)

    # Check solutions are identical
    println(sum(abs(H1-H2)))

    带输出
    Hess1
    elapsed time: 0.324256216 seconds (262498648 bytes allocated, 33.95% gc time)
    Broadcast
    elapsed time: 0.126647594 seconds (80080696 bytes allocated, 20.22% gc time)
    0.0

    关于julia - 快速计算sum_i f(i)x(i)x(i)'?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/25223073/

    24 4 0
    Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
    广告合作:1813099741@qq.com 6ren.com