gpt4 book ai didi

tensorflow - 从 tf.data.Dataset.map() 返回数据集会导致 'TensorSliceDataset' 对象没有属性 'get_shape' 错误

转载 作者:行者123 更新时间:2023-11-30 09:26:56 31 4
gpt4 key购买 nike

我正在使用数据集 API 创建输入管道。我以类似于以下的模式使用 tf.data.Dataset.map() 方法:

def mapped_fn(_):
X = tf.random_uniform([3,3])
y = tf.random_uniform([3,1])
dataset = tf.data.Dataset.from_tensor_slices((X,y))
return dataset

with tf.Session() as sess:
first = tf.random_uniform([1,2])
unimportant_dataset = tf.data.Dataset.from_tensors(first)
dataset = unimportant_dataset.map(mapped_fn)
sess.run(dataset)

我收到以下错误:AttributeError:“TensorSliceDataset”对象没有属性“get_shape”

总体背景是,mapped_fn 从 .tfrecords 文件反序列化示例 protobuf(在本例中由 unimportant_dataset 表示), reshape 特征向量 (X ),并且需要返回一个数据集,其中包含由新特征向量(在本例中为 (3,) 形状)中的切片定义的元素。返回 ZipDataset 时我遇到了类似的错误。提前致谢!

最佳答案

DomJack's answer关于 Dataset.map() 的签名绝对正确:它期望传递的 mapped_fn 的返回值是一个或多个张量(或稀疏张量)。

如果您确实有一个返回Dataset的函数,则可以使用Dataset.flat_map()将所有返回的数据集展平并连接成一个数据集,如下:

def mapped_fn(_):
X = tf.random_uniform([3,3])
y = tf.random_uniform([3,1])
dataset = tf.data.Dataset.from_tensor_slices((X,y))
return dataset

# Generate 100 dummy elements.
unimportant_dataset = tf.data.Dataset.range(100)

# Convert each dummy element into a dataset of 3 nested elements, and concatenate them.
dataset = unimportant_dataset.flat_map(mapped_fn)

关于tensorflow - 从 tf.data.Dataset.map() 返回数据集会导致 'TensorSliceDataset' 对象没有属性 'get_shape' 错误,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/50809257/

31 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com