为了账号安全,请及时绑定邮箱和手机立即绑定

PyTorch性能优化:有效收集度量,避免TorchMetrics的陷阱

PyTorch 模型性能分析与优化(第 7 部分)

照片由 Darling Arias 提供,来自 Unsplash

指标收集是每个机器学习项目不可或缺的一部分,让我们能够追踪模型表现并监控训练进展。理想情况下,指标的收集和计算应不增加训练过程的额外开销。然而,就像训练循环中的其他组件一样,低效的指标计算会引入不必要的开销,从而增加训练步所需的时间,增加训练成本。

这是我们在PyTorch中的性能剖析和优化系列中的第七篇文章。该系列旨在强调性能分析和优化在机器学习开发过程中的关键角色。每篇文章都聚焦于训练流程的不同阶段,展示了实用的工具和技巧,用于分析并提高资源利用效率和减少运行时间。

这一期的重点在于指标收集。我们将展示一个简单的指标收集实现方式如何损害运行时性能,并介绍一些分析和优化指标收集的方法和工具。

为了实现我们的指标收集任务,我们将使用TorchMetrics,一个流行的库,用于简化和标准化PyTorch中的指标计算,详情请参阅TorchMetrics。我们的目标将是以下几点:

  1. 展示由简单实现引起的指标收集的运行时开销
  2. 使用PyTorch Profiler来定位指标计算带来的性能瓶颈
  3. 展示减少指标收集开销的优化技术

为了便于讨论,我们将定义一个示例的PyTorch模型,并评估度量收集功能如何影响其运行时性能。我们将在NVIDIA A40 GPU上进行实验,使用 PyTorch 2.5.1 docker 镜像以及 TorchMetrics 1.6.1

需要注意的是,度量收集的行为方式会因硬件、运行时环境和模型架构的不同而可能会有很大不同。文中提供的代码示例仅作为演示使用。请不要误解我们提到的任何工具或技术是推荐使用的。

简单的ResNet模型

在下面的代码里,我们定义了一个简单的基于ResNet-18的基础网络的图像分类模型。

# 导入时间模块
import time  
# 导入torch库
import torch  
# 导入torchvision库
import torchvision  

# 设备设置为GPU
device = "cuda"  

# 加载预训练的ResNet18模型,并将其移动到指定设备
model = torchvision.models.resnet18().to(device)  
# 定义损失函数为交叉熵损失
criterion = torch.nn.CrossEntropyLoss()  
# 定义优化器为随机梯度下降算法
optimizer = torch.optim.SGD(model.parameters())

我们定义了一个合成的数据集,我们会用它来训练我们的玩具模型。

    from torch.utils.data import Dataset, DataLoader  

    # 这是一个包含随机图像和标签的数据集  
    class FakeDataset(Dataset):  
        def __len__(self):  
            return 100000000  

        def __getitem__(self, index):  
            rand_image = torch.randn([3, 224, 224], dtype=torch.float32)  
            label = torch.tensor(data=index % 1000, dtype=torch.int64)  
            return rand_image, label  

    train_set = FakeDataset()  

    batch_size = 128  # 批次大小为128  
    num_workers = 12  # 工作线程数为12  

    train_loader = DataLoader(  
        dataset=train_set,  
        batch_size=batch_size,  
        num_workers=num_workers,  
        pin_memory=True  # 启用内存固定  
    )  # 创建数据加载器

我们定义了一系列来自TorchMetrics的标准指标,并使用一个开关来启用或禁用计算指标。

    from torchmetrics import (  
        MeanMetric,  
        Accuracy,  
        Precision,  
        Recall,  
        F1Score,  
    )  

    # 切换以启用或禁用指标收集  
    capture_metrics = False  

    if capture_metrics:  
            metrics = {  
            "avg_loss": MeanMetric(),  
            "accuracy": Accuracy(task="多类别", num_classes=1000),  
            "precision": Precision(task="多类别", num_classes=1000),  
            "recall": Recall(task="多类别", num_classes=1000),  
            "f1_score": F1Score(task="多类别", num_classes=1000),  
        }  

        # 将指标移到设备上  
        metrics = {name: metric.to(device) for name, metric in metrics.items()}

接下来,我们定义一个PyTorch Profiler实例,以及一个控制开关,允许我们启用或禁用性能分析。有关使用PyTorch Profiler的详细指南,请参阅本系列的第一篇博客文章。

    从torch导入profiler

    # 启用/禁用分析
    enable_profiler = True

    if enable_profiler:
        prof = profiler.profile(
            schedule=profiler.schedule(wait=10, warmup=2, active=3, repeat=1),
            on_trace_ready=profiler.tensorboard_trace_handler("./logs/"),
            profile_memory=True,
            with_stack=True
        )
        prof.start()

