gpt4 book ai didi

type-conversion - 在 Julia 中绘制 ForwardDiff 的输出

转载 作者:行者123 更新时间:2023-12-04 02:30:03 24 4
gpt4 key购买 nike

我只想使用 ForwardDiff.jl定义函数并绘制其梯度的功能(使用 ForwardDiff.gradient 评估)。它似乎不起作用,因为 ForwardDiff.gradient 的输出这很奇怪吗Dual类型的东西,它不容易转换为所需的类型(在我的例子中,是 Float32 的一维数组)。

using Plots
using ForwardDiff

my_func(x::Array{Float32,1}) = 1f0. / (1f0 .+ exp(3f0 .* x)) # doesn't matter what this is, just a sigmoid function here

grad_f(x::Array{Float32,1}) = ForwardDiff.gradient(my_func, x)

x_values = collect(Float32,0:0.01:10)

plot(x_values,my_func(x_values)); # this works fine

plot!(x_values,grad_f(x_values)); # this throws an error

这是我得到的错误:

ERROR: MethodError: no method matching Float64(::ForwardDiff.Dual{ForwardDiff.Tag{typeof(g),Float32},Float64,12})

当我检查 grad_f(x_values) 的类型时,我明白了:

Array{Array{ForwardDiff.Dual{ForwardDiff.Tag{typeof(g),Float32},Float32,12},1},1}

例如,为什么在 ForwardDiff 文档的示例中没有发生这种情况?看这里:https://github.com/JuliaDiff/ForwardDiff.jl

提前致谢。

编辑:在 Kristoffer Carlsson 的评论之后:我试过了,但它仍然不起作用。我不明白我在这里尝试的方法与他的建议有何不同:

function g(x::Float32)
return x / (1f0 + exp(10f0 * (x - 5f0)))
end

function ∂g∂x(x::Float32)
return ForwardDiff.derivative(g, x)
end

x_vals = collect(Float32,0:0.01:10)
plot(x_vals,g.(x_vals))
plot!(x_vals,∂g∂x.(x_vals))

现在的错误是:

no method matching g(::ForwardDiff.Dual{ForwardDiff.Tag{typeof(g),Float32},Float32,1})

这个错误发生在我调用 ∂g∂x(x) 时, 无论我是否使用广播版本 ∂g∂x.(x) .我想这与函数定义有关,但我看不出我定义它的方式与 Kristoffer 的版本有何不同,只是它不是在一行中定义的......这太令人困惑了。

这应该有效,因为根据 ForwardDiff的文档,您只需要输入的类型是 Real 的子类型- 和 Float32是 Real 的子类型。

编辑:我现在意识到,在阅读了其他人的评论后,您需要将您的函数限制为足够通用以接受任何抽象类型的输入 Real ,我并没有完全从文档中收集到。对造成的困惑表示歉意。

最佳答案

您在数组而不是标量上定义函数,并且还过多地限制了输入类型。此外,对于标量函数,您应该使用 ForwardDiff.derivative。尝试类似的东西:

using Plots
using ForwardDiff

my_func(x::Real) = 1f0 / (1f0 + exp(3f0 * x))
my_func_derivative(x::Real) = ForwardDiff.derivative(my_func, x)

plot(my_func, xlimits = (0, 10))
plot!(my_func_derivative)

给予:

enter image description here

关于type-conversion - 在 Julia 中绘制 ForwardDiff 的输出,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/64820651/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com