gpt4 book ai didi

python - 如何使用 mapPartitions 在 RDD 的分区上运行 python 用户定义函数?

转载 作者:行者123 更新时间:2023-12-05 06:32:27 25 4
gpt4 key购买 nike

我正在尝试在 RDD 的分区上运行 python UDF。这是我创建 rdd 的方式:

text_file = open("/home/zeinab/Desktop/inputFile.txt", "r")
lines = text_file.read().strip().split("\n")
linestofloat = []
for l in lines:
linestofloat.append(float(l))
linestofloat = np.array(linestofloat)
data = sc.parallelize(linestofloat)

输入文本文件的格式如下所示:

26.000000

-8.000000

-28.000000

-6.000000

-18.000000

...

我尝试使用 mapPartitons 运行的函数如下:

def classic_sta_lta_py(a, nsta, nlta):
"""
Computes the standard STA/LTA from a given input array a. The length of
the STA is given by nsta in samples, respectively is the length of the
LTA given by nlta in samples. Written in Python.

.. note::

There exists a faster version of this trigger wrapped in C
called :func:`~obspy.signal.trigger.classic_sta_lta` in this module!

:type a: NumPy :class:`~numpy.ndarray`
:param a: Seismic Trace
:type nsta: int
:param nsta: Length of short time average window in samples
:type nlta: int
:param nlta: Length of long time average window in samples
:rtype: NumPy :class:`~numpy.ndarray`
:return: Characteristic function of classic STA/LTA
"""
# The cumulative sum can be exploited to calculate a moving average (the
# cumsum function is quite efficient)
print("Hello!!!")
#a =[x for x in floatelems.toLocalIterator()]
#a = np.array(a)
print("a array is: {} ".format(a))
sta = np.cumsum(a ** 2)
#print("{}. sta array is: ".format(sta))


# Convert to float
sta = np.require(sta, dtype=np.float)

# Copy for LTA
lta = sta.copy()

# Compute the STA and the LTA
sta[nsta:] = sta[nsta:] - sta[:-nsta]
sta /= nsta
lta[nlta:] = lta[nlta:] - lta[:-nlta]
lta /= nlta

# Pad zeros
sta[:nlta - 1] = 0

# Avoid division by zero by setting zero values to tiny float
dtiny = np.finfo(0.0).tiny
idx = lta < dtiny
lta[idx] = dtiny

return sta / lta

但是当我运行以下行时,我不断收到以下错误:

stalta_ratio = data.mapPartitions(lambda i: classic_sta_lta_py(i, 2, 30))

错误:

TypeError: unsupported operand type(s) for ** or pow(): 'itertools.chain' and 'int'

at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.handlePythonException(PythonRunner.scala:298)
at org.apache.spark.api.python.PythonRunner$$anon$1.read(PythonRunner.scala:438)
at org.apache.spark.api.python.PythonRunner$$anon$1.read(PythonRunner.scala:421)
at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNext(PythonRunner.scala:252)
at org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37)
at scala.collection.Iterator$class.foreach(Iterator.scala:893)
at org.apache.spark.InterruptibleIterator.foreach(InterruptibleIterator.scala:28)
at scala.collection.generic.Growable$class.$plus$plus$eq(Growable.scala:59)
at scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:104)
at scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:48)
at scala.collection.TraversableOnce$class.to(TraversableOnce.scala:310)
at org.apache.spark.InterruptibleIterator.to(InterruptibleIterator.scala:28)
at scala.collection.TraversableOnce$class.toBuffer(TraversableOnce.scala:302)
at org.apache.spark.InterruptibleIterator.toBuffer(InterruptibleIterator.scala:28)
at scala.collection.TraversableOnce$class.toArray(TraversableOnce.scala:289)
at org.apache.spark.InterruptibleIterator.toArray(InterruptibleIterator.scala:28)
at org.apache.spark.rdd.RDD$$anonfun$collect$1$$anonfun$12.apply(RDD.scala:939)
at org.apache.spark.rdd.RDD$$anonfun$collect$1$$anonfun$12.apply(RDD.scala:939)
at org.apache.spark.SparkContext$$anonfun$runJob$5.apply(SparkContext.scala:2074)
at org.apache.spark.SparkContext$$anonfun$runJob$5.apply(SparkContext.scala:2074)
at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:87)
at org.apache.spark.scheduler.Task.run(Task.scala:109)
at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:345)
at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
... 1 more

有人知道我做错了什么吗?

谢谢。

最佳答案

你在 mapPartitions 中的 lambda 中获得的参数类型是迭代器,但是查看你的函数文档你需要 numpy.ndarray .如果您的数据集小到足以由一个执行者处理,您可以轻松地转换它。试试这个:

data.mapPartitions(
lambda i: classic_sta_lta_py(np.ndarray(list(i)), 2, 30)
)

关于python - 如何使用 mapPartitions 在 RDD 的分区上运行 python 用户定义函数?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/51314144/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com