最后,我们定义一个标准的训练过程:

    model.train()  

    t0 = time.perf_counter()  
    total_time = 0  
    count = 0  

    for idx, (data, target) in enumerate(train_loader):  
        data = data.to(device, non_blocking=True)  
        target = target.to(device, non_blocking=True)  
        optimizer.zero_grad()  
        output = model(data)  
        loss = criterion(output, target)  
        loss.backward()  
        optimizer.step()  

        if capture_metrics:  
            # 更新指标数据  
            metrics["avg_loss"].update(loss)  
            for name, metric in metrics.items():  
                if name != "avg_loss":  
                    metric.update(output, target)  

            if (idx + 1) % 100 == 0:  
                # 计算  
                metric_results = {  
                    name: metric.compute().item()   
                        for name, metric in metrics.items()  
                }  
                # 打印指标  
                print(f"Step {idx + 1}: {metric_results}")  
                # 重置所有指标数据  
                for metric in metrics.values():  
                    metric.reset()  

        elif (idx + 1) % 100 == 0:  
            # 打印最后一个损失值  
            print(f"Step {idx + 1}: Loss = {loss.item():.4f}")  

        batch_time = time.perf_counter() - t0  
        t0 = time.perf_counter()  
        if idx > 10:  # 跳过前几步  
            total_time += batch_time  
            count += 1  

        if enable_profiler:  
            prof.step()  

        if idx > 200:  
            break  

    if enable_profiler:  
        prof.stop()  

    avg_time = total_time / count  
    print(f'平均每步耗时: {avg_time}')  
    print(f'处理速度: {batch_size / avg_time:.2f} images/sec')
指标采集的开销

为了衡量度量指标计算对训练步骤时间的影响,我们进行了实验,分别运行了带有和不带有指标计算的训练脚本两次。结果见下表。

简单指标收集的开销——作者

我们天真地收集指标导致运行时性能几乎下降了10个百分点!! 虽然收集指标对机器学习开发很重要,但它通常只涉及相对简单的数学运算,几乎不会带来如此大的负担。这是为什么呢??

使用 PyTorch 分析器找出性能问题

为了更清楚地找到性能下降的原因,我们再次运行了训练脚本,并启用了PyTorch Profiler工具。下面是得到的跟踪结果:

度量收集实验痕迹,作者的

痕迹显示反复出现的“cudaStreamSynchronize”操作与显着下降的GPU利用率同时发生。这类“CPU-GPU同步”事件在我们系列的第二部分中有详细讨论。在典型的训练步骤中,CPU和GPU并行工作:CPU负责任务管理,比如将数据传输到GPU和加载内核,而GPU则在输入数据上执行模型并更新其权重。理想情况下,我们希望尽量减少CPU和GPU之间的同步点,以最大化其性能。然而在这里,我们看到指标收集触发了一个同步点,通过执行CPU到GPU的数据复制。这要求CPU暂停其处理,直到GPU跟上,反过来又导致GPU等待CPU继续加载后续内核。总的来说,这些同步点导致了CPU和GPU使用效率低下。我们的指标收集实现为每个训练步骤添加了八个这样的同步点。

更仔细地检查痕迹发现,同步事件信号来自于 update 调用,来自 MeanMetric TorchMetric 对象。对于有经验的性能分析专家来说,这可能足以确定根本原因,但我们会更进一步,利用 torch.profiler.record_function 工具来找出具体的有问题的代码行。

使用record_function进行性能分析

为了精准定位同步事件的来源,我们扩展了MeanMetric类,并重写了update方法,使用了record_function上下文块。通过这种方法,我们能够对方法中的每一项操作进行性能分析,从而识别性能瓶颈。

    class ProfileMeanMetric(MeanMetric):  
        def update(self, value, weight=1.0):  
            # 将权重广播到与值相同的形状  
            with profiler.record_function("处理值"):  
                if not isinstance(value, torch.Tensor):  
                    value = torch.as_tensor(value, dtype=self.dtype,  
                                            device=self.device)  
            with profiler.record_function("处理权重"):  
                if weight is not None and not isinstance(weight, torch.Tensor):  
                    weight = torch.as_tensor(weight, dtype=self.dtype,  
                                             device=self.device)  
            with profiler.record_function("广播权重"):  
                weight = torch.broadcast_to(weight, value.shape)  
            with profiler.record_function("类型转换及NAN值检查"):  
                value, weight = self._cast_and_nan_check_input(value, weight)  

            # 如果值的数量为0,则直接返回
            if value.numel() == 0:  
                return  

            with profiler.record_function("更新值"):  
                self.mean_value += (value * weight).sum()  
            with profiler.record_function("更新权值"):  
                self.weight += weight.sum()

