gpt4 book ai didi

python - 有没有办法在 GridSearchCV 中查看交叉验证的折叠?

转载 作者:太空狗 更新时间:2023-10-30 01:26:47 29 4
gpt4 key购买 nike

我目前正在使用 Python 中的 GridSearchCV 进行 3 倍 cv 优化超参数。我只是想知道是否有任何方法可以查看 GridSearchCV 中使用的 cv 中训练和测试数据的索引?

最佳答案

如果您不想在 CV 阶段折叠之前打乱样本,则可以。您可以将 KFold(或另一个 CV 类)的实例传递给 GridSearchCV 构造函数并像这样访问它的折叠:

import pandas as pd
import numpy as np
from sklearn.model_selection import GridSearchCV
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import KFold

params = {'penalty' : ['l1', 'l2'], 'C' : [1,2,3]}
grid = GridSearchCV(LogisticRegression(), params, cv=KFold(n_splits=3))

X = np.array([[1, 2], [3, 4], [1, 2], [3, 4], [5, 6], [7, 8]])

for train, test in grid.cv.split(X):
print('TRAIN: ', train, ' TEST: ', test)

打印:

TRAIN:  [2 3 4 5]  TEST:  [0 1]
TRAIN: [0 1 4 5] TEST: [2 3]
TRAIN: [0 1 2 3] TEST: [4 5]

对于未打乱的 CV,折叠始终相同,因此您可以确定这些是在网格搜索期间使用的折叠。

如果您想在折叠之前打乱样本,则稍微复杂一些,因为每次调用 cv.split() 都会生成不同的拆分。我可以想到两种方法:

  1. 您可以为 CV 对象提供固定的 random_state,例如KFold(n_splits=3, shuffle=True, random_state=42)

  2. 在创建 GridSearchCV 对象之前,从 KFold 迭代器创建一个列表。

因此,对于第二种方法,请执行以下操作:

grid = GridSearchCV(LogisticRegression(), params, 
cv=list(KFold(n_splits=3, shuffle=True).split(X)))

除了迭代器,列表是一个固定对象,除非您手动操作它,否则它将在所有 GridSearch 迭代中保持相同的值。

关于python - 有没有办法在 GridSearchCV 中查看交叉验证的折叠?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/42011850/

29 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com