gpt4 book ai didi

python - 您能否使用 Sklearn 的 Transformer API 持续跟踪列标签?

转载 作者:太空宇宙 更新时间:2023-11-03 11:36:52 28 4
gpt4 key购买 nike


现在,任何使用 sklearn 中的 transformer api 的方法都会返回一个 numpy 数组作为其结果。通常这很好,但是如果您将一个扩展或减少列数的多步骤过程链接在一起,没有一种清晰的方法来跟踪它们与原始列标签的关系,这将很难使用充分利用图书馆。


numeric_columns = train.select_dtypes(include=np.number).columns.tolist()
cat_columns = train.select_dtypes(include=np.object).columns.tolist()

numeric_pipeline = make_pipeline(SimpleImputer(strategy='median'), StandardScaler())
cat_pipeline = make_pipeline(SimpleImputer(strategy='most_frequent'), OneHotEncoder())

transformers = [
('num', numeric_pipeline, numeric_columns),
('cat', cat_pipeline, cat_columns)

combined_pipe = ColumnTransformer(transformers)

train_clean = combined_pipe.fit_transform(train)

test_clean = combined_pipe.transform(test)

在这个例子中,我使用 ColumnTransformer 拆分了我的数据集,然后使用 OneHotEncoder 添加了额外的列,所以我对列的排列与我开始的不一样出去。

如果我使用使用相同 API 的不同模块,我很容易会有不同的安排。 OrdinalEncoerselect_k_best


对此有广泛的讨论here ,但我认为还没有最终确定。


是的,您是对的,目前 sklearn 中还没有完全支持跟踪 feature_names。最初,决定在 numpy 数组级别将其保持为通用。可以跟踪 sklearn 估计器中添加的特征名称的最新进展 here .

无论如何,我们可以创建包装器来获取 ColumnTransformer 的特征名称。我不确定它是否可以捕获所有可能的 ColumnTransformers 类型。但至少,它可以解决您的问题。

来自Documentation of ColumnTransformer :


The order of the columns in the transformed feature matrix follows the order of how the columns are specified in the transformers list. Columns of the original feature matrix that are not specified are dropped from the resulting transformed feature matrix, unless specified in the passthrough keyword. Those columns specified with passthrough are added at the right to the output of the transformers.


import pandas as pd
import numpy as np
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import make_pipeline, Pipeline
from sklearn.impute import SimpleImputer
from sklearn.preprocessing import StandardScaler, OneHotEncoder, MinMaxScaler
from sklearn.feature_extraction.text import _VectorizerMixin
from sklearn.feature_selection._base import SelectorMixin
from sklearn.feature_selection import SelectKBest
from sklearn.feature_extraction.text import CountVectorizer

train = pd.DataFrame({'age': [23,12, 12, np.nan],
'Gender': ['M','F', np.nan, 'F'],
'income': ['high','low','low','medium'],
'sales': [10000, 100020, 110000, 100],
'foo' : [1,0,0,1],
'text': ['I will test this',
'need to write more sentence',
'want to keep it simple',
'hope you got that these sentences are junk'],
'y': [0,1,1,1]})
numeric_columns = ['age']
cat_columns = ['Gender','income']

numeric_pipeline = make_pipeline(SimpleImputer(strategy='median'), StandardScaler())
cat_pipeline = make_pipeline(SimpleImputer(strategy='most_frequent'), OneHotEncoder())
text_pipeline = make_pipeline(CountVectorizer(), SelectKBest(k=5))

transformers = [
('num', numeric_pipeline, numeric_columns),
('cat', cat_pipeline, cat_columns),
('text', text_pipeline, 'text'),
('simple_transformer', MinMaxScaler(), ['sales']),

combined_pipe = ColumnTransformer(
transformers, remainder='passthrough')

transformed_data = combined_pipe.fit_transform(
train.drop('y',1), train['y'])

def get_feature_out(estimator, feature_in):
if hasattr(estimator,'get_feature_names'):
if isinstance(estimator, _VectorizerMixin):
# handling all vectorizers
return [f'vec_{f}' \
for f in estimator.get_feature_names()]
return estimator.get_feature_names(feature_in)
elif isinstance(estimator, SelectorMixin):
return np.array(feature_in)[estimator.get_support()]
return feature_in

def get_ct_feature_names(ct):
# handles all estimators, pipelines inside ColumnTransfomer
# doesn't work when remainder =='passthrough'
# which requires the input column names.
output_features = []

for name, estimator, features in ct.transformers_:
if name!='remainder':
if isinstance(estimator, Pipeline):
current_features = features
for step in estimator:
current_features = get_feature_out(step, current_features)
features_out = current_features
features_out = get_feature_out(estimator, features)
elif estimator=='passthrough':

return output_features


enter image description here

关于python - 您能否使用 Sklearn 的 Transformer API 持续跟踪列标签?,我们在Stack Overflow上找到一个类似的问题:

28 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号