3 回答
TA贡献1998条经验 获得超6个赞
我似乎无法弄清楚如何使线性回归线(又名最佳拟合线)跨越图形的整个宽度。它似乎只是上升了左边最远的数据点和右边最远的数据点,没有进一步。我将如何解决这个问题?
import matplotlib.pyplot as plt
import numpy as np
from scipy import stats
from scipy.interpolate import *
import MySQLdb
# connect to MySQL database
def mysql_select_all():
conn = MySQLdb.connect(host='localhost',
user='root',
passwd='XXXXX',
db='world')
cursor = conn.cursor()
sql = """
SELECT
GNP, Population
FROM
country
WHERE
Name LIKE 'United States'
OR Name LIKE 'Canada'
OR Name LIKE 'United Kingdom'
OR Name LIKE 'Russia'
OR Name LIKE 'Germany'
OR Name LIKE 'Poland'
OR Name LIKE 'Italy'
OR Name LIKE 'China'
OR Name LIKE 'India'
OR Name LIKE 'Japan'
OR Name LIKE 'Brazil';
"""
cursor.execute(sql)
result = cursor.fetchall()
list_x = []
list_y = []
for row in result:
list_x.append(('%r' % (row[0],)))
for row in result:
list_y.append(('%r' % (row[1],)))
list_x = list(map(float, list_x))
list_y = list(map(float, list_y))
fig = plt.figure()
ax1 = plt.subplot2grid((1,1), (0,0))
p1 = np.polyfit(list_x, list_y, 1) # this line refers to line of regression
ax1.xaxis.labelpad = 50
ax1.yaxis.labelpad = 50
plt.plot(list_x, np.polyval(p1,list_x),'r-') # this refers to line of regression
plt.scatter(list_x, list_y, color = 'darkgreen', s = 100)
plt.xlabel("GNP (US dollars)", fontsize=30)
plt.ylabel("Population(in billions)", fontsize=30)
plt.xticks([1000000, 2000000, 3000000, 4000000, 5000000, 6000000,
7000000, 8000000, 9000000], rotation=45, fontsize=14)
plt.yticks(fontsize=14)
plt.show()
cursor.close()
mysql_select_all()
而延长之后,
TA贡献1906条经验 获得超10个赞
如果您希望绘图不超出 x 轴上的数据,只需执行以下操作:
fig, ax = plt.subplots()
ax.margins(x=0)
# Don't use plt.plot
ax.plot(list_x, np.polyval(p1,list_x),'r-')
ax.scatter(list_x, list_y, color = 'darkgreen', s = 100)
ax.set_xlabel("GNP (US dollars)", fontsize=30)
ax.set_ylabel("Population(in billions)", fontsize=30)
ax.set_xticks([1000000, 2000000, 3000000, 4000000, 5000000, 6000000, 7000000, 8000000, 9000000], rotation=45, fontsize=14)
ax.tick_params(axis='y', labelsize=14)
添加回答
举报