gpt4 book ai didi

python - Scikit-learn微调: Postprocess predicted labels before evaluation

转载 作者:行者123 更新时间:2023-11-30 09:33:05 25 4
gpt4 key购买 nike

我想知道是否有一种方法可以在 sklearn 中对预测标签进行后处理。我的训练数据具有以下形式的地面实况标签0, 1

但是,问题是我目前正在使用隔离森林,它预测:

  • -1 表示异常值,相当于真实标签 1
  • 1 表示普通数据,相当于真实标签 0

如果我要编写一个函数来对预测进行后处理,那会非常简单:

def process_anomaly_labels(raw_y_pred):
y_pred = raw_y_pred.copy()
y_pred[raw_y_pred == 1] = 0
y_pred[raw_y_pred == -1] = 1
return y_pred

但是当我使用 RandomSearchCV 微调模型时,我不知道如何对预测标签进行后处理:

from sklearn.model_selection import RandomizedSearchCV
# fine tuning
forest_params = {
"n_estimators": [50, 200, 800],
"max_samples": [1000, 4000, 16000, 64000, 120000],
"max_features": [1, 5, 15, 30],
"contamination": [0.001, 0.1, 0.2, 0.5]
}
forest_grid_search = RandomizedSearchCV(
IsolationForest(),
param_distributions=forest_params,
scoring="f1",
n_jobs=8,
n_iter=50,
cv=3,
verbose=2
)
forest_grid_search.fit(X_train_trans, y_train)

我无法将真实标签转换为与预测标签匹配,因为我想在评估时使用二进制 F1 分数。

最佳答案

按照评论中的建议,编写一个执行所需映射的自定义记分器。

示例代码

from sklearn.metrics import make_scorer, f1_score
from sklearn.ensemble import IsolationForest
from sklearn.datasets import make_blobs
from sklearn.model_selection import RandomizedSearchCV
import numpy as np

def relabeled_f1_score(y_true, y_pred):
y_pred_c = y_pred.copy()
y_pred_c[y_pred_c == 1] = 0
y_pred_c[y_pred_c == -1] = 1
return f1_score(y_true=y_true, y_pred=y_pred_c)

n_samples = 1000
n_features = 40

X, _ = make_blobs(n_samples=n_samples, n_features=n_features)
y = np.random.choice([0, 1], n_samples) # 1 = outlier, 0 = inliner

param_grid = {
"n_estimators": [50, 200, 800],
"max_samples": [1000, 4000, 16000, 64000, 120000],
"max_features": [1, 5, 15, 30],
"contamination": [0.001, 0.1, 0.2, 0.5]
}

custom_scorer = make_scorer(score_func=relabeled_f1_score, greater_is_better=True)
my_rs = RandomizedSearchCV(IsolationForest(), param_distributions=param_grid, scoring=custom_scorer, verbose=3)

my_rs.fit(X, y)

关于python - Scikit-learn微调: Postprocess predicted labels before evaluation,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/50828900/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com