为了账号安全,请及时绑定邮箱和手机立即绑定

尝试为我的模型找到最佳多项式回归度时出现错误

尝试为我的模型找到最佳多项式回归度时出现错误

汪汪一只猫 2023-12-26 17:04:37
from sqlalchemy import create_engineimport osimport pandas as pdsnowflake_username = os.environ['SNOWFLAKE_USERNAME']snowflake_password = os.environ['SNOWFLAKE_PASSWORD']snowflake_account = os.environ['SNOWFLAKE_ACCOUNT']snowflake_warehouse = os.environ['SNOWFLAKE_WAREHOUSE']snowflake_database = 'test_db'snowflake_schema = 'public'if __name__ == '__main__':    engine = create_engine(        'snowflake://{user}:{password}@{account}/{db}/{schema}?warehouse={warehouse}'.format(            user=snowflake_username,            password=snowflake_password,            account=snowflake_account,            db=snowflake_database,            schema=snowflake_schema,            warehouse=snowflake_warehouse,        )    )    df = pd.DataFrame([('Mark', 10), ('Luke', 20)], columns=['name', 'balance'])    df.to_sql('TEST_TABLE', con=engine, schema='public', index=False, if_exists='append')每次我运行上面的脚本时,马可福音和路加福音的记录都会附加到我的test_db.public.test_table表中。
查看完整描述

1 回答

?
ABOUTYOU

TA贡献1812条经验 获得超5个赞

如果你的 x 是一维列向量,np.polyfit()并且np.polyval()将完成工作。np.polyfit(x,y,order,full=True)返回残差(我相信是残差平方和)供您检查最佳顺序。您不需要第二次回归拟合来获得残差。

注意,您选择最小残差的逻辑在工程师方面是可行的,但在数学上并不合理。这是因为误差平方和 (SSE) 始终会随着回归量数量的增加而减小,因此您始终会从最大多项式阶次获得结果。您必须尝试使用带有惩罚的公式,以便在选择模型时添加更多项(例如AIC或BIC标准)。然而,这部分完全由研究者自由选择,当然超出了问题本身的范围。

import numpy as np

import matplotlib.pyplot as plt


# get your x and y as np array

x = np.random.uniform(-1,1, 100)

y = x - x**2 + x**3 - x**4 + np.random.normal(0, 0.1, 100)


def poly_fit(n):

  ls_res=[]

  ls_coeff = []

  for i in range(2, n):

    coeff, res, _, _, _ = np.polyfit(x, y, i, full=True)

    ls_res.append(res)

    ls_coeff.append(coeff)


  # argmin should be taken from a penalized loss function

  # add it here

  return ls_coeff[np.argmin(ls_res)]


plt.scatter(x, y, color='red')

coeff = poly_fit(6)

plt.plot(np.sort(x), np.polyval(coeff, np.sort(x)), color='blue')

plt.title('Polynomial Regression results')

plt.xlabel('Position/level')

plt.ylabel('Salary')

plt.show()

https://img1.sycdn.imooc.com/658a974f000183db06390478.jpg

查看完整回答
反对 回复 2023-12-26
  • 1 回答
  • 0 关注
  • 98 浏览
慕课专栏
更多

添加回答

举报

0/150
提交
取消
意见反馈 帮助中心 APP下载
官方微信