gpt4 book ai didi

tensorflow - 在 tensorflow 中使用 dropout 时出错

转载 作者:行者123 更新时间:2023-12-01 19:55:07 33 4
gpt4 key购买 nike

我正在尝试在 tensorflow 中使用 dropout 功能:

sess=tf.InteractiveSession()
initial = tf.truncated_normal([1,4], stddev=0.1)
x = tf.Variable(initial)
keep_prob = tf.placeholder("float")
dx = tf.nn.dropout(x, keep_prob)
sess.run(tf.initialize_all_variables())
sess.run(dx, feed_dict={keep_prob: 0.5})
sess.close()

此示例与 the tutorial 中的完成方式非常相似;但是,我最终遇到以下错误:

RuntimeError: min: Conversion function <function constant at 0x7efcc6e1ec80> for type <type 'object'> returned incompatible dtype: requested = float32_ref, actual = float32

我在理解数据类型float32_ref时遇到了一些困难,这似乎是问题的背景。我还尝试指定 dtype=tf.float32,但这并不能解决任何问题。

我也尝试了这个示例,它与 float32 配合得很好:

sess=tf.Session()
x=tf.Variable(np.array([1.0,2.0,3.0,4.0]))
sess.run(x.initializer)
x=tf.cast(x,tf.float32)
prob=tf.Variable(np.array([0.5]))
sess.run(prob.initializer)
prob=tf.cast(prob,tf.float32)
dx=tf.nn.dropout(x,prob)
sess.run(dx)
sess.close()

但是,如果我转换 float64 而不是 float32,我会得到相同的错误:

RuntimeError: min: Conversion function <function constant at 0x7efcc6e1ec80> for type <type 'object'> returned incompatible dtype: requested = float64_ref, actual = float64

编辑:

似乎只有在变量上直接使用 dropout 时才会出现此问题,适用于占位符以及变量和占位符的乘积,示例:

sess=tf.InteractiveSession()
x = tf.placeholder(tf.float64)

sess=tf.InteractiveSession()
initial = tf.truncated_normal([1,5], stddev=0.1,dtype=tf.float64)
y = tf.Variable(initial)

keep_prob = tf.placeholder(tf.float64)
dx = tf.nn.dropout(x*y, keep_prob)
sess.run(tf.initialize_all_variables())
sess.run(dx, feed_dict={x : np.array([1.0, 2.0, 3.0, 4.0, 5.0]),keep_prob: 0.5})
sess.close()

最佳答案

这是 tf.nn.dropout 实现中的一个错误,已在最近的提交中修复,并将包含在 TensorFlow 的下一版本中。目前,为了避免此问题,可以 build TensorFlow from source ,或按如下方式修改您的程序:

#dx = tf.nn.dropout(x, keep_prob)
dx = tf.nn.dropout(tf.identity(x), keep_prob)

关于tensorflow - 在 tensorflow 中使用 dropout 时出错,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/34160692/

33 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com