gpt4 book ai didi

python - Tensorflow:使用 py_func 的自定义数据读取器

转载 作者:行者123 更新时间:2023-11-30 22:41:56 25 4
gpt4 key购买 nike

我正在尝试对 hdf5 文件中的数据进行排队。由于 Tensorflow 不支持 hdf5,因此我创建了一个 python 函数,该函数从 hdf5 文件中读取示例,并在到达文件末尾时引发 tf.errors.OutOfRangeError 。然后,我用 tf.py_func 包装这个 python 函数,并将其用作队列的入队操作。

这是我的代码:

import h5py
import tensorflow as tf
from tensorflow.python.framework import errors
import numpy as np

def read_from_hdf5(hdf5_file, batch_size):
h5py_handle = h5py.File(hdf5_file)

# Check shapes from the hdf5 file so that we can set the tensor shapes
feature_shape = h5py_handle['features'].shape[1:]
label_shape = h5py_handle['labels'].shape[1:]

#generator that produces examples for training. It will be wrapped by tf.pyfunc to simulate a reader
def example_generator(h5py_handle):
for i in xrange(0, h5py_handle['features'].shape[0]-batch_size+1, batch_size):
features = h5py_handle['features'][i:i+batch_size]
labels = h5py_handle['labels'][i:i+batch_size]
yield [features, labels]
raise errors.OutOfRangeError(node_def=None, op=None, message='completed all examples in %s'%hdf5_file)

[features_tensor, labels_tensor] = tf.py_func(
example_generator(h5py_handle).next,
[],
[tf.float32, tf.float32],
stateful=True)

# Set the shape so that we can infer sizes etc in later layers.
features_tensor.set_shape([batch_size, feature_shape[0], feature_shape[1], feature_shape[2]])
labels_tensor.set_shape([batch_size, label_shape[0]])

return features_tensor, labels_tensor


def load_data_from_filename_list(hdf5_files, batch_size, shuffle_seed=0):
example_list = [read_from_hdf5(hdf5_file, batch_size) for hdf5_file in hdf5_files]
min_after_dequeue = 10000
capacity = min_after_dequeue + (len(example_list)+1) * batch_size #min_after_dequeue + (num_threads + a small safety margin) * batch_size
features, labels = tf.train.shuffle_batch_join(example_list, batch_size, capacity=capacity, min_after_dequeue=min_after_dequeue, seed=shuffle_seed, enqueue_many=True)
return features, labels, metadata

我预计 tf.errors.OutOfRangeError 将由 QueueRunner 处理,但是,我收到以下错误并且程序崩溃。是否可以从 py_func 进行这种读取,如果可以,我做错了什么?如果没有,我应该使用什么方法?

Traceback (most recent call last):
File "/users/anaconda2/lib/python2.7/site-packages/tensorflow/python/ops/script_ops.py", line 85, in __call__
ret = func(*args)
File "build/bdist.linux-x86_64/egg/tronn/datalayer.py", line 27, in example_generator
raise errors.OutOfRangeError(node_def=None, op=None, message='completed all examples in %s'%hdf5_file)
tensorflow.python.framework.errors_impl.OutOfRangeError: completed all examples
W tensorflow/core/framework/op_kernel.cc:993] Internal: Failed to run py callback pyfunc_13: see error log.

最佳答案

看起来 py_func 中的异常处理不受支持。

考虑py_func.cc中的这段代码

// Invokes the trampoline.
PyObject* result = PyEval_CallObject(trampoline, args);
Py_DECREF(args);
if (result == nullptr) {
if (PyErr_Occurred()) {
// TODO(zhifengc): Consider pretty-print error using LOG(STDERR).
PyErr_Print();
}
return errors::Internal("Failed to run py callback ", call->token,
": see error log.");
}

PyErr_Occurred 在生成异常时设置,因此这将导致执行抛出无法运行 py 回调

py_func 是一个奇怪的生物,因为它运行在 Python 客户端环境中。通常,当 op(如 reader op)失败时,从 TF 运行时传播的它会向 Python 客户端返回 not ok 状态,然后将其转换为 raise_exception_on_not_ok_status 中的 Python 异常(在 client.py:session.run 中) 。由于 py_func 主体在 Python 客户端中运行,因此需要修改 TensorFlow 来处理 PyErr_Occurred,以将不良状态插入 TensorFlow 运行时。

关于python - Tensorflow:使用 py_func 的自定义数据读取器,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/42322176/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com