为了账号安全,请及时绑定邮箱和手机立即绑定

如何为具有自定义管道的自定义估算器运行网格搜索?

如何为具有自定义管道的自定义估算器运行网格搜索?

鸿蒙传说 2021-10-19 15:37:04
我正在尝试为一个自定义案例运行网格搜索,该案例涉及一个包含pipeline作为其构造函数输入之一的估算器。class DefaultEstimator(BaseEstimator, TransformerMixin):  def __init__(self, preprocessor, pipelines):    self.pipelines = pipelines  def fit(self, X, y=None):    for idx, each_pipeline in enumerate(self.pipelines):      each_pipeline.fit(X[idx], y)    return self  def transform(self, X):   transformed_data = []   for idx, each_pipeline in enumerate(self.pipelines):     transformed_data.append(each_pipeline.transform(X[idx))   return sp.hstack(transformed_data)我的管道看起来像这样:pipeline1 = trainer.create_pipeline(num_features=100)pipeline2 = trainer.create_pipeline(num_features=50)复合管道看起来像:aggregated_pipeline = Pipeline([('contextual', DefaultEstimator([pipeline1, pipeline2])),                                ('classifier', Pipeline([('clf', SVM(random_state=1234, probability=True)]))                              ])输入数据有两列,每列都有一个各自的管道(pipeline1和pipeline2)。对于按键grid_params的clf可写成classifier__clf__C,classifier__clf__gamma等等。现在的问题是:如何编写grid_params用于GridSearchCV(...)作为管道的步骤之一是不是一个管道对象,而定制估计对象?
查看完整描述

1 回答

?
LEATH

TA贡献1936条经验 获得超6个赞

GridSearchCV并Pipeline使用估计器set_params设置要测试的参数。所以,你必须在你的 中实现这一点DefaultEstimator,并适当地设置管道参数。scikit 中的一个常见模式是使用双下划线来分隔嵌套对象的参数,例如:


class DefaultEstimator:

    def set_params(self, **kwargs):

        for k, v in kwargs.items():

            parts = k.split('__')

            if parts[0].startswith('pipeline'):

                pipe_num = int(parts[0].split('_')[1])

                param_name = '__'.join(parts[1:])

                self.pipelines[pipe_num].set_params(*{param_name: v})

            else:

                # other logic

这将允许您使用诸如contextual__pipeline_1__num_features(contextual__将被网格搜索剥离,因此无需处理它) 之类的参数。


查看完整回答
反对 回复 2021-10-19
  • 1 回答
  • 0 关注
  • 173 浏览
慕课专栏
更多

添加回答

举报

0/150
提交
取消
微信客服

购课补贴
联系客服咨询优惠详情

帮助反馈 APP下载

慕课网APP
您的移动学习伙伴

公众号

扫描二维码
关注慕课网微信公众号