gpt4 book ai didi

python - 使用 py_func 时,数据集 API 不会为其输出张量传递维度信息

转载 作者:太空狗 更新时间:2023-10-29 23:56:59 27 4
gpt4 key购买 nike

要重现我的问题,请先尝试这个(使用 py_func 进行映射):

import tensorflow as tf
import numpy as np
def image_parser(image_name):
a = np.array([1.0,2.0,3.0], dtype=np.float32)
return a

images = [[1,2,3],[4,5,6]]
im_dataset = tf.data.Dataset.from_tensor_slices(images)
im_dataset = im_dataset.map(lambda image:tuple(tf.py_func(image_parser, [image], [tf.float32])), num_parallel_calls = 2)
im_dataset = im_dataset.prefetch(4)
iterator = im_dataset.make_initializable_iterator()
print(im_dataset.output_shapes)

它会给你 (TensorShape(None),)

但是,如果您尝试这样做(使用直接 tensorflow 映射而不是 py_func):

import tensorflow as tf
import numpy as np

def image_parser(image_name)
return image_name

images = [[1,2,3],[4,5,6]]
im_dataset = tf.data.Dataset.from_tensor_slices(images)
im_dataset = im_dataset.map(image_parser)
im_dataset = im_dataset.prefetch(4)
iterator = im_dataset.make_initializable_iterator()
print(im_dataset.output_shapes)

它会给你准确的张量维度 (3,)

最佳答案

这是 tf.py_func 的一个普遍问题,这是有意为之的,因为 TensorFlow 无法自行推断输出形状,例如参见 this answer .

如果需要,您可以自己设置形状,方法是将 tf.py_func 移动到解析函数中:

def parser(x):
a = np.array([1.0,2.0,3.0])
y = tf.py_func(lambda: a, [], tf.float32)
y.set_shape((3,))
return y

dataset = tf.data.Dataset.range(10)
dataset = dataset.map(parser)
print(dataset.output_shapes) # will correctly print (3,)

关于python - 使用 py_func 时,数据集 API 不会为其输出张量传递维度信息,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/48823848/

27 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com