gpt4 book ai didi

tensorflow - Spark Dataframe to Tensorflow Dataset (tf.data API)

转载 作者:行者123 更新时间:2023-12-04 04:09:49 27 4
gpt4 key购买 nike

我正在尝试将 aspark 数据帧转换为 tensorflow 记录,然后从 tensorflow 中将其作为数据集调用,以便为我的模型获取输入。这是行不通的。

我的尝试如下:

1)使用spark-tensorflow-connector库的jar获取sparksession:

spark = SparkSession.builder.config(conf=SparkConf().set("spark.jars", "path/to/spark-tensorflow-connector_2.11-1.6.0.jar").getOrCreate()

2)将dataframe保存为TFRecord(我这里以数据集为例):
df = spark.createDataFrame([(1, 120), (2, 130), (2, 140)], ['A', 'B'])

path='path/example.tfrecord'
df.write.format("tfrecords").mode("overwrite").option("recordType", "Example").save(path)

3) 将 tfrecord 文件加载到 tf.data API 中(为了简单起见,我只是将 'A' 作为一个特性):
path2 = "path/example.tfrecord/*"
dataset=tf.data.TFRecordDataset(tf.compat.v1.gfile.Glob(path2))

def parse_func(buff):
features = {'A': tf.compat.v1.FixedLenFeature(shape=[5], dtype=tf.int64)}
tensor_dict = tf.compat.v1.parse_single_example(buff, features)
return tensor_dict['A']

train_dataset = dataset.map(parse_func).batch(1)

但是当我尝试打印数据集迭代器时:
for x in train_dataset:
print(x)

我收到以下错误:
2020-05-21 06:43:53.579843: W tensorflow/core/framework/op_kernel.cc:1655] OP_REQUIRES failed at iterator_ops.cc:941 : Data loss: corrupted record at 0
Traceback (most recent call last):
File "/home/patrizio/PycharmProjects/pyspark-config/venv/lib/python3.6/site-packages/tensorflow_core/python/eager/context.py", line 1897, in execution_mode
2020-05-21 06:43:53.580090: W tensorflow/core/framework/op_kernel.cc:1655] OP_REQUIRES failed at example_parsing_ops.cc:93 : Invalid argument: Key: A. Can't parse serialized Example.
2020-05-21 06:43:53.580567: W tensorflow/core/framework/op_kernel.cc:1655] OP_REQUIRES failed at example_parsing_ops.cc:93 : Invalid argument: Key: A. Can't parse serialized Example.
yield
File "/home/patrizio/PycharmProjects/pyspark-config/venv/lib/python3.6/site-packages/tensorflow_core/python/data/ops/iterator_ops.py", line 659, in _next_internal
output_shapes=self._flat_output_shapes)
File "/home/patrizio/PycharmProjects/pyspark-config/venv/lib/python3.6/site-packages/tensorflow_core/python/ops/gen_dataset_ops.py", line 2479, in iterator_get_next_sync
_ops.raise_from_not_ok_status(e, name)
File "/home/patrizio/PycharmProjects/pyspark-config/venv/lib/python3.6/site-packages/tensorflow_core/python/framework/ops.py", line 6606, in raise_from_not_ok_status
six.raise_from(core._status_to_exception(e.code, message), None)
File "<string>", line 3, in raise_from
tensorflow.python.framework.errors_impl.DataLossError: corrupted record at 0 [Op:IteratorGetNextSync]

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
File "/snap/pycharm-community/194/plugins/python-ce/helpers/pycharm/_jb_unittest_runner.py", line 35, in <module>
sys.exit(main(argv=args, module=None, testRunner=unittestpy.TeamcityTestRunner, buffer=not JB_DISABLE_BUFFERING))
File "/usr/lib/python3.6/unittest/main.py", line 94, in __init__
self.parseArgs(argv)
File "/usr/lib/python3.6/unittest/main.py", line 141, in parseArgs
self.createTests()
File "/usr/lib/python3.6/unittest/main.py", line 148, in createTests
self.module)
File "/usr/lib/python3.6/unittest/loader.py", line 219, in loadTestsFromNames
suites = [self.loadTestsFromName(name, module) for name in names]
File "/usr/lib/python3.6/unittest/loader.py", line 219, in <listcomp>
suites = [self.loadTestsFromName(name, module) for name in names]
File "/usr/lib/python3.6/unittest/loader.py", line 204, in loadTestsFromName
test = obj()
File "/home/patrizio/PycharmProjects/pyspark-config/tests/python/output/test_output.py", line 75, in test_TFRecord_new
for x in train_dataset:
File "/home/patrizio/PycharmProjects/pyspark-config/venv/lib/python3.6/site-packages/tensorflow_core/python/data/ops/iterator_ops.py", line 630, in __next__
return self.next()
File "/home/patrizio/PycharmProjects/pyspark-config/venv/lib/python3.6/site-packages/tensorflow_core/python/data/ops/iterator_ops.py", line 674, in next
return self._next_internal()
File "/home/patrizio/PycharmProjects/pyspark-config/venv/lib/python3.6/site-packages/tensorflow_core/python/data/ops/iterator_ops.py", line 665, in _next_internal
return structure.from_compatible_tensor_list(self._element_spec, ret)
File "/usr/lib/python3.6/contextlib.py", line 99, in __exit__
self.gen.throw(type, value, traceback)
File "/home/patrizio/PycharmProjects/pyspark-config/venv/lib/python3.6/site-packages/tensorflow_core/python/eager/context.py", line 1900, in execution_mode
executor_new.wait()
File "/home/patrizio/PycharmProjects/pyspark-config/venv/lib/python3.6/site-packages/tensorflow_core/python/eager/executor.py", line 67, in wait
pywrap_tensorflow.TFE_ExecutorWaitForAllPendingNodes(self._handle)
tensorflow.python.framework.errors_impl.DataLossError: corrupted record at 0

有谁知道如何处理这个问题?

非常感谢您提前。

最佳答案

我希望这仍然相关。
您的 glob 表达式不正确。 Spark 必须在将示例保存到 TFRecord 时创建了一个 _SUCCESS 文件。在模式中包含扩展名

path2 = "path/example.tfrecord/*.tfrecord"
您还可以通过简单的评估来检查 python 将要读取的文件列表
tf.io.gfile.glob(path)
我会使用这个 API 而不是旧的 compat.v1 . tf.io.FixedLenFeature的形状也是错的。每个值都是一个标量,而不是长度为 5 的向量。正确的形状就是 [] .
def parse_func(buff):
features = {'A': tf.io.FixedLenFeature(shape=[], dtype=tf.int64)}
tensor_dict = tf.io.parse_single_example(buff, features)
return tensor_dict

train_dataset = dataset.map(parse_func).batch(3)
如果你真的想花哨,使用 tf.io.parse_example更好,因为它执行矢量化解析。但是,您需要在解析之前进行批处理。
def parse_func(buff):
features = {'A': tf.io.FixedLenFeature(shape=[], dtype=tf.int64)}
tensor_dict = tf.io.parse_example(buff, features)
return tensor_dict

train_dataset = dataset.batch(3).map(parse_func)

One might see performance advantages by batching Example protos with parse_example instead of using this function directly. (source)

关于tensorflow - Spark Dataframe to Tensorflow Dataset (tf.data API),我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/61927639/

27 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com