gpt4 book ai didi

python - 我的 learning_rate 真的在 theano 中改变了吗?

转载 作者:太空宇宙 更新时间:2023-11-04 01:01:21 26 4
gpt4 key购买 nike

我正在尝试调整梯度下降算法的学习率。我希望能够确认我对 learning_rate 的更改是否真的对我的 theano 训练功能产生影响。

示例代码:

#set up the updates
for param in params:
updates.append((param, param-learning_rate*T.grad(cost, param)))
#set up the training function
train = theano.function(inputs=[index], outputs=[cost], updates=updates, givens={x:self.X[index:index+mini_batch_size,:]})

#run through the minibatches
for epoch in range(n_epochs):
for row in range(0,self.m, mini_batch_size):
cost = train(row)
#occasionally adjust the learning rate
learning_rate = learning_rate/2.0

这会如我所愿地工作吗?如何确认?

编辑:

根据这个小测试,这似乎行不通:

x = th.tensor.dscalar()
rate=5.0
f = th.function(inputs=[x], outputs=2*x*rate)
print(f(10))
>> 100.0
rate=0.0
print(f(10))
>> 100.0

解决这个问题的正确方法是什么?

最佳答案

问题是您的代码将学习率作为常量编译到计算图中。如果你想改变速率,你需要使用 Theano 变量在计算图中表示它,然后在函数执行时提供一个值。这可以通过两种方式完成:

  1. 每次执行函数时传递速率,方法是将其视为输入值并在计算图中将其表示为标量张量。

  2. 将比率存储在 Theano 共享变量中。在执行函数之前手动更改变量。

第二种方法有两种变体。首先,您在执行前手动调整速率值。在第二个中,您指定一个符号表达式来解释每次执行时应如何更新速率。

此示例代码基于问题的编辑部分演示了这三种方法。

import theano as th
import theano.tensor

# Original version (changing rate doesn't affect theano function output)
x = th.tensor.dscalar()
rate=5.0
f = th.function(inputs=[x], outputs=2*x*rate)
print(f(10))
rate=0.0
print(f(10))


# New version using an input value
x = th.tensor.dscalar()
rate=th.tensor.scalar()
f = th.function(inputs=[x, rate], outputs=2*x*rate)
print(f(10, 5.0))
print(f(10, 0.0))


# New version using a shared variable with manual update
x = th.tensor.dscalar()
rate=th.shared(5.0)
f = th.function(inputs=[x], outputs=2*x*rate)
print(f(10))
rate.set_value(0.0)
print(f(10))


# New version using a shared variable with automatic update
x = th.tensor.dscalar()
rate=th.shared(5.0)
updates=[(rate, rate / 2.0)]
f = th.function(inputs=[x], outputs=2*x*rate, updates=updates)
print(f(10))
print(f(10))
print(f(10))
print(f(10))

关于python - 我的 learning_rate 真的在 theano 中改变了吗?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/32717528/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com