gpt4 book ai didi

python - 即使我将学习率设置得尽可能小, tensorflow 优化器也会输出 nan

转载 作者:行者123 更新时间:2023-12-01 09:17:52 24 4
gpt4 key购买 nike

下面是代码。

def create_train_model(hidden_nodes,num_iters):
tf.reset_default_graph()
X=tf.placeholder(shape=(120,4),dtype=tf.float64,name='X')
y=tf.placeholder(shape=(120,1),dtype=tf.float64,name='y')
W1=tf.Variable(np.random.rand(4,hidden_nodes),dtype=tf.float64)
W2=tf.Variable(np.random.rand(hidden_nodes,2),dtype=tf.float64)
A1=tf.sigmoid(tf.matmul(X,W1))
U_est=tf.sigmoid(tf.matmul(A1,W2))
loss=fuloss3(U_est,y)
optimizer=tf.train.AdagradOptimizer(4.9406564584124654e-324)
TRAIN=optimizer.minimize(loss)
init=tf.initialize_all_variables()
sess=tf.Session()
sess.run(init)
for i in range(num_iters):
pout=sess.run(loss,feed_dict={X: Xtrain,
y: ytrain})
sess.run(TRAIN,feed_dict={X: Xtrain,
y: ytrain})
loss_plot[hidden_nodes][i]=sess.run(loss,feed_dict={X: Xtrain,y:
ytrain})
print(pout)
weights1=sess.run(W1)
weights2=sess.run(W2)
print(weights1)
print(weights2)
print('loss (hidden nodes: %d, iterations: %d): %.2f'%(hidden_nodes,
num_iters,loss_plot[hidden_nodes][num_iters-1]))
sess.close()
return weights1, weights2

print(pout) 返回一个非 nan 数字。训练结束后,权重全部为 nan。即使我将学习率设置为尽可能最小。为什么会出现这种情况呢?由于学习率如此之小,您基本上不会移动变量。从 pout 可以明显看出,最初的损失运行给出了有效的结果,这意味着这不是我设置损失的问题。提前致谢。

最佳答案

我怀疑您的问题出在这里:

W1=tf.Variable(np.random.rand(4,hidden_nodes),dtype=tf.float64)
W2=tf.Variable(np.random.rand(hidden_nodes,2),dtype=tf.float64)

试试这个:

W1 = tf.get_variable("W1", shape=..., dtype=...,
initializer=tf.contrib.layers.xavier_initializer())
W2 = tf.get_variable("W2", shape=..., dtype=...,
initializer=tf.contrib.layers.xavier_initializer())

你的权重初始化在[0,1]范围内,这是相当大的权重。这将使网络开始出现剧烈的梯度波动,这可能会让您陷入 NaN 情况。

xavier 初始化程序将考虑节点的输入数量并初始化该值,以免节点饱和。通俗地说,它会根据您的架构智能地初始化权重。

请注意,此初始化程序也有一个约定版本。

或者,作为快速测试,您可以通过简单地将随机权重乘以一个小值(例如 1e-4)来减小权重初始化的大小。

如果这不能解决问题,请在此处发表评论。

关于python - 即使我将学习率设置得尽可能小, tensorflow 优化器也会输出 nan,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/51071994/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com