gpt4 book ai didi

tensorflow - 为什么当我在 tensorflow 中更改测试批量大小时,结果不同

转载 作者:行者123 更新时间:2023-12-05 07:42:44 25 4
gpt4 key购买 nike

这是我的火车代码:

x = tf.placeholder(tf.float32, [None, 2, 3])
cell = tf.nn.rnn_cell.GRUCell(10)

_, state = tf.nn.dynamic_rnn(
cell = cell,
inputs = x,
dtype = tf.float32)
# train
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
x_ = np.ones([2,2,3],np.float32)
output = sess.run(state, feed_dict= {x: x_})
print output
saver = tf.train.Saver()
saver.save(sess,'./model')

结果是:

[[ 0.12851571 -0.23994535  0.23123585 -0.00047993 -0.02450397
-0.21048039 -0.18786618 0.04458345 -0.08603278 -0.08259721]
[ 0.12851571 -0.23994535 0.23123585 -0.00047993 -0.02450397
-0.21048039 -0.18786618 0.04458345 -0.08603278 -0.08259721]]

这是我的测试代码:

x = tf.placeholder(tf.float32, [None, 2, 3])
cell = tf.nn.rnn_cell.GRUCell(10)

_, state = tf.nn.dynamic_rnn(
cell = cell,
inputs = x,
dtype = tf.float32)
with tf.Session() as sess:
x_ = np.ones([1,2,3],np.float32)
saver = tf.train.Saver()
saver.restore(sess,'./model')
output = sess.run(state, feed_dict= {x: x_})
print output

然后我得到:

[[ 0.12851571 -0.23994535  0.2312358  -0.00047993 -0.02450397 
-0.21048039 -0.18786621 0.04458345 -0.08603278 -0.08259721]]

你看,结果略有变化。当我将测试批处理设置为 2 时,结果与训练结果相同。那怎么了?我的tf版本是0.12

最佳答案

更新(不是答案)

tf.nn.rnn_cell.GRUCelltf.nn.dynamic_rnn 都已弃用并替换为 tf.keras.layers.GRU.

使用已弃用的函数,您似乎甚至不需要保存和恢复模型,甚至不需要多次运行它。您所需要做的就是以奇数批处理大小运行它并使用 tf.float32 作为数据类型,最后的结果将略有偏差。

import tensorflow as tf
import numpy as np

x = tf.placeholder(tf.float32, [None, 2, 3])
cell = tf.nn.rnn_cell.GRUCell(10)

_, state = tf.nn.dynamic_rnn(
cell = cell,
inputs = x,
dtype = tf.float32)

sess = tf.Session()
sess.run(tf.global_variables_initializer())

x_ = np.ones([3,2,3],np.float32)
output = sess.run(state, feed_dict= {x: x_})
print(output)

返回这样的结果

[[ 0.03649516 -0.08052824 -0.0539998   0.2995336  -0.12542574 -0.04339318
0.3872745 0.08844283 -0.14555818 -0.4216033 ]
[ 0.03649516 -0.08052824 -0.0539998 0.2995336 -0.12542574 -0.04339318
0.3872745 0.08844283 -0.14555818 -0.4216033 ]
[ 0.03649516 -0.08052824 -0.05399981 0.2995336 -0.12542574 -0.04339318
0.38727456 0.08844285 -0.14555818 -0.4216033 ]]

异常似乎只出现在奇数批处理的最后一行。

另一种观点是,单个批处理是正确的,所有偶数大小的批处理都关闭,除了奇数大小的批处理的最后一行以外的所有内容都关闭。

对于 dtype=float64 或 dtype=float16 似乎没有发生,这两者看起来都很稳定。

此外,这个问题只是在隐藏状态下,似乎并没有出现在常规输出中。

关于tensorflow - 为什么当我在 tensorflow 中更改测试批量大小时,结果不同,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/44288469/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com