为了账号安全,请及时绑定邮箱和手机立即绑定

2 DataLoader-庖丁解牛之pytorch

标签:
机器学习

数据集已经有了,直接使用不就得了,实际数据加载是一个很大的问题,涉及内存、cpu、GPU的利用关系,因此专门设计一个数据加载类DataLoader,我们先看一看这个类的参数

* dataset (Dataset): 装载的数据集
* batch_size (int, optional): 每批加载批次大小,默认1* shuffle (bool, optional): 每个epoch是否混淆
* sampler (Sampler, optional): 采样器,与shuffle互斥
* batch_sampler (Sampler, optional): 和sampler类似,
* num_workers (int, optional): 多进程并发装载,subprocess工作进程个数,默认0* collate_fn (callable, optional): 合并mini-batch的采样列表
* pin_memory (bool, optional): 锁页内存
* drop_last (bool, optional): 丢弃最后一个不完整的batch
* timeout (numeric, optional): 收集工作批次的等待时间    
* worker_init_fn (callable, optional): 每个工作进程根据worker ID调用

参数一大堆,但是函数就三个

__setattr__(self, attr, val)  设置属性
__iter__(self)                    迭代
__len__(self)                     长度

采样器

我们先看看采样器
采样器有如下几个

Sampler 基本采样器基类SequentialSampler 序列采样器 iter(range(len(self.data_source)))RandomSampler 随机采样器iter(torch.randperm(len(self.data_source)).tolist())SubsetRandomSampler 子集随机采样器 (self.indices[i] for i in torch.randperm(len(self.indices)))WeightRandomSampler 权重随机采样器iter(torch.multinomial(self.weights, self.num_samples, self.replacement))BatchSampler 批处理采样器DistributedSampler 分布采样器
from torch.utils.data import BatchSampler, SequentialSampler
list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False))

目前系统的采样器只有这几种,对于DataLoader来说批次也是采样过程,因此都归结为采样器。DataLoader的最重要一个函数是迭代器

迭代器

迭代器根据采样器的处理,利用多线程技术,分批次进行加载,这也是DataLoader的核心,该进程首先申请两类队列,一类是索引队列,一类是工作结果队列,用于存储进程之间的结果。之后引入最重要的工作进程_worker_loop这是一个全局函数,从索引队列中领取任务,将结果放到工作结果队列中,源码如下:

......    while True:        try:            # 领任务
            r = index_queue.get(timeout=MANAGER_STATUS_CHECK_INTERVAL)        except queue.Empty:            if watchdog.is_alive():                continue
            else:                break
        if r is None:            break
        idx, batch_indices = r        try:
            samples = collate_fn([dataset[i] for i in batch_indices]) # 干活
        except Exception:
            data_queue.put((idx, ExceptionWrapper(sys.exc_info()))) # 撂挑子
        else:
            data_queue.put((idx, samples)) # 交结果
            del samples
......

工作管理进程收集上交结果

def _worker_manager_loop(in_queue, out_queue, done_event, pin_memory, device_id):
    if pin_memory:
        torch.cuda.set_device(device_id)    while True:        try:
            r = in_queue.get()        except Exception:            if done_event.is_set():                return
            raise
        if r is None:            break
        if isinstance(r[1], ExceptionWrapper):
            out_queue.put(r)            continue
        idx, batch = r        try:            if pin_memory:
                batch = pin_memory_batch(batch)        except Exception:
            out_queue.put((idx, ExceptionWrapper(sys.exc_info())))        else:
            out_queue.put((idx, batch))

在工作管理进程收集结果的时候有个操作比较特别,pin_memory_batch称锁页内存(pinned memory or page locked memory):创建DataLoader时,设置pin_memory=True,则意味着生成的Tensor数据最开始是属于内存中的锁页内存,这样将内存的Tensor转义到GPU的显存就会更快一些。
主机中的内存,有两种存在方式,一是锁页,二是不锁页,锁页内存存放的内容在任何情况下都不会与主机的虚拟内存进行交换(注:虚拟内存就是硬盘),而不锁页内存在主机内存不足时,数据会存放在虚拟内存中。
而显卡中的显存全部是锁页内存。
当计算机的内存充足的时候,可以设置pin_memory=True。当系统卡住,或者交换内存使用过多的时候,设置pin_memory=False。因为pin_memory与电脑硬件性能有关,pytorch开发者不能确保每一个炼丹玩家都有高端设备,因此pin_memory默认为False。
数据装载部分主要有迭代器来实现,此处代码不清晰,主要过程就是多线程、内存管理、分批读入等



作者:readilen
链接:https://www.jianshu.com/p/a32ae0294223


点击查看更多内容
TA 点赞

若觉得本文不错,就分享一下吧!

评论

作者其他优质文章

正在加载中
  • 推荐
  • 评论
  • 收藏
  • 共同学习,写下你的评论
感谢您的支持,我会继续努力的~
扫码打赏,你说多少就多少
赞赏金额会直接到老师账户
支付方式
打开微信扫一扫,即可进行扫码打赏哦
今天注册有机会得

100积分直接送

付费专栏免费学

大额优惠券免费领

立即参与 放弃机会
意见反馈 帮助中心 APP下载
官方微信

举报

0/150
提交
取消