接着,我们将 avg_loss 指标更新为使用新创建的 ProfileMeanMetric 指标,并再次运行了训练脚本。

使用record_function记录度量收集的痕迹(通过性能调试使用分析器)(作者提供)

更新后的跟踪显示同步事件起源于以下行:

weight = torch.as_tensor(weight, dtype=self.dtype, device=self.device) # 将weight转换为torch.tensor类型,并指定数据类型和运行设备

此操作将默认的标量值 weight=1.0 转换成 PyTorch 张量并放到 GPU 上。触发同步事件是因为这个操作会触发从 CPU 到 GPU 的数据复制,CPU 需要等待 GPU 处理复制的数据。

优化一:指定权重

现在我们找到了问题的根源,我们只需在更新调用中设置一个 权重 值,就能轻松解决这个问题。这就能避免运行时将默认标量 weight=1.0 转换为 GPU 上的张量,从而避免同步事件的发生。

    # 更新指标值
    # 如果需要捕获指标
    if capture_metric:  
        metrics["avg_loss"].update(loss, weight=torch.ones_like(loss))
        # 更新平均损失,权重设置为与损失相同形状的全1张量

在重新运行脚本并应用此更改后,我们成功地消除了最初的同步事件……却发现了一个新的同步事件,这次是由 _cast_and_nan_check_input 函数引起的:

度量收集跟踪(优化一)(作者:某某)

利用record_function进行性能剖析

为了探索我们的新同步活动,我们扩展了自定义指标,并增加了一些额外的性能探针,并重新运行了脚本。

    class ProfileMeanMetric(MeanMetric):  
        def update(self, value, weight = 1.0):  
            # 将权重广播到值的大小  
            with profiler.record_function("处理值"):  
                if not isinstance(value, torch.Tensor):  
                    value = torch.as_tensor(value, dtype=self.dtype,  
                                            device=self.device)  
            with profiler.record_function("处理权重值"):  
                if weight is not None and not isinstance(weight, torch.Tensor):  
                    weight = torch.as_tensor(weight, dtype=self.dtype,  
                                             device=self.device)  
            with profiler.record_function("广播权重"):  
                weight = torch.broadcast_to(weight, value.shape)  
            with profiler.record_function("NaN检查和转换"):  
                value, weight = self._cast_and_nan_check_input(value, weight)  

            if value.numel() == 0:  
                return  

            with profiler.record_function("更新值和权重"):  
                self.mean_value += (value * weight).sum()  
            with profiler.record_function("更新权重"):  
                self.weight += weight.sum()  

        def _cast_and_nan_check_input(self, x, weight = None):  
            """将输入 `x` 转换为张量并检查其是否包含NaN值。"""  
            with profiler.record_function("处理x"):  
                if not isinstance(x, torch.Tensor):  
                    x = torch.as_tensor(x, dtype=self.dtype,  
                                        device=self.device)  
            with profiler.record_function("处理权重"):  
                if weight is not None and not isinstance(weight, torch.Tensor):  
                    weight = torch.as_tensor(weight, dtype=self.dtype,  
                                             device=self.device)  
                nans = torch.isnan(x)  
                if weight is not None:  
                    nans_weight = torch.isnan(weight)  
                else:  
                    nans_weight = torch.zeros_like(nans).bool()  
                    weight = torch.ones_like(x)  

            with profiler.record_function("是否存在NaN"):  
                anynans = nans.any() or nans_weight.any()  

            with profiler.record_function("检查和处理NaN"):  
                if anynans:  
                    if self.nan策略 == "error":  
                        raise RuntimeError("在张量中检测到 `nan` 值")  
                    if self.nan策略 in ("ignore", "warn"):  
                        if self.nan策略 == "warn":  
                            print("在张量中检测到 `nan` 值。这些值将被忽略。")  
                        x = x[~(nans | nans_weight)]  
                        weight = weight[~(nans | nans_weight)]  
                    else:  
                        if not isinstance(self.nan策略, float):  
                            raise ValueError(f"nan策略 应该是浮点数,但您传入了 {self.nan策略}")  
                        x[nans | nans_weight] = self.nan策略  
                        weight[nans | nans_weight] = self.nan策略  

            with profiler.record_function("返回处理后的张量"):  
                retval = x.to(self.dtype), weight.to(self.dtype)  
            return retval

