为了账号安全,请及时绑定邮箱和手机立即绑定

自定义估计器无法通过 cross_val_score 进行深度复制

自定义估计器无法通过 cross_val_score 进行深度复制

交互式爱情 2023-08-08 17:07:51
我有一个自己实现的自定义估计器,但无法使用,我相信这与我的方法cross_val_score()有关。predict()这是完整的错误跟踪:    Traceback (most recent call last):  File "/Users/joann/Desktop/Implementações ML/Adaboost Classifier/test.py", line 30, in <module>    ada2_score = cross_val_score(ada_2, X, y, cv=5)  File "/Users/joann/opt/anaconda3/lib/python3.7/site-packages/sklearn/model_selection/_validation.py", line 390, in cross_val_score    error_score=error_score)  File "/Users/joann/opt/anaconda3/lib/python3.7/site-packages/sklearn/model_selection/_validation.py", line 236, in cross_validate    for train, test in cv.split(X, y, groups))  File "/Users/joann/opt/anaconda3/lib/python3.7/site-packages/joblib/parallel.py", line 1004, in __call__    if self.dispatch_one_batch(iterator):  File "/Users/joann/opt/anaconda3/lib/python3.7/site-packages/joblib/parallel.py", line 835, in dispatch_one_batch    self._dispatch(tasks)  File "/Users/joann/opt/anaconda3/lib/python3.7/site-packages/joblib/parallel.py", line 754, in _dispatch    job = self._backend.apply_async(batch, callback=cb)  File "/Users/joann/opt/anaconda3/lib/python3.7/site-packages/joblib/_parallel_backends.py", line 209, in apply_async    result = ImmediateResult(func)  File "/Users/joann/opt/anaconda3/lib/python3.7/site-packages/joblib/_parallel_backends.py", line 590, in __init__    self.results = batch()  File "/Users/joann/opt/anaconda3/lib/python3.7/site-packages/joblib/parallel.py", line 256, in __call__    for func, args, kwargs in self.items]我的predict(self, X)方法返回一个大小向量n_samples以及参数的预测X。我还做了一个score()功能如下:def score(self, X, y):     scr_pred = self.predict(X)         return sum(scr_pred == y) / X.shape[0]该方法只是计算给定样本的模型的准确性。如果我使用此score()方法或设置 across_val_score(... , scoring="accuracy")它不起作用。
查看完整描述

1 回答

?
有只小跳蛙

TA贡献1824条经验 获得超8个赞

然而,您的问题陈述在这里并不清楚,但是查看错误,您似乎正在尝试多类分类。

这里的问题是,您的代码中可能在某些时候没有正确完成预处理,因为错误是从 inverse_binarize_thresholding 记录的,这是由于 sklearn 预处理的以下功能而引发的:

def _inverse_binarize_thresholding(y, output_type, classes, threshold):   
    if output_type == "binary" and y.ndim == 2 and y.shape[1] > 2: 
           raise ValueError("output_type='binary', but y.shape = {0}". 
                                   format(y.shape))

您的代码中必须缺少一些转换或预处理,并且您必须正确使用 LabelBinarizer

查看完整回答
反对 回复 2023-08-08
  • 1 回答
  • 0 关注
  • 119 浏览
慕课专栏
更多

添加回答

举报

0/150
提交
取消
意见反馈 帮助中心 APP下载
官方微信