为了账号安全,请及时绑定邮箱和手机立即绑定

sklearn 分层 k 折 CV 与线性模型,如 ElasticNetCV

sklearn 分层 k 折 CV 与线性模型,如 ElasticNetCV

杨__羊羊 2021-11-23 16:13:03
使用交叉验证 (CV)sklearn非常简单直接。但是cv=5在线性 CV 模型中设置时的默认实现,例如ElasticNetCV或LassoCV是KFoldCV。出于各种原因,我想使用StratifiedKFold. 从文档来看,似乎任何CV 方法都可以用cv=.传递cv=KFold(5)按预期工作,但cv=StratifiedKFold(5)会引发错误:ValueError: 支持的目标类型是: ('binary', 'multiclass')。取而代之的是“连续”。我知道我可以cross_val_score在拟合后使用,但我想StratifiedKFold作为 CV 直接传递给线性模型。我的最低工作示例是:from sklearn.linear_model import ElasticNetCVfrom sklearn.model_selection import KFold, StratifiedKFoldimport numpy as npx = np.arange(100, dtype=np.float64).reshape(-1, 1)y = np.arange(100) + np.random.rand(100)# KFold default implementation:model_default = ElasticNetCV(cv=5)model_default.fit(x, y)  # works fine# KFold given as cv explicitly:model_kfexp = ElasticNetCV(cv=KFold(5))model_kfexp.fit(x, y)  # also works fine# StratifiedKFold given as cv explicitly:model_skf = ElasticNetCV(cv=StratifiedKFold(5))model_skf.fit(x, y)  # THIS RAISES THE ERROR知道如何StratifiedKFold直接设置为 CV 吗?
查看完整描述

1 回答

?
婷婷同学_

TA贡献1844条经验 获得超8个赞

你的问题的根源是这一行:


y = np.arange(100) + np.random.rand(100)

StratifiedKFold无法从连续分布中采样,因此您的错误。尝试更改这一行,您的代码将愉快地执行:


from sklearn.linear_model import ElasticNetCV

from sklearn.model_selection import KFold, StratifiedKFold

import numpy as np


x = np.arange(100, dtype=np.float64).reshape(-1, 1)

y = np.random.choice([0,1], size=100)


# KFold default implementation:

model_default = ElasticNetCV(cv=5)

model_default.fit(x, y)  # works fine

# KFold given as cv explicitly:

model_kfexp = ElasticNetCV(cv=KFold(5))

model_kfexp.fit(x, y)  # also works fine


# StratifiedKFold given as cv explicitly:

model_skf = ElasticNetCV(cv=StratifiedKFold(5))

model_skf.fit(x, y)  # no ERROR

笔记


如果您对连续数据进行采样,请使用KFold. 如果您的目标是明确的,您可以使用两者KFold并 使用StratifiedKFold适合您需要的任何一种。


笔记2


如果您坚持在连续数据上模拟分层抽样,您可能希望应用pandas.cut到您的数据,然后对该数据进行分层抽样,最后将结果(train_id, test_id)生成器传递给cvparam:


x = np.arange(100, dtype=np.float64).reshape(-1, 1)

y = np.arange(100) + np.random.rand(100)


y_cat = pd.cut(y, 10, labels=range(10))

skf_gen = StratifiedKFold(5).split(x, y_cat)


model_skf = ElasticNetCV(cv=skf_gen)

model_skf.fit(x, y)  # no ERROR


查看完整回答
反对 回复 2021-11-23
  • 1 回答
  • 0 关注
  • 259 浏览
慕课专栏
更多

添加回答

举报

0/150
提交
取消
意见反馈 帮助中心 APP下载
官方微信