捕获到的轨迹如下所示:

使用 record_function 进行度量收集的追踪 — 第二部分(作者撰写)

这直接指向了错误的那一行。

    anynans = nans.any() or nans_weight.any() # 检查是否存在任何缺失值

此操作会查找输入张量中的NaN值,但是这会导致需要CPU和GPU之间昂贵的同步,因为该操作需要将数据从GPU复制到CPU。

仔细查看了TorchMetric的BaseAggregator 类后,我们发现了几种处理NAN值更新的方式,所有这些处理方式都会经过那个有问题的代码行。然而,对于我们计算平均损失指标的需求来说,这个检查对我们来说是不必要的,也不值得为此牺牲运行性能。

优化2:关闭NAN值检查

为了消除这种开销,我们建议通过重写 _cast_and_nan_check_input 函数来禁用 NaN 值检验。我们实现了一个动态方法,而不是静态重写的方式,它可以灵活地应用于任何 BaseAggregator 类的子类。

    从 torchmetrics.aggregation 导入 BaseAggregator  

    def suppress_nan_check(MetricClass):  
        assert issubclass(MetricClass, BaseAggregator), MetricClass  
        # 定义一个名为 DisableNanCheck 的类,该类继承自 MetricClass。
        class DisableNanCheck(MetricClass):  
            # 定义一个名为 _cast_and_nan_check_input 的方法,该方法用于转换输入并检查是否包含 NaN 值。
            def _cast_and_nan_check_input(self, x, weight=None):  
                if not isinstance(x, torch.Tensor):  
                    x = torch.as_tensor(x, dtype=self.dtype,   
                                        device=self.device)  
                if weight is not None and not isinstance(weight, torch.Tensor):  
                    weight = torch.as_tensor(weight, dtype=self.dtype,  
                                             device=self.device)  
                if weight is None:  
                    weight = torch.ones_like(x)  
                return x.to(self.dtype), weight.to(self.dtype)  
        return DisableNanCheck  

    NoNanMeanMetric = suppress_nan_check(MeanMetric)  

    # 将 NoNanMeanMetric 实例添加到 metrics 字典中,并将其设置为指定的设备上。
    metrics["avg_loss"] = NoNanMeanMetric().to(device)
优化发布结果:成功了

在实施了两项优化措施——指定权重值和禁用NaN检查之后,我们发现每步时间性能和GPU利用率与我们的基线实验一致。此外,PyTorch Profiler生成的跟踪显示,所有与指标收集相关的“cudaStreamSynchronize”事件已被消除。通过几次小的改动,我们已经将训练成本减少了大约10%,同时保持指标收集行为不变。

在下一节中,我们将探讨另一个度量收集的优化。

示例:优化度量设备的布局

在前面的部分里,度量值存放在GPU上,因此在GPU上存取和计算这些度量值更为合理。然而,当我们要聚合的值在CPU上时,将度量值移到CPU上会更方便,以避免不必要的设备传输。

在下面的代码片段中,我们对脚本进行了修改,使用[MeanMetric](在CPU上)来计算每步平均时间。此更改对我们的训练步骤的运行时性能没有影响。

    avg_time = 初始化平均时间计算指标()  
    t0 = time.perf_counter()  

    for idx, (data, target) in enumerate(train_loader):  
        # 非阻塞地将数据移动到设备  
        data = data.to(device, non_blocking=True)  
        target = target.to(device, non_blocking=True)  

        optimizer.zero_grad()  
        output = model(data)  
        loss = criterion(output, target)  
        loss.backward()  
        optimizer.step()  

        if capture_metrics:  
            # 如果开启指标捕捉  
            metrics["avg_loss"].update(loss)  
            for name, metric in metrics.items():  
                if name != "avg_loss":  
                    metric.update(output, target)  

            if (idx + 1) % 100 == 0:  
                # 计算指标  
                metric_results = {  
                    name: metric.compute().item()  
                        for name, metric in metrics.items()  
                }  
                # 打印步骤和指标  
                print(f"步骤 {idx + 1}: {metric_results}")  
                # 将这些指标重置为初始状态  
                for metric in metrics.values():  
                    metric.reset()  

        elif (idx + 1) % 100 == 0:  
            # 打印最后一个损失值  
            print(f"步骤 {idx + 1}: 损失 = {loss.item():.4f}")  

        batch_time = time.perf_counter() - t0  
        t0 = time.perf_counter()  
        if idx > 10:  # 跳过前几个批次  
            avg_time.update(batch_time)  

        if enable_profiler:  
            # 如果启用性能分析器  
            prof.step()  

        if idx > 200:  
            break  

    if enable_profiler:  
        prof.stop()  

    avg_time = avg_time.compute().item()  
    print(f'平均步骤耗时: {avg_time}')  
    print(f'每秒处理图像数: {batch_size/avg_time:.2f} 张图像')  

