gpt4 book ai didi

python - 使用 tensorflow.data.Dataset.from_generator 时出现 InvalidArgumentError

转载 作者:太空宇宙 更新时间:2023-11-04 04:45:43 37 4
gpt4 key购买 nike

我正在尝试生成我自己的图像数据集以在单个 GPU 上使用 Tensorflow Dataset API 测量推理性能

resolutions = [
(2048, 1080)
]

def generate_image(size, channels):
image_value = random.random()
image_shape = [1, size[1], size[0], channels]
return tf.constant(
value=image_value,
shape=image_shape,
dtype=tf.float32)

def generate_single_input(size):
source = generate_image(size, 3)
target = generate_image(size, 3)
return source, target

def input_generator_fn():
for res in resolutions:
for i in range(10):
yield generate_single_input(res)


def benchmark():
...
ds = tf.data.Dataset.from_generator(
generator=input_generator_fn,
output_types=(tf.float32, tf.float32),
output_shapes=(tf.TensorShape([1, 1080, 2048, 3]),
tf.TensorShape([1, 1080, 2048, 3])))
iterator = ds.make_one_shot_iterator()
next_record = iterator.get_next()

inputs = next_record[0]
outputs = next_record[1]

predictions = {
'input_images': inputs
'output_images': outputs
}
session = tf.Session()
with session:
tf.global_variables_initializer()
for res in resolutions:
for i in range(10):
session.run(predictions)
.....

但我在运行后观察到以下异常:

2018-04-06 13:38:44.050448: W tensorflow/core/framework/op_kernel.cc:1198] Invalid argument: ValueError: setting an array element with a sequence.

2018-04-06 13:38:44.050581: W tensorflow/core/framework/op_kernel.cc:1198] Invalid argument: ValueError: setting an array element with a sequence.
[[Node: PyFunc = PyFunc[Tin=[DT_INT64], Tout=[DT_FLOAT, DT_FLOAT], token="pyfunc_1"](arg0)]]

Traceback (most recent call last):
File "tensorflow/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1350, in _do_call
return fn(*args)

File "tensorflow/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1329, in _run_fn
status, run_metadata)

File "tensorflow/lib/python3.5/site-packages/tensorflow/python/framework/errors_impl.py", line 473, in __exit__
c_api.TF_GetCode(self.status.status))
tensorflow.python.framework.errors_impl.InvalidArgumentError: ValueError: setting an array element with a sequence.
[[Node: PyFunc = PyFunc[Tin=[DT_INT64], Tout=[DT_FLOAT, DT_FLOAT], token="pyfunc_1"](arg0)]]
[[Node: IteratorGetNext = IteratorGetNext[output_shapes=[[1,1080,2048,3], [1,1080,2048,3]], output_types=[DT_FLOAT, DT_FLOAT], _device="/job:localhost/replica:0/task:0/device:CPU:0"](OneShotIterator)]]

最佳答案

简而言之,原因是 from_generator 可以展平 NumPy 数组,但不能展平 Tensor。

这是一个较短的代码,可以重现错误:

import tensorflow as tf
import numpy as np

print(tf.__version__)
def g():
img = tf.random_uniform([3])
# img = np.random.rand(3)
# img = tf.convert_to_tensor(img)
yield img

dataset = tf.data.Dataset.from_generator(g, tf.float64, tf.TensorShape([3]))
iterator = dataset.make_one_shot_iterator()
next_iterator = iterator.get_next()

sess = tf.Session()
sess.run(next_iterator)

1.14 版本中的错误信息很有帮助。 (确切的代码行会因版本不同而有所不同,但我检查过我使用的1.12和1.13原因是一样的。)

InvalidArgumentError: TypeError: `generator` yielded an element that could not be converted to the expected type. The expected type was float64, but the yielded element was Tensor("random_uniform:0", shape=(3,), dtype=float32).
Traceback (most recent call last):

File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/data/ops/dataset_ops.py", line 530, in generator_py_func
ret, dtype=dtype.as_numpy_dtype))

File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/script_ops.py", line 169, in _convert
result = np.asarray(value, dtype=dtype, order="C")

File "/usr/local/lib/python3.6/dist-packages/numpy/core/numeric.py", line 538, in asarray
return array(a, dtype, copy=False, order=order)

ValueError: setting an array element with a sequence.

当生成的元素是 Tensor 时,from_generator 会将其展平为 output_types。转换功能不起作用。

要解决此问题,只需在生成器生成张量时不要使用 from_generator 即可。您可以使用 from_tensorsfrom_tensor_slices

img = tf.random_uniform([3])

dataset = tf.data.Dataset.from_tensors(img).repeat()
iterator = dataset.make_initializable_iterator()
next_iterator = iterator.get_next()

sess = tf.Session()
sess.run(iterator.initializer)
sess.run(next_iterator)

关于python - 使用 tensorflow.data.Dataset.from_generator 时出现 InvalidArgumentError,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/49691315/

37 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com