gpt4 book ai didi

python - 使用 tf.data.Dataset.batch 时出现的问题

转载 作者:行者123 更新时间:2023-11-30 09:18:10 26 4
gpt4 key购买 nike

我想弄清楚 tf.data.Dataset.batch 如何处理我的数据集。数据集如下:

dataset = tf.convert_to_tensor([[5.1, 3.3, 1.7, 0.5, ],
[5.9, 3.0, 4.2, 1.5],
[6.9, 3.1, 5.4, 2.1],
[2.3, 1.3, 6.4, 9.3]])

然后我使用批处理方法:

dataset = dataset.batch(2)

并迭代数据集一次。

x = tfe.Iterator(dataset).next()

正如我所想,结果应该是一个 2*4 数组,但它返回整个 4*4 数据集。

有人能给我一些关于如何应用batch方法的详细信息吗?

最佳答案

您需要将您的数据集 Tensor转换为TensorSliceDataset,即告诉Tensorflow对张量进行切片并制作它的数据集。

import tensorflow as tf

data = tf.convert_to_tensor([[5.1, 3.3, 1.7, 0.5],
[5.9, 3.0, 4.2, 1.5],
[6.9, 3.1, 5.4, 2.1],
[2.3, 1.3, 6.4, 9.3]])

dataset = tf.data.Dataset.from_tensor_slices(data).batch(2)
batch_iterator = dataset.make_one_shot_iterator().get_next()

sess = tf.InteractiveSession()
batch = sess.run(batch_iterator)
print(batch)
# [[ 5.1 3.3 1.7 0.5 ]
# [ 5.9 3. 4.2 1.5 ]]

关于python - 使用 tf.data.Dataset.batch 时出现的问题,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/49874590/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com