gpt4 book ai didi

python - 将 tf.dataset 中的每个样本映射到 id

转载 作者:行者123 更新时间:2023-12-01 06:20:53 24 4
gpt4 key购买 nike

出于测试目的,我想为 tf.dataset 中的每个样本附加一个 id。只需向上计数就足够了。

我的数据集的类型为 FlatMapDataset fwiw。

for entry in img_ds:
print(entry.shape)

(128, 128, 3)
(128, 128, 3)
(128, 128, 3)
(128, 128, 3)
...

我尝试的是有一个映射函数,在其内部定义一个计数器并向上计数:

@staticmethod
def map_to_id(img):
try:
ExperimentalPipeline.map_to_id.id_counter += 1
except AttributeError:
ExperimentalPipeline.map_to_id.id_counter = 0
return img, ExperimentalPipeline.map_to_id.id_counter

然后使用 tf.data 中的 Dataset.map 将 id 附加到每个样本:

img_ds = img_ds.map(ExperimentalPipeline.map_to_id)

不幸的是,这不起作用,每个样本的 id 为零:

for i, id in img_ds:
print(f"{i.shape}, {id}")

(128, 128, 3), 0
(128, 128, 3), 0
(128, 128, 3), 0
(128, 128, 3), 0
...

我还注意到我的 map_to_id 函数仅被调用一次。

@staticmethod
def map_to_id(img):
print("enter map_to_id")
try:
ExperimentalPipeline.map_to_id.id_counter += 1
except AttributeError:
print("caught exception")
ExperimentalPipeline.map_to_id.id_counter = np.random.randint(1000)
return img, ExperimentalPipeline.map_to_id.id_counter

enter map_to_id
caught exception
(128, 128, 3), 889
(128, 128, 3), 889
(128, 128, 3), 889
(128, 128, 3), 889

我想我不明白 Dataset.map 应该如何工作。我认为它会获取正在调用的数据集中的每个样本,并以该样本作为参数调用提供的函数。
有人可以帮我解决这个问题吗?

最佳答案

TensorFlow 将运行一次 map 函数,将该函数编译为 TensorFlow 操作。然后这些操作(而不是原始的 python 函数)将应用于数据集的每个元素。如果你想为每个元素运行原始的 python 函数,你可以使用 py_function相反。

在这种特定情况下,您想要附加元素 ID,可以使用 Dataset.enumerate实现您的目标:

img_ds = img_ds.enumerate()

关于python - 将 tf.dataset 中的每个样本映射到 id,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/60365843/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com