为了账号安全,请及时绑定邮箱和手机立即绑定

对齐上下轴的刻度坐标

对齐上下轴的刻度坐标

牧羊人nacy 2021-11-02 19:29:54
我想构建一个图形,在其底部 x 轴上有一组刻度,在其顶部 x 轴上有另一组与底部刻度对齐的刻度。特别是在我的情况下,这些是批次和时代。对于n底部的每个批次点(不一定是刻度),我希望顶部有一个纪元刻度。考虑这个例子:import numpy as npimport matplotlib.pyplot as pltbatches = np.arange(1,101)epoch_ends = batches[[(i*10)-1 for i in range(1,11)]]accuracy = np.apply_along_axis(arr=batches, axis=0, func1d=lambda x : x/len(batches))loss = np.apply_along_axis(arr=batches, axis=0, func1d=lambda x : 1 - (x/len(batches)))fig, ax1 = plt.subplots( nrows=1, ncols=1 )ax2 = ax1.twinx()ax3 = ax1.twiny()ax1.set_xlabel('batches')ax1.set_xticks(np.arange(1, len(batches)+1, 9))ax1.set_ylabel('accuracy')ax1.grid()ax2.set_ylabel('loss')ax2.set_yticklabels(np.linspace(3, 10, 9))ax3.set_xlabel('epochs')ax3.set_xticks(epoch_ends)ax3.set_xticklabels(range(1, len(epoch_ends)+1))acc_plt = ax1.plot(batches, accuracy, color='red', label='acc')loss_plt = ax2.plot(batches, loss, color='blue', label='loss')lns = acc_plt+loss_pltlabs = [l.get_label() for l in lns]ax1.legend(lns, labs, loc=2)plt.show()batches和epoch_ends分别是这样的[  1   2   3   4   5   6   7   8   9  10  11  12  13  14  15  16  17  18  19  20  21  22  23  24  25  26  27  28  29  30  31  32  33  34  35  36  37  38  39  40  41  42  43  44  45  46  47  48  49  50  51  52  53  54  55  56  57  58  59  60  61  62  63  64  65  66  67  68  69  70  71  72  73  74  75  76  77  78  79  80  81  82  83  84  85  86  87  88  89  90  91  92  93  94  95  96  97  98  99 100][ 10  20  30  40  50  60  70  80  90 100]所以我希望纪元刻度 1 与批次 x-coorrdiante 10、2 与 20 等对齐。但是正如您在图片中看到的,它们并没有排成一行。我需要在我的代码中更改什么才能使其正常工作?
查看完整描述

2 回答

?
青春有我

TA贡献1784条经验 获得超8个赞

这是对齐它们的一种方法。思路如下:


首先使用 ( ax1)在下 x 轴上绘制数据

然后使用设置上 x 轴的限制与下 x 轴相同 ax3.set_xlim(ax1.get_xlim())

然后在对应于下 x 轴值 (10, 20, 30, ..., 90, 100) 的位置设置上 x 轴的刻度

最后,使用 重命名刻度标签ax3.set_xticklabels()。

这是代码:我正在用注释替换代码中已有的部分#。


# imports and define data and compute accuracy and loss here


# Initiate figure and axis objects here


ax1.set_xlabel('batches')

ax1.set_xticks(np.arange(1, len(batches)+1, 9))

ax1.set_ylabel('accuracy')

ax1.grid()


acc_plt = ax1.plot(batches, accuracy, color='red', label='acc')

loss_plt = ax2.plot(batches, loss, color='blue', label='loss')


ax2.set_ylabel('loss')

ax2.set_yticklabels(np.linspace(3, 10, 9))


new_tick_locations = np.arange(1, 11)*10

new_tick_labels = range(1, 11)


ax3.set_xlabel('epochs')

ax3.set_xlim(ax1.get_xlim())

ax3.set_xticks(new_tick_locations)

ax3.set_xticklabels(new_tick_labels)


# Set legends and show the plot

//img1.sycdn.imooc.com//6181216a0001069305490405.jpg

查看完整回答
反对 回复 2021-11-02
?
侃侃尔雅

TA贡献1801条经验 获得超16个赞

这会失去您对刻度的控制(您没有说是否需要),但会像您所说的那样排列两个 x 轴:(注释新行)


import numpy as np

import matplotlib.pyplot as plt


batches = np.arange(1,101)

epoch_ends = batches[[(i*10)-1 for i in range(1,11)]]

accuracy = np.apply_along_axis(arr=batches, axis=0, func1d=lambda x : x/len(batches))

loss = np.apply_along_axis(arr=batches, axis=0, func1d=lambda x : 1 - (x/len(batches)))


fig, ax1 = plt.subplots( nrows=1, ncols=1 )

ax2 = ax1.twinx()

ax3 = ax1.twiny()


ax1.set_xlabel('batches')

#ax1.set_xticks(np.arange(1, len(batches)+1, 9))

ax1.set_xlim(0, 100)                               # Set xlim

ax1.set_ylabel('accuracy')

ax1.grid()


ax2.set_ylabel('loss')

ax2.set_yticklabels(np.linspace(3, 10, 9))


ax3.set_xlabel('epochs')

#ax3.set_xticks(epoch_ends)

#ax3.set_xticklabels(range(1, len(epoch_ends)+1))

ax3.set_xlim(0, 10)                                # Set xlim


acc_plt = ax1.plot(batches, accuracy, color='red', label='acc')

loss_plt = ax2.plot(batches, loss, color='blue', label='loss')


lns = acc_plt+loss_plt

labs = [l.get_label() for l in lns]

ax1.legend(lns, labs, loc=2)


plt.show()


查看完整回答
反对 回复 2021-11-02
  • 2 回答
  • 0 关注
  • 205 浏览
慕课专栏
更多

添加回答

举报

0/150
提交
取消
意见反馈 帮助中心 APP下载
官方微信