gpt4 book ai didi

tensorflow-datasets - 如果 a 被打乱,tf.data.Dataset.zip(a, b) 会改变元素的顺序

转载 作者:行者123 更新时间:2023-12-05 07:12:56 26 4
gpt4 key购买 nike

我正在准备一个数据集,然后在存储输出之前训练一个模型(为了知识蒸馏的目的)

为了以 tfrecords 格式存储它们,我需要使用 .zip() 函数。

我用以下代码重现了错误/错误。我的实际训练文件有数百行,所以我没有在此处包含它们。

我使用 tensorflow 2.1。和 ubuntu 18.04 上的 python 3.7

我无法解决的问题是:

数据被打乱(没关系)。但是在压缩之后,元组之间的顺序不同(这是不对的)。

import tensorflow as tf 
ds = tf.data.Dataset.from_tensor_slices([1,2,3,4, 5])

#prepare dataset for training
batch_size=2
ds = ds.cache().repeat().shuffle(buffer_size=5, reshuffle_each_iteration=True).batch(batch_size)

#create model. here: map identity function
model = tf.keras.models.Sequential([tf.keras.layers.Lambda(lambda x: x , input_shape=(1,))])

#train with model.fit()

#make predictions.
pred = model.predict(ds, steps=5//batch_size)

#prepare for saving to tfrecords
ds = ds.unbatch()
ds = ds.take(5)
pred = tf.data.Dataset.from_tensor_slices(pred)
combined = tf.data.Dataset.zip((ds, pred))

#show unwanted behaviour
for (a),c in combined:
print(a,c)

代码片段的输出显示每行的元素不匹配。 (例如第 1 行:3 应该映射到 3)

tf.Tensor(3, shape=(), dtype=int32) tf.Tensor([4.], shape=(1,), dtype=float32)
tf.Tensor(1, shape=(), dtype=int32) tf.Tensor([1.], shape=(1,), dtype=float32)
tf.Tensor(4, shape=(), dtype=int32) tf.Tensor([1.], shape=(1,), dtype=float32)
tf.Tensor(3, shape=(), dtype=int32) tf.Tensor([2.], shape=(1,), dtype=float32)

最佳答案

Tensorflow 在数据集的每次迭代中应用随机播放。Zip 是这些迭代之一,这就是为什么 model.predict 中的顺序与 zip 中的顺序不匹配的原因(两次都有随机播放)

无论如何,对于 predict 你并不真的需要打乱数据集。预测不应取决于模型在先前预测中看到的内容。

关于tensorflow-datasets - 如果 a 被打乱,tf.data.Dataset.zip(a, b) 会改变元素的顺序,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/60270718/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com