gpt4 book ai didi

python - 如何使用 Dataflow 在 apache beam 中跳过 io 级别的错误元素?

转载 作者:太空宇宙 更新时间:2023-11-03 19:41:32 25 4
gpt4 key购买 nike

我正在对 GCP 中存储的 tfrecords 进行一些分析,但文件内的一些 tfrecords 已损坏,因此当我运行管道并收到四个以上错误时,我的管道因 this 而中断。 。我认为这是 DataFlowRunner 的约束,而不是 Beam 的约束。

这是我的处理脚本

import argparse
import apache_beam as beam
from apache_beam.options.pipeline_options import PipelineOptions
from apache_beam.metrics.metric import Metrics

from apache_beam.runners.direct import direct_runner
import tensorflow as tf

input_ = "path_to_bucket"


def _parse_example(serialized_example):
"""Return inputs and targets Tensors from a serialized tf.Example."""
data_fields = {
"inputs": tf.io.VarLenFeature(tf.int64),
"targets": tf.io.VarLenFeature(tf.int64)
}
parsed = tf.io.parse_single_example(serialized_example, data_fields)
inputs = tf.sparse.to_dense(parsed["inputs"])
targets = tf.sparse.to_dense(parsed["targets"])
return inputs, targets


class MyFnDo(beam.DoFn):

def __init__(self):
beam.DoFn.__init__(self)
self.input_tokens = Metrics.distribution(self.__class__, 'input_tokens')
self.output_tokens = Metrics.distribution(self.__class__, 'output_tokens')
self.num_examples = Metrics.counter(self.__class__, 'num_examples')
self.decode_errors = Metrics.counter(self.__class__, 'decode_errors')

def process(self, element):
# inputs = element.features.feature['inputs'].int64_list.value
# outputs = element.features.feature['outputs'].int64_list.value
try:
inputs, outputs = _parse_example(element)
self.input_tokens.update(len(inputs))
self.output_tokens.update(len(outputs))
self.num_examples.inc()
except Exception:
self.decode_errors.inc()



def main(argv):
parser = argparse.ArgumentParser()
parser.add_argument('--input', dest='input', default=input_, help='input tfrecords')
# parser.add_argument('--output', dest='output', default='gs://', help='output file')

known_args, pipeline_args = parser.parse_known_args(argv)
pipeline_options = PipelineOptions(pipeline_args)

with beam.Pipeline(options=pipeline_options) as p:
tfrecords = p | "Read TFRecords" >> beam.io.ReadFromTFRecord(known_args.input,
coder=beam.coders.ProtoCoder(tf.train.Example))
tfrecords | "count mean" >> beam.ParDo(MyFnDo())


if __name__ == '__main__':
main(None)

那么基本上我如何在分析时跳过损坏的 tfrecords 并记录它们的数量?

最佳答案

它存在一个概念问题,beam.io.ReadFromTFRecord 从单个 tfrecord 读取(可以共享到多个文件),而我给出了许多单独 tfrecord 的列表因此它导致了错误。从 ReadFromTFRecord 切换到 ReadAllFromTFRecord 解决了我的问题。

p = beam.Pipeline(runner=direct_runner.DirectRunner())
tfrecords = p | beam.Create(tf.io.gfile.glob(input_)) | ReadAllFromTFRecord(coder=beam.coders.ProtoCoder(tf.train.Example))
tfrecords | beam.ParDo(MyFnDo())

关于python - 如何使用 Dataflow 在 apache beam 中跳过 io 级别的错误元素?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/60396249/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com