gpt4 book ai didi

matrix - (Julia) .+ 运算符似乎没有使用我的自定义广播函数,为什么?

转载 作者:行者123 更新时间:2023-12-02 07:14:11 24 4
gpt4 key购买 nike

我正在实现一个自定义矩阵,它只有一个非零值,无论您执行什么操作,这都是矩阵中唯一可以非零的单元格。我将其称为 SVMatrix(单值矩阵)。到目前为止我的代码是

struct SVMatrix{T} <: Base.AbstractMatrix{T}
value::T
index::Tuple{Int,Int}
size::Tuple{Int,Int}
end

function Base.broadcast(+, A::SVMatrix, B::AbstractArray)
SVMatrix(A.value+B[A.index...], A.index, A.size)
end

function Base.getindex(A::SVMatrix{T}, i::Int) where {T}
if i == A.index[1] + A.index[2]*A.size[1]
A.value
else
0
end
end


function Base.getindex(A::SVMatrix{T}, i::Vararg{Int,2}) where {T}
if i == A.index
return A.value
else
0
end
end

function Base.size(A::SVMatrix)
A.size
end

然后我通过以下方式将广播函数与 .+ 运算符一起计时

function time(n::Int)
A = SVMatrix(1.0, (3,4), (n, n))
B = rand(n,n)
@time broadcast(+, A, B)
@time A .+ B
end

time(1000)
println()
time(1000)

并得到结果

 0.000000 seconds
0.008207 seconds (2 allocations: 7.629 MiB, 47.51% gc time)

0.000000 seconds
0.008258 seconds (2 allocations: 7.629 MiB)

所以看起来 .+ 没有使用我的自定义广播函数,尽管它说 in the documentation那个

In fact, f.(args...) is equivalent to broadcast(f, args...), providing a convenient syntax to broadcast any function (dot syntax).

为什么我会得到这些结果?

最佳答案

这实际上是一个很好的例子,说明您不应该扩展广播。

julia> struct SVMatrix{T} <: Base.AbstractMatrix{T}
value::T
index::Tuple{Int,Int}
size::Tuple{Int,Int}
end

julia> @inline function Base.getindex(A::SVMatrix{T}, i::Vararg{Int,2}) where {T}
@boundscheck checkbounds(A, i...)
if i == A.index
return A.value
else
return zero(T)
end
end

julia> Base.size(A::SVMatrix) = A.size

julia> SVMatrix(1.0, (1,1), (2, 2)) .+ ones(2, 2)
2×2 Array{Float64,2}:
2.0 1.0
1.0 1.0

.+的结果不应该是[2 0; 0 0] !如果我们使用您的广播实现(更正为在 ::typeof(+) 上调度为 DNF noted ),当其他人使用它并期望它像所有其他 AbstractArray 一样运行时,您的数组会令人惊讶地崩溃。 s。

现在,您可以返回一个智能重新计算的操作 SVMatrix.* :

julia> SVMatrix(2.5, (1,1), (2, 2)) .* ones(2, 2)
2×2 Array{Float64,2}:
2.5 0.0
0.0 0.0

我们可以在 O(1) 空间和时间中完成此操作,但默认实现是循环所有值并返回密集的 Array 。这就是 Julia 的多重调度的闪光点:

julia> Base.broadcasted(::typeof(*), A::SVMatrix, B::AbstractArray) = SVMatrix(A.value*B[A.index...], A.index, A.size)

julia> SVMatrix(2.5, (1,1), (2, 2)) .* ones(2, 2)
2×2 SVMatrix{Float64}:
2.5 0.0
0.0 0.0

由于这是一个 O(1) 操作,并且是一个巨大的胜利,我们可以选择退出广播融合并立即重新计算一个新的 SVMatrix - 即使在“融合”表达式中。不过,您还没有完成!

  • 需要对兼容形状进行错误检查。
  • 需要允许广播类似 SVMatrix(2.5, (1,1), (2, 2)) .* rand(2) 的内容.
  • 理想情况下,您应该实现 BroadcastStyle 允许对“参数列表中至少有一个 SVMatrix”进行分派(dispatch)。然后你可以实现Base.broadcasted(::ArrayStyle{SVMatrix}, ::typeof(*), args...)这将允许 SVMatrix出现在 .* 的两侧,但这是一个高级主题。

关于matrix - (Julia) .+ 运算符似乎没有使用我的自定义广播函数,为什么?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/58698562/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com