为了账号安全,请及时绑定邮箱和手机立即绑定

使用 matplotlib 显示 SVG 图片(批量)

载入数据迭代器中一批量图片,并以 SVG 格式显示图片:

from pylab import plt, mpl
from IPython import display


class Loader:
    """
    方法
    ========
    L 为该类的实例
    len(L)::返回 batch 的批数
    iter(L)::即为数据迭代器

    Return
    ========
    可迭代对象(numpy 对象)
    """

    def __init__(self, batch_size, X, Y=None, shuffle=True, name=None):
        '''
        X, Y 均为类 numpy, 可以是 HDF5 
        '''
        if name is not None:
            self.name = name
        self.X = X[:]
        if Y is None:
            # print('不存在标签!')
            self.Y = None
        else:
            self.Y = Y[:]
        self.batch_size = batch_size
        self.shuffle = shuffle

    def __iter__(self):
        n = len(self.X)
        idx = np.arange(n)

        if self.shuffle:
            np.random.shuffle(idx)

        for k in range(0, n, self.batch_size):
            K = idx[k:min(k + self.batch_size, n)].tolist()
            if self.Y is None:
                yield np.take(self.X[:], K, 0)
            else:
                yield np.take(self.X[:], K, 0), np.take(self.Y[:], K, 0)

    def __len__(self):
        return round(len(self.X) / self.batch_size)

    def use_svg_display(self):
        # 用矢量图显示。
        display.set_matplotlib_formats('svg')

    def show_imgs(self, label_names, imgs, labels, figsize=(7, 7)):
        '''
        展示 多张图片
        '''
        n = imgs.shape[0]
        h, w = 4, int(n / 4)
        self.use_svg_display()
        _, ax = plt.subplots(h, w, figsize=figsize)  # 设置图的尺寸
        K = np.arange(n).reshape((h, w))
        names = np.asanyarray(
            [label_names[label] for label in labels], dtype='U')
        names = names.reshape((h, w))
        for i in range(h):
            for j in range(w):
                img = imgs[K[i, j]]
                ax[i][j].imshow(img)
                ax[i][j].axes.get_yaxis().set_visible(False)
                ax[i][j].axes.set_xlabel(names[i][j])
                ax[i][j].set_xticks([])
        plt.show()
import tables as tb
h5 = tb.open_file('E:/xdata/X.h5')

data = h5.root.cifar10

batch_size = 32
trainset = Loader(batch_size, data.trainX, data.trainY, shuffle=True, name='train')

for imgs, labels in iter(trainset):
    trainset.show_imgs(data.label_names, imgs, labels)
    break
点击查看更多内容
TA 点赞

若觉得本文不错,就分享一下吧!

评论

作者其他优质文章

正在加载中
  • 推荐
  • 评论
  • 收藏
  • 共同学习,写下你的评论
感谢您的支持,我会继续努力的~
扫码打赏,你说多少就多少
赞赏金额会直接到老师账户
支付方式
打开微信扫一扫,即可进行扫码打赏哦
今天注册有机会得

100积分直接送

付费专栏免费学

大额优惠券免费领

立即参与 放弃机会
意见反馈 帮助中心 APP下载
官方微信

举报

0/150
提交
取消