gpt4 book ai didi

python-3.x - 如何在 TensorFlow 中访问数据集的特征字典

转载 作者:行者123 更新时间:2023-12-05 02:11:57 26 4
gpt4 key购买 nike

借助 tensorflow-datasets,我将 MNIST 数据集集成到 Tensorflow 中,现在想使用 Matplotlib 可视化单个图像。我是根据本指南完成的:https://www.tensorflow.org/datasets/overview

不幸的是,我在执行过程中收到一条错误消息。但它在指南中效果很好。

根据指南,您必须使用 take() 函数创建一个只有一张图像的新数据集。然后在指南中访问这些功能。在我尝试的过程中,我总是收到一条错误消息。

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import matplotlib.pyplot as plt
import numpy as np
import tensorflow.compat.v1 as tf

import tensorflow_datasets as tfds



mnist_train, info = tfds.load(name="mnist", split=tfds.Split.TRAIN, with_info=True)
assert isinstance(mnist_train, tf.data.Dataset)

mnist_example = mnist_train.take(50)

#The error is raised in the next line.
image = mnist_example["image"]
label = mnist_example["label"]

plt.imshow(image.numpy()[:, :, 0].astype(np.float32), cmap=plt.get_cmap("gray"))
print("Label: %d" % label.numpy())

这是错误信息:

Traceback (most recent call last):
File "D:/mnist/model.py", line 24, in <module>
image = mnist_example["image"]
TypeError: 'DatasetV1Adapter' object is not subscriptable

有谁知道我该如何解决这个问题?经过大量研究,我仍然没有找到解决方案。

最佳答案

急切执行

写代码tf.enable_eager_execution()

为什么?

因为如果你不这样做,你将需要创建图表并执行 session.run() 以获得一些样本

eager execution 定义 ( reference ):

TensorFlow's eager execution is an imperative programming environment that evaluates >operations immediately, without building graphs: operations return concrete values >instead of constructing a computational graph to run later

然后

如何访问数据集对象中的样本

您只需要遍历 DatasetV1Adapter 对象

通过转换为 numpy 访问一些示例的几种方法:

1.

mnist_example = mnist_train.take(50)
for sample in mnist_example:
image, label = sample["image"].numpy(), sample["label"].numpy()
plt.imshow(image[:, :, 0].astype(np.uint8), cmap=plt.get_cmap("gray"))
plt.show()
print("Label: %d" % label)

2.

mnist_example = tfds.as_numpy(mnist_example, graph=None)
for sample in mnist_example:
image, label = sample["image"], sample["label"]
plt.imshow(image[:, :, 0].astype(np.uint8), cmap=plt.get_cmap("gray"))
plt.show()
print("Label: %d" % label)

注意 1:如果您想要 numpy 数组中的所有 50 个样本,您可以创建一个空数组,例如 np.zeros((28, 28, 50), dtype=np .uint8) 数组并将这些图像分配给它的元素。

注意2:为了显示效果,不要转成np.float32,没用的,图片是uint8格式/范围(默认不归一化) )

关于python-3.x - 如何在 TensorFlow 中访问数据集的特征字典,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/56759226/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com