gpt4 book ai didi

julia - Flux.jl 中的交叉熵损失 - Julia

转载 作者:行者123 更新时间:2023-12-05 02:41:08 25 4
gpt4 key购买 nike

我想设置我的模型以在 Flux.jl 中使用交叉熵损失。我该怎么做,我应该在哪里传递损失函数本身?

最佳答案

Flux.jl 通过 Flux.Losses 模块提供了一个内置模块,其中包含许多常见的损失函数,您可以通过 using Flux.Losses 访问该模块。模块中内置了交叉熵损失函数,使用方法如下:

julia> y_label = Flux.onehotbatch([0, 1, 2, 1, 0], 0:2)
3×5 Flux.OneHotArray{3,2,Vector{UInt32}}:
1 0 0 0 1
0 1 0 1 0
0 0 1 0 0

julia> y_model = softmax(reshape(-7:7, 3, 5) .* 1f0)
3×5 Matrix{Float32}:
0.0900306 0.0900306 0.0900306 0.0900306 0.0900306
0.244728 0.244728 0.244728 0.244728 0.244728
0.665241 0.665241 0.665241 0.665241 0.665241

julia> sum(y_model; dims=1)
1×5 Matrix{Float32}:
1.0 1.0 1.0 1.0 1.0

julia> Flux.crossentropy(y_model, y_label)
1.6076053f0

您可以在此处找到完整的 Flux.crossentropy 函数定义:https://fluxml.ai/Flux.jl/stable/models/losses/#Flux.Losses.crossentropy

定义损失函数后,您可以将其传递给内置的训练函数:Flux.train!(loss, params(model), data, opt) 或在您的自定义中使用它训练循环。

关于julia - Flux.jl 中的交叉熵损失 - Julia,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/68239673/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com