gpt4 book ai didi

python - TensorFlow py_function 嵌套输出类型?

转载 作者:行者123 更新时间:2023-12-01 06:22:36 24 4
gpt4 key购买 nike

是否可以为 TensorFlow 指定嵌套输出类型 py_function

作为一个具体案例,我想要py_function返回类型为 ((tf.float32, tf.float32), (tf.float32, tf.float32)) ,其中各个元素不一定具有相同的尺寸。有没有办法为py_function指定这个?

就像深入了解为什么这对我的情况有用一样,我有一个 tf.data.Dataset带有文件路径列表。 py_function获取这些文件路径之一并从文件中生成一个反例和正例以及相应的标签,从而生成 ((positive_data, positive_label), (negative_data, negative_label)) (注意,标签不一定是单个值,但它们的形状也与输入数据不同)。这个py_function可以映射到数据集,并且(使用上述结构)将一级展平以生成 (data, label) 的训练数据集结构化元素。虽然可以有一个解决方法,将数据和标签堆叠在 py_function 中。然后拆开(或者从 py_function 开始完全非结构化,然后才配对),它会导致困惑和困惑的设置。如果py_function可以直接输出 ((tf.float32, tf.float32), (tf.float32, tf.float32))类型,这将导致更清晰的设置。

最佳答案

tf.py_function的输出类型不能是嵌套序列。但是,当将 tf.py_functiontf.data API 结合使用时,您需要创建一个包装函数(下例中的 tf_foo) ,您可以将输出嵌套在该函数中。

import tensorflow as tf

# The python function.
def foo(x):
return x, x, x, x

# Wrap the python function to make it compatible with `tf.data.Dataset.map`.
def tf_foo(x):
a, b, c, d = tf.py_function(foo, [x], Tout=[tf.float32, tf.float32, tf.float32, tf.float32])
return (a, b), (c, d)

dset = tf.data.Dataset.from_tensor_slices([0, 1, 2, 3, 4])
dset.map(tf_foo)
# <MapDataset shapes: ((<unknown>, <unknown>), (<unknown>, <unknown>)),
# types: ((tf.float32, tf.float32), (tf.float32, tf.float32))>

这也在 TensorFlow tutorial 中得到了证明。 .

关于python - TensorFlow py_function 嵌套输出类型?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/60290340/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com