gpt4 book ai didi

python - 如何优化大型数据集的标签编码(sci-kit learn)

转载 作者:行者123 更新时间:2023-11-30 09:35:06 25 4
gpt4 key购买 nike

我正在使用 sci-kit learn 的标签编码类将字符串列表列表编码为整数代码。即

[[a,b,c],[b,c,d],[c,f,z]...,[a,v,z]]]

LabelEncoder 已实例化并适合标签名称。我想做的是迭代列表列表并转换每个列表。

我的第一个解决方案是强力迭代列表。

for list in list_of_lists:
label_encoder.transform(list)

当规模扩大到数万时,它变得非常慢。

我尝试将列表列表转换为 Pandas 数据帧,并将 Pandas 中的 .map 方法应用到数据集,但仍然很慢。

有什么方法可以优化标签编码器的转换吗?我不知道为什么这么慢。

最佳答案

您可以尝试使用纯 numpy,而不是使用 scikit-learn 循环,我确信这会更快。

如果内部列表中的元素数量始终相同(3?),那么您可以尝试以下操作:

<强>1。准备一些数据:

n=5
xs = np.random.choice(list("qwertyuiopasdfghjklzxcvbnm"),3*n).reshape((-1,3))
xs
array([['z', 'h', 'd'],
['g', 'k', 'y'],
['t', 'c', 'o'],
['f', 'b', 's'],
['x', 'n', 'z']],
dtype='<U1')

<强>2。编码

np.unique(xs, return_inverse=True)[1].reshape((-1,3))
array([[13, 5, 2],
[ 4, 6, 12],
[10, 1, 8],
[ 3, 0, 9],
[11, 7, 13]])

<强>3。时机

n = 1000000
xs = np.random.choice(list("qwertyuiopasdfghjklzxcvbnm"),3*n).reshape((-1,3))

%timeit np.unique(xs, return_inverse=True)[1].reshape((-1,3))
849 ms ± 39.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

不到一秒...

如果您可以显示您的完整代码,我们可以比较运行时间。

编辑:通过编码来回移动

由于 @JCDJulian 的评论(见下文),问题略有变化,我添加了代码片段,以在字典的帮助下在数据处理的任何点显示编码/解码:

首先,如果您想编码,您需要dic:

labels = np.unique(xs, return_inverse=True)[1]
dic = dict(zip(xs.flatten(),labels))

编码过程本身是:

ys = np.reshape([dic[v] for list in xs for v in list], (-1,3))
ys
array([[13, 5, 2],
[ 4, 6, 12],
[10, 1, 8],
[ 3, 0, 9],
[11, 7, 13]])

为了解码,您需要reverse_dic:

reverse_dic = dict(zip(labels, xs.flatten()))
np.reshape([reverse_dic[v] for list in ys for v in list], (-1,3))
array([['z', 'h', 'd'],
['g', 'k', 'y'],
['t', 'c', 'o'],
['f', 'b', 's'],
['x', 'n', 'z']],
dtype='<U1')

编辑2:随机形状数组

为了完整起见,随机形状数组的解决方案

编码:

labels = np.unique(xs, return_inverse=True)[1]
dic = dict(zip(xs.flatten(),labels))
np.vectorize(dic.get)(xs)
array([[13, 5, 2],
[ 4, 6, 12],
[10, 1, 8],
[ 3, 0, 9],
[11, 7, 13]])

解码:

reverse_dic = dict(zip(labels, xs.flatten()))
np.vectorize(reverse_dic.get)(ys)
array([['z', 'h', 'd'],
['g', 'k', 'y'],
['t', 'c', 'o'],
['f', 'b', 's'],
['x', 'n', 'z']],
dtype='<U1')

请注意,数组的形状不会出现在代码中的任何地方!

关于python - 如何优化大型数据集的标签编码(sci-kit learn),我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/45321999/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com