为了解释这个问题,我们修改了模型定义,使其支持分布式数据并行(DDP)。


    # 切换以启用或禁用DDP功能
    use_ddp = True

    if use_ddp:
        import os
        import torch.distributed as dist
        from torch.nn.parallel import DistributedDataParallel as DDP
        os.environ["MASTER_ADDR"] = "127.0.0.1"
        os.environ["MASTER_PORT"] = "29500"
        dist.init_process_group("nccl", rank=0, world_size=1)
        torch.cuda.set_device(0)
        model = DDP(torchvision.models.resnet18().to(device))
    else:
        model = torchvision.models.resnet18().to(device)

    # 插入训练循环

    # 在脚本末尾,如果启用了DDP,则添加以下代码:
    if use_ddp:
        # 销毁进程组以结束分布式训练
        dist.destroy_process_group()

DDP 修改引起了以下的问题:

运行时错误:与 CPU 设备相关的后端未定义

默认情况下,分布式训练中的指标默认会在这所有使用的设备之间同步。然而,由于 DDP 使用的同步机制不支持存储在 CPU 上的指标,因此不支持这些存放在 CPU 上的指标。

一种解决办法是关闭跨设备的指标同步。

    avg_time = NoNanMeanMetric(sync_on_compute=False) # 这里计算平均时间,忽略NaN值

在这种情况下,我们测量平均时间时,这种解决方案是可以接受的。然而,在某些情况下,同步指标是必不可少的,我们可能别无选择,只能将指标移动到GPU。

    avg_time = NoNanMeanMetric().to(device) # 计算平均时间,跳过非数值(NaN)数据,并将结果移动到指定设备

代码更新导致一个新的CPU和GPU同步事件,这个事件是由于更新函数。

按作者记录avg_time指标收集情况

这个同步事件几乎不会让人感到意外——毕竟,我们正用位于CPU上的值更新GPU指标,这需要一次内存复制。然而,在标量指标的情况下,通过简单的优化,可以完全避免数据传输。

优化 3:用张量而不是标量来更新指标值

解决方案很简单:不再用浮点数直接更新指标,而是在调用update函数之前将值转换成张量。具体来说,

    # 将batch_time转换为张量
    batch_time = torch.as_tensor(batch_time)

    # 使用与batch_time相同形状的全1张量更新avg_time
    avg_time.update(batch_time, torch.ones_like(batch_time))

这个小改动消除了同步事件,并将步长时间恢复到了原来的水平。

乍一看,这个结果看起来可能有点出乎意料:我们可能会认为用CPU张量更新GPU的指标仍然需要进行内存复制。然而,PyTorch 对标量运算进行了优化,通过使用一个专用的内核,该内核可以高效地完成加法而不需要明确的数据传输。这避免了原本昂贵的同步事件。

总结

在这篇文章中,我们探讨了使用TorchMetrics时一种初级的方法如何引入CPU-GPU同步事件,并显著降低了PyTorch的训练性能。使用PyTorch Profiler,我们识别并优化了导致这些同步事件的代码行。

我们已经在TorchMetrics的GitHub页面上创建了一个专门的拉取请求,涵盖了本文中讨论的一些优化。欢迎随时贡献您自己的改进和优化!

点击查看更多内容
TA 点赞

若觉得本文不错,就分享一下吧!

评论

作者其他优质文章

正在加载中
  • 推荐
  • 评论
  • 收藏
  • 共同学习,写下你的评论
感谢您的支持,我会继续努力的~
扫码打赏,你说多少就多少
赞赏金额会直接到老师账户
支付方式
打开微信扫一扫,即可进行扫码打赏哦
今天注册有机会得

100积分直接送

付费专栏免费学

大额优惠券免费领

立即参与 放弃机会
意见反馈 帮助中心 APP下载
官方微信

举报

0/150
提交
取消