为了账号安全,请及时绑定邮箱和手机立即绑定

Matplotlib - 将线性回归线扩展到图的整个宽度

Matplotlib - 将线性回归线扩展到图的整个宽度

神不在的星期二 2021-10-26 15:42:51
我似乎无法弄清楚如何使线性回归线(又名最佳拟合线)跨越图形的整个宽度。它似乎只是上升了左边最远的数据点和右边最远的数据点,没有进一步。我将如何解决这个问题?import matplotlib.pyplot as pltimport numpy as npfrom scipy import statsfrom scipy.interpolate import *import MySQLdb# connect to MySQL databasedef mysql_select_all():    conn = MySQLdb.connect(host='localhost',                           user='root',                           passwd='XXXXX',                           db='world')    cursor = conn.cursor()    sql = """        SELECT            GNP, Population        FROM            country        WHERE            Name LIKE 'United States'                OR Name LIKE 'Canada'                OR Name LIKE 'United Kingdom'                OR Name LIKE 'Russia'                OR Name LIKE 'Germany'                OR Name LIKE 'Poland'                OR Name LIKE 'Italy'                OR Name LIKE 'China'                OR Name LIKE 'India'                OR Name LIKE 'Japan'                OR Name LIKE 'Brazil';    """    cursor.execute(sql)    result = cursor.fetchall()    list_x = []    list_y = []    for row in result:        list_x.append(('%r' % (row[0],)))    for row in result:        list_y.append(('%r' % (row[1],)))    list_x = list(map(float, list_x))    list_y = list(map(float, list_y))    fig = plt.figure()    ax1 = plt.subplot2grid((1,1), (0,0))    p1 = np.polyfit(list_x, list_y, 1)          # this line refers to line of regression    ax1.xaxis.labelpad = 50    ax1.yaxis.labelpad = 50    plt.plot(list_x, np.polyval(p1,list_x),'r-') # this refers to line of regression      plt.scatter(list_x, list_y, color = 'darkgreen', s = 100)    plt.xlabel("GNP (US dollars)", fontsize=30)    plt.ylabel("Population(in billions)", fontsize=30)    plt.xticks([1000000, 2000000, 3000000, 4000000, 5000000, 6000000,                 7000000, 8000000, 9000000],  rotation=45, fontsize=14)    plt.yticks(fontsize=14)    plt.show()    cursor.close()mysql_select_all()
查看完整描述

3 回答

?
米琪卡哇伊

TA贡献1998条经验 获得超6个赞

我似乎无法弄清楚如何使线性回归线(又名最佳拟合线)跨越图形的整个宽度。它似乎只是上升了左边最远的数据点和右边最远的数据点,没有进一步。我将如何解决这个问题?


import matplotlib.pyplot as plt

import numpy as np

from scipy import stats

from scipy.interpolate import *

import MySQLdb


# connect to MySQL database

def mysql_select_all():

    conn = MySQLdb.connect(host='localhost',

                           user='root',

                           passwd='XXXXX',

                           db='world')

    cursor = conn.cursor()

    sql = """

        SELECT

            GNP, Population

        FROM

            country

        WHERE

            Name LIKE 'United States'

                OR Name LIKE 'Canada'

                OR Name LIKE 'United Kingdom'

                OR Name LIKE 'Russia'

                OR Name LIKE 'Germany'

                OR Name LIKE 'Poland'

                OR Name LIKE 'Italy'

                OR Name LIKE 'China'

                OR Name LIKE 'India'

                OR Name LIKE 'Japan'

                OR Name LIKE 'Brazil';

    """


    cursor.execute(sql)

    result = cursor.fetchall()


    list_x = []

    list_y = []


    for row in result:

        list_x.append(('%r' % (row[0],)))


    for row in result:

        list_y.append(('%r' % (row[1],)))


    list_x = list(map(float, list_x))

    list_y = list(map(float, list_y))


    fig = plt.figure()

    ax1 = plt.subplot2grid((1,1), (0,0))


    p1 = np.polyfit(list_x, list_y, 1)          # this line refers to line of regression


    ax1.xaxis.labelpad = 50

    ax1.yaxis.labelpad = 50


    plt.plot(list_x, np.polyval(p1,list_x),'r-') # this refers to line of regression  

    plt.scatter(list_x, list_y, color = 'darkgreen', s = 100)

    plt.xlabel("GNP (US dollars)", fontsize=30)

    plt.ylabel("Population(in billions)", fontsize=30)

    plt.xticks([1000000, 2000000, 3000000, 4000000, 5000000, 6000000, 

                7000000, 8000000, 9000000],  rotation=45, fontsize=14)

    plt.yticks(fontsize=14)


    plt.show()

    cursor.close()


mysql_select_all()

//img1.sycdn.imooc.com//6177b19d0001c4a019581284.jpg

而延长之后,

//img1.sycdn.imooc.com//6177b1aa00016a3619621284.jpg


查看完整回答
反对 回复 2021-10-26
?
隔江千里

TA贡献1906条经验 获得超10个赞

如果您希望绘图不超出 x 轴上的数据,只需执行以下操作:


fig, ax = plt.subplots()

ax.margins(x=0)

# Don't use plt.plot

ax.plot(list_x, np.polyval(p1,list_x),'r-')

ax.scatter(list_x, list_y, color = 'darkgreen', s = 100)

ax.set_xlabel("GNP (US dollars)", fontsize=30)

ax.set_ylabel("Population(in billions)", fontsize=30)

ax.set_xticks([1000000, 2000000, 3000000, 4000000, 5000000, 6000000, 7000000, 8000000, 9000000],  rotation=45, fontsize=14)

ax.tick_params(axis='y', labelsize=14)


查看完整回答
反对 回复 2021-10-26
  • 3 回答
  • 0 关注
  • 343 浏览
慕课专栏
更多

添加回答

举报

0/150
提交
取消
意见反馈 帮助中心 APP下载
官方微信