gpt4 book ai didi

python - tf.data.Iterator.get_next() : How to advance in tf. while_loop?

转载 作者:行者123 更新时间:2023-11-28 17:07:17 25 4
gpt4 key购买 nike

目前,我尝试在 Tensorflow while 循环中实现所有训练,但我在使用 Tensorflow 数据集 API 的迭代器时遇到了问题。

通常,当调用 sess.run() 时,Iterator.get_next() 会前进到下一个元素。但是,我需要前进到一次运行中的下一个元素。我该怎么做?

下面的小例子说明了我的问题:

import tensorflow as tf
import numpy as np


def for_loop(condition, modifier, body_op, idx=0):
idx = tf.convert_to_tensor(idx)

def body(i):
with tf.control_dependencies([body_op(i)]):
return [modifier(i)]

# do the loop:
loop = tf.while_loop(condition, body, [idx])
return loop


x = np.arange(10)

data = tf.data.Dataset.from_tensor_slices(x)
data = data.repeat()

iterator = data.make_initializable_iterator()
smpl = iterator.get_next()

loop = for_loop(
condition=lambda i: tf.less(i, 5),
modifier=lambda i: tf.add(i, 1),
body_op=lambda i: tf.Print(smpl, [smpl], message="This is sample: ")
)

sess = tf.InteractiveSession()
sess.run(iterator.initializer)
sess.run(loop)

输出:

This is sample: [0]
This is sample: [0]
This is sample: [0]
This is sample: [0]
This is sample: [0]

我总是得到完全相同的元素。

最佳答案

每次要“在一次运行中迭代”时,都需要调用 iterator.get_next()

例如,在您的玩具示例中,只需将 body_op 替换为:

 body_op=lambda i: tf.Print(i, [iterator.get_next()], message="This is sample: ")
# This is sample: [0]
# This is sample: [1]
# This is sample: [2]
# This is sample: [3]
# This is sample: [4]

关于python - tf.data.Iterator.get_next() : How to advance in tf. while_loop?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/50237486/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com