gpt4 book ai didi

python - 如何重用 iterator.get_next() 中的数据批处理

转载 作者:行者123 更新时间:2023-11-28 18:14:05 27 4
gpt4 key购买 nike

我正在实现一种涉及交替优化的算法。也就是说,在每次迭代中,算法都会获取一个数据批处理,并使用该数据批处理顺序优化两个损失。我当前使用 tf.data.Dataasettf.data.Iterator 的实现是这样的(这确实不正确,详见下文):

data_batch = iterator.get_next()
train_op_1 = get_train_op(data_batch)
train_op_2 = get_train_op(data_batch)

for _ in range(num_steps):
sess.run(train_op_1)
sess.run(train_op_2)

请注意,以上是不正确的,因为每次调用 sess.run 都会推进迭代器以获取下一个数据批处理。所以 train_op_1train_op_2 确实使用了不同的数据批处理。

我也不能做 sess.run([train_op_1, train_op_2]) 之类的事情,因为两个优化步骤需要是连续的(即,第二个优化步骤取决于最新的变量值第一个优化步骤。)

我想知道是否有任何方法可以以某种方式“卡住”迭代器,使其不会在 sess.run 调用中前进?

最佳答案

我正在做类似的事情,所以这是我的代码的一部分,从一些不必要的东西中剥离出来。它有更多的功能,因为它有训练和验证迭代器,但您应该了解使用 is_keep_previous 标志的想法。基本上作为 True 传递,它会强制重用迭代器的先前值,在 False 的情况下,它将获得新值。

iterator_t = ds_t.make_initializable_iterator()
iterator_v = ds_v.make_initializable_iterator()

iterator_handle = tf.placeholder(tf.string, shape=[], name="iterator_handle")
iterator = tf.data.Iterator.from_string_handle(iterator_handle,
iterator_t.output_types,
iterator_t.output_shapes)

def get_next_item():
# sometimes items need casting
next_elem = iterator.get_next(name="next_element")
x, y = tf.cast(next_elem[0], tf.float32), next_elem[1]
return x, y

def old_data():
# just forward the existing batch
return inputs, target

is_keep_previous = tf.placeholder_with_default(tf.constant(False),shape=[], name="keep_previous_flag")

inputs, target = tf.cond(is_keep_previous, old_data, new_data)

with tf.Session() as sess:
sess.run([tf.global_variables_initializer(),tf.local_variables_initializer()])
handle_t = sess.run(iterator_t.string_handle())
handle_v = sess.run(iterator_v.string_handle())
# Run data iterator initialisation
sess.run(iterator_t.initializer)
sess.run(iterator_v.initializer)
while True:
try:
inputs_, target_ = sess.run([inputs, target], feed_dict={iterator_handle: handle_t, is_keep_previous:False})
print(inputs_, target_)
inputs_, target_ = sess.run([inputs, target], feed_dict={iterator_handle: handle_t, is_keep_previous:True})
print(inputs_, target_)
inputs_, target_ = sess.run([inputs, target], feed_dict={iterator_handle: handle_v})
print(inputs_, target_)
except tf.errors.OutOfRangeError:
# now we know we run out of elements in the validationiterator
break

关于python - 如何重用 iterator.get_next() 中的数据批处理,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/49584489/

27 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com