gpt4 book ai didi

tensorflow - 在 Tensorflow 的数据集 API 中使用 flat_map

转载 作者:行者123 更新时间:2023-12-04 01:14:14 25 4
gpt4 key购买 nike

我正在使用数据集 API,读取数据如下:

dataset = tf.data.TFRecordDataset(filename, compression_type="GZIP")
dataset = dataset.map(lambda str: tf.parse_single_example(str, feature_schema))

我现在想用 flat_map为了过滤掉一些,同时在训练时动态复制一些其他样本(这是导致我的模型的输入函数)。
flat_map 的 API需要返回 Dataset对象,但是我不知道如何创建它。这是我想要实现的伪代码实现:
def flat_map_impl(tf_example):
# Pseudo-code:
# if tf_example["a"] == 1:
# return []
# else:
# return [tf_example, tf_example]

dataset.flat_map(flat_map_impl)

我如何在 flat_map 中实现这一点功能?

注意:我想可以通过 py_func 来实现这一点。 ,但我宁愿避免这种情况。

最佳答案

也许是创建 tf.data.Dataset 的最常见方式从 Dataset.flat_map() 返回时是使用 Dataset.from_tensors() Dataset.from_tensor_slices() .在这种情况下,因为 tf_example是字典,用Dataset.from_tensors()的组合大概是最简单的和 Dataset.repeat(count) , 其中一个 conditional expression计算 count :

dataset = tf.data.TFRecordDataset(filename, compression_type="GZIP")
dataset = dataset.map(lambda str: tf.parse_single_example(str, feature_schema))

def flat_map_impl(tf_example):
count = tf.cond(tf.equal(tf_example["a"], 1)),
lambda: tf.constant(0, dtype=tf.int64),
lambda: tf.constant(2, dtype=tf.int64))

return tf.data.Dataset.from_tensors(tf_example).repeat(count)

dataset = dataset.flat_map(flat_map_impl)

关于tensorflow - 在 Tensorflow 的数据集 API 中使用 flat_map,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/50530806/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com