gpt4 book ai didi

python - Tensorflow 数据集 API - .from_tensor_slices()/.from_tensor() - 无法创建内容大于 2gb 的张量原型(prototype)

转载 作者:太空宇宙 更新时间:2023-11-03 11:16:15 28 4
gpt4 key购买 nike

所以我想使用数据集 API 来批处理我的大型数据集 (~8GB),因为我在使用 GPU 时遇到大量空闲时间,因为我正在使用 feed_dict 将数据从 python 传递到 Tensorflow。

当我按照此处提到的教程进行操作时:

https://github.com/aymericdamien/TensorFlow-Examples/blob/master/examples/5_DataManagement/tensorflow_dataset_api.py

运行我的简单代码时:

one_hot_dataset = np.load("one_hot_dataset.npy")
dataset = tf.data.Dataset.from_tensor_slices(one_hot_dataset)

我在使用 TensorFlow 1.8 和 Python 3.5 时收到错误消息:

Traceback (most recent call last):

File "<ipython-input-17-412a606c772f>", line 1, in <module>
dataset = tf.data.Dataset.from_tensor_slices((one_hot_dataset))

File "/anaconda2/envs/tf/lib/python3.5/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 235, in from_tensor_slices
return TensorSliceDataset(tensors)

File "/anaconda2/envs/tf/lib/python3.5/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 1030, in __init__
for i, t in enumerate(nest.flatten(tensors))

File "/anaconda2/envs/tf/lib/python3.5/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 1030, in <listcomp>
for i, t in enumerate(nest.flatten(tensors))

File "/anaconda2/envs/tf/lib/python3.5/site-packages/tensorflow/python/framework/ops.py", line 1014, in convert_to_tensor
as_ref=False)

File "/anaconda2/envs/tf/lib/python3.5/site-packages/tensorflow/python/framework/ops.py", line 1104, in internal_convert_to_tensor
ret = conversion_func(value, dtype=dtype, name=name, as_ref=as_ref)

File "/anaconda2/envs/tf/lib/python3.5/site-packages/tensorflow/python/framework/constant_op.py", line 235, in _constant_tensor_conversion_function
return constant(v, dtype=dtype, name=name)

File "/anaconda2/envs/tf/lib/python3.5/site-packages/tensorflow/python/framework/constant_op.py", line 214, in constant
value, dtype=dtype, shape=shape, verify_shape=verify_shape))

File "/anaconda2/envs/tf/lib/python3.5/site-packages/tensorflow/python/framework/tensor_util.py", line 496, in make_tensor_proto
"Cannot create a tensor proto whose content is larger than 2GB.")

ValueError: Cannot create a tensor proto whose content is larger than 2GB.

我该如何解决这个问题?我认为原因很明显,但是 tf 开发人员将输入数据限制为 2GB 时是怎么想的?!?我真的无法理解这种合理性,在处理更大的数据集时有什么解决方法?

我在谷歌上搜索了很多,但找不到任何类似的错误消息。当我使用 numpy 数据集的 FITFH 时,上述步骤没有任何问题。

我需要以某种方式告诉 TensorFlow 我实际上将逐批加载数据并且可能想要预取一些批处理以使我的 GPU 保持忙碌。但它似乎试图一次加载整个 numpy 数据集。那么使用数据集 API 的好处是什么,因为我可以通过简单地尝试将我的 numpy 数据集作为 tf.constant 加载到 TensorFlow 图中来重现此错误,这显然不适合并且我收到 OOM 错误。

感谢提示和故障排除提示!

最佳答案

tf.data 用户指南 (https://www.tensorflow.org/guide/datasets) 的“使用 NumPy 数组”部分解决了这个问题。

基本上,创建一个 dataset.make_initializable_iterator() 迭代器并在运行时提供您的数据。

如果由于某种原因这不起作用,您可以将数据写入文件或从 Python 生成器 (https://www.tensorflow.org/api_docs/python/tf/data/Dataset#from_generator) 创建数据集,您可以在其中放置任意 Python 代码,包括对 numpy 数组进行切片并生成切片。

关于python - Tensorflow 数据集 API - .from_tensor_slices()/.from_tensor() - 无法创建内容大于 2gb 的张量原型(prototype),我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/51118565/

28 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com