gpt4 book ai didi

python - 有没有办法在 Tensorflow 的另一个数据集中使用 tf.data.Dataset?

转载 作者:太空宇宙 更新时间:2023-11-04 04:48:09 26 4
gpt4 key购买 nike

我正在做分割。每个训练样本都有多个带有分割掩码的图像。我正在尝试编写 input_fn 以将所有掩码图像合并为每个训练样本的一个图像。我计划使用两个 Datasets,一个遍历样本文件夹,另一个将所有掩码作为一个大批量读取,然后将它们合并为一个张量。

调用嵌套的 make_one_shot_iterator 时出现错误。我知道这种方法有点牵强,而且很可能数据集不是为这种用途而设计的。但是我应该如何解决这个问题以避免使用 tf.py_func?

这是数据集的简化版本:

def read_sample(sample_path):
masks_ds = (tf.data.Dataset.
list_files(sample_path+"/masks/*.png")
.map(tf.read_file)
.map(lambda x: tf.image.decode_image(x, channels=1))
.batch(1024)) # maximum number of objects
masks = masks_ds.make_one_shot_iterator().get_next()

return tf.reduce_max(masks, axis=0)

ds = tf.data.Dataset.from_tensor_slices(tf.glob("../input/stage1_train/*"))
ds.map(read_sample)
# ...
sample = ds.make_one_shot_iterator().get_next()
# ...

最佳答案

如果嵌套数据集只有一个元素,可以使用tf.contrib.data.get_single_element()在嵌套数据集上而不是创建迭代器:

def read_sample(sample_path):
masks_ds = (tf.data.Dataset.list_files(sample_path+"/masks/*.png")
.map(tf.read_file)
.map(lambda x: tf.image.decode_image(x, channels=1))
.batch(1024)) # maximum number of objects
masks = tf.contrib.data.get_single_element(masks_ds)
return tf.reduce_max(masks, axis=0)

ds = tf.data.Dataset.from_tensor_slices(tf.glob("../input/stage1_train/*"))
ds = ds.map(read_sample)
sample = ds.make_one_shot_iterator().get_next()

此外,您可以使用 tf.data.Dataset.flat_map() , tf.data.Dataset.interleave() , 或 tf.contrib.data.parallel_interleave() transformationw 在函数内部执行嵌套的 Dataset 计算,并将结果展平为单个 Dataset。例如,要获取单个 Dataset 中的所有样本:

def read_all_samples(sample_path):
return (tf.data.Dataset.list_files(sample_path+"/masks/*.png")
.map(tf.read_file)
.map(lambda x: tf.image.decode_image(x, channels=1))
.batch(1024)) # maximum number of objects

ds = tf.data.Dataset.from_tensor_slices(tf.glob("../input/stage1_train/*"))
ds = ds.flat_map(read_all_samples)
sample = ds.make_one_shot_iterator().get_next()

关于python - 有没有办法在 Tensorflow 的另一个数据集中使用 tf.data.Dataset?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/49019898/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com