gpt4 book ai didi

python - 将代码从 Python 转换为 Julia 后代码无法正常工作

转载 作者:太空宇宙 更新时间:2023-11-03 12:43:30 24 4
gpt4 key购买 nike

我是 Julia 编程语言的新手,并且仍在通过编写我已经用 Python 编写的代码(或者至少尝试过使用 Python)来学习它。

有一篇文章解释了如何制作一个非常简单的神经网络:https://medium.com/technology-invention-and-more/how-to-build-a-simple-neural-network-in-9-lines-of-python-code-cc8f23647ca1 .

我在 Python 中尝试了本文中的代码,它工作正常。但是,我以前没有在 Python 中使用过线性代数的东西(比如点)。现在我正在尝试将这段代码翻译成 Julia,但有些事情我无法理解。这是我的 Julia 代码:

using LinearAlgebra

synaptic_weights = [-0.16595599, 0.44064899, -0.99977125]::Vector{Float64}

sigmoid(x) = 1 / (1 + exp(-x))
sigmoid_derivative(x) = x * (1 -x)

function train(training_set_inputs, training_set_outputs, number_of_training_iterations)
global synaptic_weights
for (iteration) in 1:number_of_training_iterations
output = think(training_set_inputs)

error = training_set_outputs .- output

adjustment = dot(transpose(training_set_inputs), error * sigmoid_derivative(output))

synaptic_weights = synaptic_weights .+ adjustment
end
end

think(inputs) = sigmoid(dot(inputs, synaptic_weights))

println("Random starting synaptic weights:")
println(synaptic_weights)

training_set_inputs = [0 0 1 ; 1 1 1 ; 1 0 1 ; 0 1 1]::Matrix{Int64}
training_set_outputs = [0, 1, 1, 0]::Vector{Int64}
train(training_set_inputs, training_set_outputs, 10000)

println("New synaptic weights after training:")
println(synaptic_weights)

println("Considering new situation [1, 0, 0] -> ?:")
println(think([1 0 0]))

我已经尝试将向量(如 synaptic_weights)初始化为:

synaptic_weights = [-0.16595599 ; 0.44064899 ; -0.99977125]

但是,代码不起作用。更确切地说,有 3 件事我不清楚:

  1. 我是否以正确的方式初始化向量和矩阵(是否与原作者在 Python 中所做的相同)?
  2. 在 Python 中,原作者使用 + 和 - 运算符,其中一个操作数是向量,另一个是标量。我不确定这是否意味着 Python 中的逐元素加法或减法。例如,Python 中的 (vector+scalar) 是否等于 Julia 中的 (vector.+scalar)?
  3. 当我尝试运行上面的 Julia 代码时,出现以下错误:

    ERROR: LoadError: DimensionMismatch("first array has length 12 which does not match the length of the second, 3.")
    Stacktrace:
    [1] dot(::Array{Int64,2}, ::Array{Float64,1}) at C:\Users\julia\AppData\Local\Julia-1.0.3\share\julia\stdlib\v1.0\LinearAlgebra\src\generic.jl:702
    [2] think(::Array{Int64,2}) at C:\Users\Viktória\Documents\julia.jl:21
    [3] train(::Array{Int64,2}, ::Array{Int64,1}, ::Int64) at C:\Users\Viktória\Documents\julia.jl:11
    [4] top-level scope at none:0
    in expression starting at C:\Users\Viktória\Documents\julia.jl:28

    当函数 think(inputs) 尝试计算输入和 synaptic_weights 的点积时,会出现此错误。在这种情况下,输入是一个 4x3 矩阵,突触权重是一个 3x1 矩阵(向量)。我知道它们可以相乘,结果会变成一个 4x1 的矩阵(向量)。这是否意味着可以计算它们的点积?

    无论如何,可以使用 numpy 包在 Python 中计算该点积,所以我想它也可以通过某种特定方式在 Julia 中计算。

对于点积,我还尝试创建一个函数,以 a 和 b 为参数,并尝试计算它们的点积:首先计算 a 和 b 的乘积,然后返回结果之和。我不确定这是否是一个好的解决方案,但是当我使用该函数时 Julia 代码没有产生预期的结果,所以我删除了它。

你能帮我看看这段代码吗?

最佳答案

这是针对 Julia 调整的代码:

sigmoid(x) = 1 / (1 + exp(-x))
sigmoid_derivative(x) = x * (1 -x)
think(synaptic_weights, inputs) = sigmoid.(inputs * synaptic_weights)

function train!(synaptic_weights, training_set_inputs, training_set_outputs,
number_of_training_iterations)
for iteration in 1:number_of_training_iterations
output = think(synaptic_weights, training_set_inputs)
error = training_set_outputs .- output
adjustment = transpose(training_set_inputs) * (error .* sigmoid_derivative.(output))
synaptic_weights .+= adjustment
end
end

synaptic_weights = [-0.16595599, 0.44064899, -0.99977125]
println("Random starting synaptic weights:")
println(synaptic_weights)

training_set_inputs = Float64[0 0 1 ; 1 1 1 ; 1 0 1 ; 0 1 1]
training_set_outputs = Float64[0, 1, 1, 0]
train!(synaptic_weights, training_set_inputs, training_set_outputs, 10000)

println("New synaptic weights after training:")
println(synaptic_weights)

println("Considering new situation [1, 0, 0] -> ?:")
println(think(synaptic_weights, Float64[1 0 0]))

有多项更改,所以如果您不清楚其中的一些更改,请询问,我会详细说明。

我改变的最重要的事情:

  • 不要使用全局变量,因为它们会显着降低性能
  • 使所有数组都具有Float64元素类型
  • 在几个地方你需要用 进行广播。(例如 sigmoidsigmoid_derivative 函数以他们期望的方式定义得到一个数字作为参数,因此当我们调用它们时 在它们的名字后面添加以触发广播)
  • 使用标准矩阵乘法 * 而不是 dot

代码的运行速度比 Python 中的原始实现快大约 30 倍。我没有挤出这段代码的最大性能(现在它做了很多可以避免的分配),因为它需要稍微重写它的逻辑,我猜你想要直接重新实现。

关于python - 将代码从 Python 转换为 Julia 后代码无法正常工作,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/55664062/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com