gpt4 book ai didi

python - 自定义转换器添加附加列

转载 作者:行者123 更新时间:2023-12-05 05:49:41 26 4
gpt4 key购买 nike

我正在尝试将我的 lambda 函数复制到我的管道中

def determine_healthy(_list):
if ('no' in _list['smoker'] and (_list['bmi'] >= 18.5) and (_list['bmi']<= 24.9)):
return True
else:
return False

df['healthy'] = df.apply(lambda row: determine_healthy(row), axis=1)

当我将它集成到我的管道中时,问题就来了,我不确定问题是否在于添加了一个额外的“健康”列。当我尝试转换我的 X_train 时抛出此错误

from sklearn.base import BaseEstimator, TransformerMixin

class HealthyAttributeAdder(BaseEstimator, TransformerMixin):
def __init__(self, items=None):
if items is None: items = []
self.l = items
def fit(self, X , y=None):
return self
def transform(self, X):
#X = X.copy()
temp_cols = X.columns.to_list()
temp_cols = temp_cols.append('healthy')
new_cols = {k:v for k,v in zip(range(len(temp_cols)),temp_cols)}
healthy = X.apply(lambda row: determine_healthy(row), axis=1)
combined_df = pd.DataFrame(np.c_[X, healthy]).rename(columns=new_cols)
return combined_df

num_col = ['age','bmi']
cat_col = ['sex', 'smoker','region','children','healthy']
y = df.pop('charges')
X = df
all_col = X.columns
X_train, X_test, y_train, y_test = train_test_split(X,y , test_size=0.2, random_state = 42)

transform_pipeline = ColumnTransformer([
('healthy', HealthyAttributeAdder(), all_col),
('ss', StandardScaler(), num_col),
('ohe', OneHotEncoder(drop='first'), cat_col),
])

price_pipeline = Pipeline([
('transform', transform_pipeline),
('lasso',Lasso())
])

health_transform = HealthyAttributeAdder()
health_transform.fit_transform(X_train)
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
~\AppData\Local\Temp/ipykernel_19796/500623650.py in <module>
----> 1 health_transform.fit_transform(X_train)

~\Venv\hdbtest\lib\site-packages\sklearn\base.py in fit_transform(self, X, y, **fit_params)
850 if y is None:
851 # fit method of arity 1 (unsupervised transformation)
--> 852 return self.fit(X, **fit_params).transform(X)
853 else:
854 # fit method of arity 2 (supervised transformation)

~\AppData\Local\Temp/ipykernel_19796/3713134512.py in transform(self, X)
11 temp_cols = X.columns.to_list()
12 temp_cols = temp_cols.append('healthy')
---> 13 new_cols = {k:v for k,v in zip(range(len(temp_cols)),temp_cols)}
14 healthy = X.apply(lambda row: determine_healthy(row), axis=1)
15 combined_df = pd.DataFrame(np.c_[X, healthy]).rename(columns=new_cols)

TypeError: object of type 'NoneType' has no len()

当我用它来预测时出错:

price_pipeline.fit(X_train,y_train)
y_pred = price_pipeline.predict(X_test)
---------------------------------------------------------------------------
KeyError Traceback (most recent call last)
~\Venv\hdbtest\lib\site-packages\pandas\core\indexes\base.py in get_loc(self, key, method, tolerance)
3360 try:
-> 3361 return self._engine.get_loc(casted_key)
3362 except KeyError as err:

~\Venv\hdbtest\lib\site-packages\pandas\_libs\index.pyx in pandas._libs.index.IndexEngine.get_loc()

~\Venv\hdbtest\lib\site-packages\pandas\_libs\index.pyx in pandas._libs.index.IndexEngine.get_loc()

pandas\_libs\hashtable_class_helper.pxi in pandas._libs.hashtable.PyObjectHashTable.get_item()

pandas\_libs\hashtable_class_helper.pxi in pandas._libs.hashtable.PyObjectHashTable.get_item()

KeyError: 'healthy'

The above exception was the direct cause of the following exception:

KeyError Traceback (most recent call last)
~\Venv\hdbtest\lib\site-packages\sklearn\utils\__init__.py in _get_column_indices(X, key)
432 for col in columns:
--> 433 col_idx = all_columns.get_loc(col)
434 if not isinstance(col_idx, numbers.Integral):

~\Venv\hdbtest\lib\site-packages\pandas\core\indexes\base.py in get_loc(self, key, method, tolerance)
3362 except KeyError as err:
-> 3363 raise KeyError(key) from err
3364

KeyError: 'healthy'

The above exception was the direct cause of the following exception:

ValueError Traceback (most recent call last)
~\AppData\Local\Temp/ipykernel_19796/993407432.py in <module>
----> 1 price_pipeline.fit(X_train,y_train)
2 y_pred = price_pipeline.predict(X_test)

~\Venv\hdbtest\lib\site-packages\sklearn\pipeline.py in fit(self, X, y, **fit_params)
388 """
389 fit_params_steps = self._check_fit_params(**fit_params)
--> 390 Xt = self._fit(X, y, **fit_params_steps)
391 with _print_elapsed_time("Pipeline", self._log_message(len(self.steps) - 1)):
392 if self._final_estimator != "passthrough":

~\Venv\hdbtest\lib\site-packages\sklearn\pipeline.py in _fit(self, X, y, **fit_params_steps)
346 cloned_transformer = clone(transformer)
347 # Fit or load from cache the current transformer
--> 348 X, fitted_transformer = fit_transform_one_cached(
349 cloned_transformer,
350 X,

~\Venv\hdbtest\lib\site-packages\joblib\memory.py in __call__(self, *args, **kwargs)
347
348 def __call__(self, *args, **kwargs):
--> 349 return self.func(*args, **kwargs)
350
351 def call_and_shelve(self, *args, **kwargs):

~\Venv\hdbtest\lib\site-packages\sklearn\pipeline.py in _fit_transform_one(transformer, X, y, weight, message_clsname, message, **fit_params)
891 with _print_elapsed_time(message_clsname, message):
892 if hasattr(transformer, "fit_transform"):
--> 893 res = transformer.fit_transform(X, y, **fit_params)
894 else:
895 res = transformer.fit(X, y, **fit_params).transform(X)

~\Venv\hdbtest\lib\site-packages\sklearn\compose\_column_transformer.py in fit_transform(self, X, y)
670 self._check_n_features(X, reset=True)
671 self._validate_transformers()
--> 672 self._validate_column_callables(X)
673 self._validate_remainder(X)
674

~\Venv\hdbtest\lib\site-packages\sklearn\compose\_column_transformer.py in _validate_column_callables(self, X)
350 columns = columns(X)
351 all_columns.append(columns)
--> 352 transformer_to_input_indices[name] = _get_column_indices(X, columns)
353
354 self._columns = all_columns

~\Venv\hdbtest\lib\site-packages\sklearn\utils\__init__.py in _get_column_indices(X, key)
439
440 except KeyError as e:
--> 441 raise ValueError("A given column is not a column of the dataframe") from e
442
443 return column_indices

ValueError: A given column is not a column of the dataframe

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com