gpt4 book ai didi

numpy - 带有字符串输入的 Tensorflow 数据集不保留数据类型

转载 作者:行者123 更新时间:2023-12-04 02:37:19 31 4
gpt4 key购买 nike

全部 可重现 下面的代码在 Google Colab 上使用 TF 2.2.0-rc2 运行。

改编 documentation 中的简单示例用于从简单的 Python 列表创建数据集:

import numpy as np
import tensorflow as tf
tf.__version__
# '2.2.0-rc2'
np.version.version
# '1.18.2'

dataset1 = tf.data.Dataset.from_tensor_slices([1, 2, 3])
for element in dataset1:
print(element)
print(type(element.numpy()))

我们得到结果
tf.Tensor(1, shape=(), dtype=int32)
<class 'numpy.int32'>
tf.Tensor(2, shape=(), dtype=int32)
<class 'numpy.int32'>
tf.Tensor(3, shape=(), dtype=int32)
<class 'numpy.int32'>

其中所有数据类型都是 int32 ,正如预期的那样。

但是改变这个简单的例子来提供一个字符串列表而不是整数:
dataset2 = tf.data.Dataset.from_tensor_slices(['1', '2', '3']) 
for element in dataset2:
print(element)
print(type(element.numpy()))

给出结果
tf.Tensor(b'1', shape=(), dtype=string)
<class 'bytes'>
tf.Tensor(b'2', shape=(), dtype=string)
<class 'bytes'>
tf.Tensor(b'3', shape=(), dtype=string)
<class 'bytes'>

令人惊讶的是,尽管张量本身是 dtype=string ,他们的评估类型为 bytes .

这种行为不仅限于 .from_tensor_slices方法;这是 .list_files 的情况(以下代码段在新的 Colab 笔记本中直接运行):
disc_data = tf.data.Dataset.list_files('sample_data/*.csv') # 4 csv files
for element in disc_data:
print(element)
print(type(element.numpy()))

结果是:
tf.Tensor(b'sample_data/california_housing_test.csv', shape=(), dtype=string)
<class 'bytes'>
tf.Tensor(b'sample_data/mnist_train_small.csv', shape=(), dtype=string)
<class 'bytes'>
tf.Tensor(b'sample_data/california_housing_train.csv', shape=(), dtype=string)
<class 'bytes'>
tf.Tensor(b'sample_data/mnist_test.csv', shape=(), dtype=string)
<class 'bytes'>

同样,评估张量中的文件名返回为 bytes , 而不是 string ,尽管张量本身是 dtype=string .
.from_generator 也观察到了类似的行为。方法(此处未显示)。

最后的演示:如 .as_numpy_iterator所示方法 documentation ,以下等式条件被评估为 True :
dataset3 = tf.data.Dataset.from_tensor_slices({'a': ([1, 2], [3, 4]), 
'b': [5, 6]})

list(dataset3.as_numpy_iterator()) == [{'a': (1, 3), 'b': 5},
{'a': (2, 4), 'b': 6}]
# True

但是如果我们改变 b 的元素作为字符串,相等条件现在令人惊讶地评估为 False !
dataset4 = tf.data.Dataset.from_tensor_slices({'a': ([1, 2], [3, 4]), 
'b': ['5', '6']}) # change elements of b to strings

list(dataset4.as_numpy_iterator()) == [{'a': (1, 3), 'b': '5'}, # here
{'a': (2, 4), 'b': '6'}] # also
# False

可能是由于不同的数据类型,因为值本身显然是相同的。

我不是通过学术实验偶然发现这种行为的。我正在尝试使用自定义函数将我的数据传递给 TF 数据集,这些函数从表单的磁盘读取文件对
f = ['filename1', 'filename2']

哪些自定义函数可以很好地独立工作,但通过 TF 数据集映射给出
RuntimeError: not a string

如果返回的数据类型确实是 bytes,那么在此挖掘之后,这似乎不是无法解释的。而不是 string .

那么,这是一个错误(看起来),还是我在这里遗漏了什么?

最佳答案

这是一个已知的行为:

发件人:https://github.com/tensorflow/tensorflow/issues/5552#issuecomment-260455136

TensorFlow converts str to bytes in most places, including sess.run, and this is unlikely to change. The user is free to convert back, but unfortunately it's too large a change to add a unicode dtype to the core. Closing as won't fix for now.



我想 TensorFlow 2.x 没有任何改变 - 仍有一些地方将字符串转换为字节,您必须手动处理。

来自 issue你已经打开了自己,似乎他们将这个主题视为 Numpy 的问题,而不是 Tensorflow 本身的问题。

关于numpy - 带有字符串输入的 Tensorflow 数据集不保留数据类型,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/61131730/

31 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com