2 回答
data:image/s3,"s3://crabby-images/ed041/ed04106b0e9cfcdd62c40dbdad6c2cee15d70575" alt="?"
TA贡献1859条经验 获得超6个赞
PS:我根据@ImportanceOfBeingEarnest指出的问题编辑了这个答案(感谢他)。请阅读答案下方的评论以了解我的意思。
新的解决方案是获取该特定轴的显示刻度并将它们格式化为小数点后两位。
new_labels = [round(float(i.get_text()), 2) for i in axes[0,0].get_yticklabels()]
axes[0,0].set_yticklabels(new_labels)
旧答案(仍保留为历史记录,因为您将看到下图中生成的 y 刻度不正确)
问题是您使用axobject 来格式化标签,但ax返回scatter_matrix的不是单轴对象。它是一个包含 9 轴(3x3 子图)的对象。如果绘制axes变量的形状,则可以证明这一点。
axes = scatter_matrix(df, alpha=0.3, figsize=(9,9), diagonal='kde')
print (axes.shape)
# (3, 3)
该解决方案是要么迭代通过所有的轴或只改变有问题的情况下,格式化。PS:下图与你的不匹配,因为我只是使用了你发布的小DataFrame。
以下是如何为所有 y 轴执行此操作
from pandas.plotting import scatter_matrix
from matplotlib.ticker import FormatStrFormatter
axes = scatter_matrix(df, alpha=0.3, figsize=(9,9), diagonal='kde')
for ax in axes.flatten():
ax.yaxis.set_major_formatter(FormatStrFormatter('%.2f'))
或者,您可以只选择一个特定的轴。在这里可以使用访问您的左上角子图axes[0,0]
axes[0,0].yaxis.set_major_formatter(FormatStrFormatter('%.2f'))
data:image/s3,"s3://crabby-images/3f1e7/3f1e7d1dd6080d1a634e344279d1f34a58f5ed9b" alt="?"
TA贡献1836条经验 获得超13个赞
pandas.scatter_matrix遭受不幸的设计选择。也就是说,它在轴的对角线上绘制 kde 或直方图,显示行的其余部分的刻度。然后这需要伪造刻度和标签以适合数据。在此过程中使用了 aFixedLocator和 a FixedFormatter。因此,刻度标签的格式直接从数字的字符串表示中接管。
我会在这里提出一个完全不同的设计。也就是说,对角线轴应该保持空白,而是使用双轴来显示直方图或 kde 曲线。因此,问题中的问题不会发生。
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
def scatter_matrix(df, axes=None, **kw):
n = df.columns.size
diagonal = kw.pop("diagonal", "hist")
if not axes:
fig, axes = plt.subplots(n,n, figsize=kw.pop("figsize", None),
squeeze=False, sharex="col", sharey="row")
else:
flax = axes.flatten()
fig = flax[0].figure
assert len(flax) == n*n
# no gaps between subplots
fig.subplots_adjust(wspace=0, hspace=0)
hist_kwds = kw.pop("hist_kwds", {})
density_kwds = kw.pop("density_kwds", {})
import itertools
p = itertools.permutations(df.columns, r=2)
n = itertools.permutations(np.arange(len(df.columns)), r=2)
for (i,j), (y,x) in zip(n,p):
axes[i,j].scatter(df[x].values, df[y].values, **kw)
axes[i,j].tick_params(left=False, labelleft=False,
bottom=False, labelbottom=False)
diagaxes = []
for i, c in enumerate(df.columns):
ax = axes[i,i].twinx()
diagaxes.append(ax)
if diagonal == 'hist':
ax.hist(df[c].values, **hist_kwds)
elif diagonal in ('kde', 'density'):
from scipy.stats import gaussian_kde
y = df[c].values
gkde = gaussian_kde(y)
ind = np.linspace(y.min(), y.max(), 1000)
ax.plot(ind, gkde.evaluate(ind), **density_kwds)
if i!= 0:
diagaxes[0].get_shared_y_axes().join(diagaxes[0], ax)
ax.axis("off")
for i,c in enumerate(df.columns):
axes[i,i].tick_params(left=False, labelleft=False,
bottom=False, labelbottom=False)
axes[i,0].set_ylabel(c)
axes[-1,i].set_xlabel(c)
axes[i,0].tick_params(left=True, labelleft=True)
axes[-1,i].tick_params(bottom=True, labelbottom=True)
return axes, diagaxes
df = pd.DataFrame(np.random.randn(1000, 4), columns=['A','B','C','D'])
axes,diagaxes = scatter_matrix(df, diagonal='kde', alpha=0.5)
plt.show()
添加回答
举报