gpt4 book ai didi

python - 从 Spark (pyspark) 管道内的 StringIndexer 阶段获取标签

转载 作者:太空狗 更新时间:2023-10-30 00:59:02 27 4
gpt4 key购买 nike

我正在使用 Sparkpyspark 并且我有一个 pipeline 设置了一堆 StringIndexer 对象,我用它来将字符串列编码为索引列:

indexers = [StringIndexer(inputCol=column, outputCol=column + '_index').setHandleInvalid('skip')
for column in list(set(data_frame.columns) - ignore_columns)]
pipeline = Pipeline(stages=indexers)
new_data_frame = pipeline.fit(data_frame).transform(data_frame)

问题是,我需要为每个 StringIndexer 对象获取标签列表。对于没有管道的单个列和单个 StringIndexer,这是一项简单的任务。在 DataFrame 上安装索引器后,我可以访问 labels 属性:

indexer = StringIndexer(inputCol="name", outputCol="name_index")
indexer_fitted = indexer.fit(data_frame)
labels = indexer_fitted.labels
new_data_frame = indexer_fitted.transform(data_frame)

但是当我使用管道时,这似乎是不可能的,或者至少我不知道该怎么做。

所以我想我的问题归结为:有没有办法访问在索引过程中为每个列使用的标签?

或者我是否必须在这个用例中放弃管道,例如循环遍历 StringIndexer 对象列表并手动执行? (我相信这是可能的。但是使用管道会更好)

最佳答案

示例数据和管道:

from pyspark.ml.feature import StringIndexer, StringIndexerModel

df = spark.createDataFrame([("a", "foo"), ("b", "bar")], ("x1", "x2"))

pipeline = Pipeline(stages=[
StringIndexer(inputCol=c, outputCol='{}_index'.format(c))
for c in df.columns
])

model = pipeline.fit(df)

阶段中提取:

# Accessing _java_obj shouldn't be necessary in Spark 2.3+
{x._java_obj.getOutputCol(): x.labels
for x in model.stages if isinstance(x, StringIndexerModel)}
{'x1_index': ['a', 'b'], 'x2_index': ['foo', 'bar']}

来自转换后的 DataFrame 的元数据:

indexed = model.transform(df)

{c.name: c.metadata["ml_attr"]["vals"]
for c in indexed.schema.fields if c.name.endswith("_index")}
{'x1_index': ['a', 'b'], 'x2_index': ['foo', 'bar']}

关于python - 从 Spark (pyspark) 管道内的 StringIndexer 阶段获取标签,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/45885044/

27 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com