gpt4 book ai didi

python - 如何使用嵌套形状的 tf.data.Dataset.padded_batch?

转载 作者:太空宇宙 更新时间:2023-11-03 15:50:32 25 4
gpt4 key购买 nike

我正在为每个元素构建一个数据集,其中包含两个形状为 [batch,width,heigh,3] 和 [batch,class] 的张量。为简单起见,假设类 = 5。

您向 dataset.padded_batch(1000,shape) 提供什么形状,以便沿宽度/高度/3 轴填充图像?

我尝试了以下方法:

tf.TensorShape([[None,None,None,3],[None,5]])
[tf.TensorShape([None,None,None,3]),tf.TensorShape([None,5])]
[[None,None,None,3],[None,5]]
([None,None,None,3],[None,5])
(tf.TensorShape([None,None,None,3]),tf.TensorShape([None,5])‌​)

每次引发 TypeError

The docs状态:

padded_shapes: A nested structure of tf.TensorShape or tf.int64 vector tensor-like objects representing the shape to which the respective component of each input element should be padded prior to batching. Any unknown dimensions (e.g. tf.Dimension(None) in a tf.TensorShape or -1 in a tensor-like object) will be padded to the maximum size of that dimension in each batch.

相关代码:

dataset = tf.data.Dataset.from_generator(generator,tf.float32)
shapes = (tf.TensorShape([None,None,None,3]),tf.TensorShape([None,5]))
batch = dataset.padded_batch(1,shapes)

最佳答案

感谢 marry 找到解决方案。事实证明,from_generator 中的类型必须与条目中的张量数量相匹配。

新代码:

dataset = tf.data.Dataset.from_generator(generator,(tf.float32,tf.float32))
shapes = (tf.TensorShape([None,None,None,3]),tf.TensorShape([None,5]))
batch = dataset.padded_batch(1,shapes)

关于python - 如何使用嵌套形状的 tf.data.Dataset.padded_batch?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/47103249/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com