gpt4 book ai didi

python - 将一个 tf.Dataset 与另一个 tf.Dataset 随机交错

转载 作者:行者123 更新时间:2023-12-01 08:28:55 28 4
gpt4 key购买 nike

我有两个数据集:

main_ds = tf.data.Dataset.from_tensor_slices(list(range(1000, 1100)))
backgroud_ds = tf.data.Dataset.from_tensor_slices([1, 2, 3, 4])

我想要批量随机交错 main_dsbackgroud_ds 数据。例如,一批大小为 10 的批处理应如下所示:

[3, 1017, 1039, 3, 2, 1024, 4, 1, 1053, 4]

我尝试了以下方法:

def interlace_background(image, background):
return tf.cond(tf.random_uniform([]) < .5, lambda: image, lambda: background)

background_ds = background_ds.shuffle(10).repeat(-1)
background_it = background_ds.make_initializable_iterator()
background_next = background_it.get_next()

main_ds = main_ds.shuffle(10)\
.repeat(-1)\
.map(lambda x: interlace_background(x, background_next))\
.batch(10)
main_it = main_ds.make_initializable_iterator()
main_next = main_it.get_next()

但是我在所有批处理中都得到了固定的背景:

batch 0: [   3 1006    3 1001    3 1005 1015 1000    3    3]
batch 1: [1007 3 1012 1018 1013 3 1008 1019 3 3]
batch 2: [1016 3 1025 3 3 3 1021 3 3 1035]
batch 3: [1038 3 3 1023 1020 3 3 1046 1034 1047]
batch 4: [ 3 3 1039 3 3 3 3 3 1053 3]

为什么背景是固定的(参见上面的背景始终为 3)以及如何解决这个问题?

下面是完全可重现的代码:

import tensorflow as tf
import numpy as np

def interlace_background(image, background):
return tf.cond(tf.random_uniform([]) < .5, lambda: image, lambda: background)

main_ds = tf.data.Dataset.from_tensor_slices(list(range(1000, 1100)))
background_ds = tf.data.Dataset.from_tensor_slices([1, 2, 3, 4])

background_ds = background_ds.shuffle(10).repeat(-1)
background_it = background_ds.make_initializable_iterator()
background_next = background_it.get_next()

main_ds = main_ds.shuffle(10)\
.repeat(-1)\
.map(lambda x: interlace_background(x, background_next))\
.batch(10)
main_it = main_ds.make_initializable_iterator()
main_next = main_it.get_next()

with tf.Session() as sess:
sess.run(background_it.initializer)
sess.run(main_it.initializer)
for i in range(5):
print('batch %i' % i, sess.run(main_next))

最佳答案

您可以使用 Dataset.zip() 执行相同的操作和 Dataset.map()

这是代码:

import tensorflow as tf

def interlace_background(image, background):
return tf.cond(tf.random_uniform([]) < .5, lambda: image, lambda: background)


main_ds = tf.data.Dataset.from_tensor_slices(list(range(1000, 1100))).shuffle(100)
background_ds = tf.data.Dataset.from_tensor_slices([1, 2, 3, 4]).shuffle(4)

new_ds = tf.data.Dataset \
.zip((main_ds, background_ds)) \
.repeat(-1) \
.map(lambda x, y: interlace_background(x, y)) \
.batch(10)

iterator = new_ds.make_initializable_iterator()
next_item = iterator.get_next()

with tf.Session() as sess:
sess.run(iterator.initializer)
for i in range(5):
print('batch %i' % i, sess.run(next_item))

输出:

batch 0 [1065    2    4    1    2    4    1 1036 1072 1020]
batch 1 [ 4 3 2 1057 1 4 2 1077 3 1]
batch 2 [ 3 1044 1042 1049 1029 1 3 1069 1018 3]
batch 3 [ 2 4 1089 1094 2 1022 1041 1006 1 3]
batch 4 [1079 2 1 3 1023 1042 4 1018 1054 4]

关于python - 将一个 tf.Dataset 与另一个 tf.Dataset 随机交错,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/54025069/

28 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com