为了账号安全,请及时绑定邮箱和手机立即绑定

Transformer模型处理变长输入序列的优化技巧

PyTorch NestedTensors、FlashAttention2 (一种技术) 和 xFormers (一种模型) 怎样提高性能并减少 AI 成本

照片由Tanja ZöllnerUnsplash拍摄

随着生成式AI(genAI)模型的流行度和规模的增加,它们的训练和部署所需的计算资源和成本也随之增加。优化这些模型对于提升它们的运行性能和降低运营成本至关重要。在现代genAI系统中,Transformer架构及其注意力机制是核心,这一机制特别消耗计算资源。

之前的帖子中,我们展示了使用优化的注意力核如何可以显著提升Transformer模型的性能。在这篇文章中,我们将继续这一探讨,解决可变长度输入序列的问题,这是文档、代码、时间序列等多种现实世界数据固有的特性。

处理变长输入的挑战

在典型的深度学习工作负载中,个体样本会被分批处理,然后复制到GPU中并输入到AI模型中。分批处理可以提高计算效率,并且通常有助于模型在训练过程中收敛。通常,批次处理涉及将所有样本张量沿新的批次维度进行堆叠。然而,在长度不同的序列中,torch.stack无法满足要求。

填充及其低效率问题

传统的方法是将输入序列填充到固定长度,然后进行堆叠(stacking)。这种方法需要在模型中适当添加掩码层,以确保输出不受无关张量元素的影响。在注意力层中,填充掩码会指示哪些标记是填充,不应被关注(例如,参见PyTorch MultiheadAttention)。然而,填充会浪费大量的GPU内存,增加成本并减缓开发进度。这对于大规模AI模型来说更是如此。

别垫料,直接连

一种避免填充的方法是拼接(sequence)序列到现有维度而不是堆叠到新维度。与torch.stack不同,torch.cat允许输入形状不同。拼接(sequence)的输出是一个长度等于所有单独序列长度之和的单一序列。为了让这种方法有效,为此单一序列添加一个注意力掩码,以确保每个标记仅关注同一原始序列中的其他标记,这一过程有时被称为文档掩码(Document Masking)。设所有单独序列长度之和为N,采用”大O表示法”,这个掩码的大小需要是O(N²),而一个简单的注意力层(仅在计算注意力分数后应用掩码)的计算复杂度同样为O(N²),这使得这种方法效率极低。

优化注意力层

解决方案是以专门的注意力层的形式来解决这个问题。与标准注意力层不同,标准注意力层会计算全套的 O(N²) 注意力得分,然后忽略无关的得分,这些优化的注意力核只计算相关的得分。在这篇文章中,我们将探讨几种具有各自特点的解决方案,每种解决方案都有其独特的特点。这些包括:

如何把现有的HuggingFace模型整合进来

对于使用预训练模型的团队来说,过渡到这些优化可能看起来有些挑战。这里我们将展示HuggingFace提供的API如何简化这一过程,让操作更简单,让开发人员可以轻松地用最少的代码修改集成这些技术。

这里有一些免责声明
  • 请不要将我们使用任何平台、库或优化技术视为对其使用的认可。最适合您的选项将很大程度上取决于您具体的使用场景。
  • 本文中讨论的一些API处于原型或测试阶段,未来可能会有所改动。
  • 提供的代码示例仅用于演示目的。我们不对它们的准确性、最优性或健壮性作任何保证或承诺。

特别感谢以下人士:Yitzhak LeviPeleg Nahaliel 的对此文的贡献。

玩具版LLM模型

为了方便讨论,我们定义一个简单的生成模型(部分灵感来源于GPT模型,并参考了这里的定义)。如需更全面的构建语言模型的指导,请参考网上众多优秀的教程(例如这里)。

Transformer 模块

我们首先构建一个基础的Transformer模块,该模块专门用于实验不同的注意力机制和优化。尽管我们的模块执行的计算与标准Transformer模块相同,但我们对常用的操作符做了一些小的调整,以支持PyTorch NestedTensor 输入(详情请参见此处)。

    # 通用导入  
    import time, functools  

    # torch 导入  
    import torch  
    from torch.utils.data import Dataset, DataLoader  
    import torch.nn as nn  

    # Transformer 的一些配置  
    BATCH_SIZE = 32  
    NUM_HEADS = 16  
    HEAD_DIM = 64  
    DIM = NUM_HEADS * HEAD_DIM  
    DEPTH = 24  
    NUM_TOKENS = 1024  
    MAX_SEQ_LEN = 1024  
    PAD_ID = 0  
    DEVICE = 'cuda'  

    class MyAttentionBlock(nn.Module):  
        def __init__(  
                self,  
                attn_fn,  
                dim,  
                num_heads,  
                format=None,  
                **kwargs  
        ):  
            super().__init__()  
            self.attn_fn = attn_fn  
            self.num_heads = num_heads  
            self.dim = dim  
            self.head_dim = dim // num_heads  
            self.norm1 = nn.LayerNorm(dim, bias=False)  
            self.norm2 = nn.LayerNorm(dim, bias=False)  
            self.qkv = nn.Linear(dim, dim * 3) # 定义 qkv 线性变换  
            self.proj = nn.Linear(dim, dim) # 定义 proj 线性变换  

            # MLP 层  
            self.fc1 = nn.Linear(dim, dim * 4)  
            self.act = nn.GELU()  
            self.fc2 = nn.Linear(dim * 4, dim)  

            self.permute = functools.partial(torch.transpose, dim0=1, dim1=2)  
            if format == '批量-头-维度-深度':  
                self.permute = nn.Identity()  

        def mlp(self, x):  
            x = self.fc1(x)  
            x = self.act(x)  
            x = self.fc2(x)  
            return x  

        def reshape_and_permute(self, x, batch_size):  
            x = x.view(batch_size, -1, self.num_heads, self.head_dim)  
            return self.permute(x)  

        def forward(self, x_in, attn_mask=None):  
            batch_size = x_in.size(0)  
            x = self.norm1(x_in)  
            qkv = self.qkv(x)  

            # 我们首先分割后重新排列 q, k, v,以支持 PyTorch 的 Nested Tensors  
            q, k, v = qkv.chunk(3, -1)  
            q = self.reshape_and_permute(q, batch_size)  
            k = self.reshape_and_permute(k, batch_size)  
            v = self.reshape_and_permute(v, batch_size)  

            # 调用 attn_fn 并传入 attn_mask  
            x = self.attn_fn(q, k, v, attn_mask=attn_mask)  

            # 重新排列输出  
            x = self.permute(x).reshape(batch_size, -1, self.dim)  
            x = self.proj(x)  
            x = x + x_in  
            x = x + self.mlp(self.norm2(x))  
            return x
Transformer: 解码器模型

基于我们可编程的Transformer模块,我们构建了典型的Transformer解码器模型。

    class MyDecoder(nn.Module):  
        def __init__(  
                self,  
                block_fn,  
                num_tokens,  
                dim,  
                num_heads,  
                num_layers,  
                max_seq_len,  
                pad_idx=None  
        ):  
            super().__init__()  
            self.num_heads = num_heads  
            self.pad_idx = pad_idx  
            self.embedding = nn.Embedding(num_tokens, dim, padding_idx=pad_idx)  
            self.positional_embedding = nn.Embedding(max_seq_len, dim)  
            self.blocks = nn.ModuleList([  
                block_fn(  
                    dim=dim,  
                    num_heads=num_heads  
                )  
                for _ in range(num_layers)])  
            self.output = nn.Linear(dim, num_tokens)  

        def embed_tokens(self, input_ids, position_ids=None):  
            x = self.embedding(input_ids)  
            if position_ids is None:  
                position_ids = torch.arange(input_ids.shape[1],  
                                            device=x.device)  
            x = x + self.positional_embedding(position_ids)  
            return x  

        def forward(self, input_ids, position_ids=None, attn_mask=None):  
            # 将token嵌入并加上位置编码
            x = self.embed_tokens(input_ids, position_ids)  
            if self.pad_idx is not None:  
                # 确保attn_mask为None
                assert attn_mask is None  
                # 创建填充掩码 - 我们假设使用布尔掩码
                attn_mask = (input_ids != self.pad_idx)  
                attn_mask = attn_mask.view(BATCH_SIZE, 1, 1, -1) \  
                    .expand(-1, self.num_heads, -1, -1)  

            for b in self.blocks:  
                # 对于self.blocks中的每一层b:
                x = b(x, attn_mask)  

            logits = self.output(x)  
            # 返回logits
            return logits
可变长度的序列输入

接下来,我们随便选择了固定的序列长度分布,创建一个包含长度各异的序列的数据集,每个序列由随机生成的标记组成。为了简单起见,这样做。在现实场景中,序列长度的分布通常反映了数据的性质,例如文档或音频片段的长度。需要注意的是,长度分布直接影响了因填充带来的计算效率问题。

    # 使用随机数据  
    # 将数据移动到指定设备上的函数定义  
    class FakeDataset(Dataset):  
        # 返回数据集的长度  
        def __len__(self):  
            return 1000000  

        # 获取指定索引的数据项  
        def __getitem__(self, index):  
            length = torch.randint(1, MAX_SEQ_LEN, (1,))  
            sequence = torch.randint(1, NUM_TOKENS, (length + 1,))  
            inputs = sequence[:-1]  
            targets = sequence[1:]  
            return inputs, targets  

    # 对序列进行填充的函数定义  
    def pad_sequence(sequence, length, pad_val):  
        # 填充序列到指定长度  
        return torch.nn.functional.pad(  
            sequence,  
            (0, length - sequence.shape[0]),  # 填充长度计算  
            value=pad_val  
        )  

    # 使用填充处理批次数据的函数定义  
    def collate_with_padding(batch):  
        padded_inputs_list = []  # 用于存储填充后的输入序列  
        padded_targets_list = []  # 用于存储填充后的目标序列  
        for b in batch:  
            padded_inputs_list.append(pad_sequence(b[0], MAX_SEQ_LEN, PAD_ID))  
            padded_targets_list.append(pad_sequence(b[1], MAX_SEQ_LEN, PAD_ID))  
        padded_inputs = torch.stack(padded_inputs_list, dim=0)  
        padded_targets = torch.stack(padded_targets_list, dim=0)  
        return {  
            'inputs': padded_inputs,  
            'targets': padded_targets  
        }  

    # 将数据传输到指定设备上的函数定义  
    def data_to_device(data, device):  
        if isinstance(data, dict):  
            return {  
                key: data_to_device(val,device)  
                for key, val in data.items()  
            }  
        elif isinstance(data, (list, tuple)):  
            return type(data)(  
                data_to_device(val, device) for val in data  
            )  
        elif isinstance(data, torch.Tensor):  
            return data.to(device=device, non_blocking=True)  # 将张量传输到指定设备  
        else:  
            return data.to(device=device)  # 将数据传输到指定设备上
训练/评估环节

最后一步,我们实现一个名为_main_的函数,它可以在不同长度的输入序列上对训练和评估进行执行。

    def main(  
        block_fn,   
        data_collate_fn=collate_with_padding,  
        pad_idx=None,  
        train=True,  
        compile=False  
    ):  
        torch.random.manual_seed(0)  
        device = torch.device(DEVICE)  
        torch.set_float32_matmul_precision("high")  

        # 创建数据集和数据加载器实例  
        dataset = FakeDataset()  
        dataloader = DataLoader(  
            dataset,  
            batch_size=BATCH_SIZE,  
            collate_fn=data_collate_fn,  
            num_workers=12,  
            pin_memory=True,  
            drop_last=True  
        )  

        model = MyDecoder(  
            block_fn=block_fn,  
            num_tokens=NUM_TOKENS,  
            dim=DIM,  
            num_heads=NUM_HEADS,  
            num_layers=DEPTH,  
            max_seq_len=MAX_SEQ_LEN,  
            pad_idx=pad_idx  
        ).to(device)  

        if compile:  
            model = torch.compile(model)  

        # 定义损失函数和优化器  
        criterion = torch.nn.CrossEntropyLoss(ignore_index=PAD_ID)  
        optimizer = torch.optim.SGD(model.parameters())  

        def train_step(model, inputs, targets,   
                       position_ids=None, attn_mask=None):  
            with torch.amp.autocast(DEVICE, dtype=torch.bfloat16):  
                outputs = model(inputs, position_ids, attn_mask)  
                outputs = outputs.view(-1, NUM_TOKENS)  
                targets = targets.flatten()  
                loss = criterion(outputs, targets)  
            optimizer.zero_grad(set_to_none=True)  
            loss.backward()  
            optimizer.step()  

        @torch.no_grad()  
        def eval_step(model, inputs, targets,   
                      position_ids=None, attn_mask=None):  
            with torch.amp.autocast(DEVICE, dtype=torch.bfloat16):  
                outputs = model(inputs, position_ids, attn_mask)  
                if outputs.is_nested:  
                    outputs = outputs.data._values  
                    targets = targets.data._values  
                else:  
                    outputs = outputs.view(-1, NUM_TOKENS)  
                    targets = targets.flatten()  
                loss = criterion(outputs, targets)  
            return loss  

        if train:  
            model.train()  
            step_fn = train_step  
        else:  
            model.eval()  
            step_fn = eval_step  

        t0 = time.perf_counter()  
        summ = 0  
        count = 0  

        for step, data in enumerate(dataloader):  
            # 将数据复制到设备  
            data = data_to_device(data, device=device)  
            step_fn(model, data['inputs'], data['targets'],  
                           position_ids=data.get('indices'),  
                           attn_mask=data.get('attn_mask'))  

            # 记录步时间  
            batch_time = time.perf_counter() - t0  
            if step > 20 and step < 100:  # 跳过最初的几步  
                summ += batch_time  
                count += 1  
            t0 = time.perf_counter()  
            if step >= 100:  
                break  
        print(f'每步的平均时间: {summ / count}')
PyTorch 带填充的自定义注意力机制(SDPA)

为了进行基线实验,我们将Transformer模块配置为使用PyTorch的SDPA机制。在实验中,我们分别进行了训练和评估,既包括启用了torch.compile的情况,也包括未启用的情况。这些实验是在装备了CUDA 12.4PyTorch 2.5.1的NVIDIA H100上进行的。

    from torch.nn.functional import scaled_dot_product_attention as sdpa  
    block_fn = functools.partial(MyAttentionBlock, attn_fn=sdpa)  
    causal_block_fn = functools.partial(  
        MyAttentionBlock,  
        attn_fn=functools.partial(sdpa, is_causal=True)  
    )  

    for mode in ['eval', 'train']:  
        for compile in [False, True]:  
            block_func = causal_block_fn if mode == 'train' else block_fn  
            print(f'{mode} with {collate}, ',  
                  f'{"编译过" if compile else "未编译"}')  
            main(block_fn=block_func,  
                 pad_idx=PAD_ID,  
                 train=mode=='train',  
                 compile=compile)

性能表现:

  • 评估时间:未启用 torch.compile 时为 132 毫秒,启用 torch.compile 时为 130 毫秒,
  • 训练时间:未启用 torch.compile 时为 342 毫秒,启用 torch.compile 时为 299 毫秒,
针对变长输入的优化

在这节里,我们将来看看几种用于处理Transformer模型中变长输入序列的优化方法。

填充优化技术

我们第一次优化的不是注意力核,而是填充机制。我们不再把每个批次的序列填充到固定的长度,而是填充到该批次中最长序列。下面的代码块包括我们修订的整理函数(collation function)和更新的实验。

    def collate_pad_to_longest(batch):  
        padded_inputs = []  
        padded_targets = []  
        max_length = max([b[0].shape[0] for b in batch])  
        for b in batch:  
            padded_inputs.append(pad_sequence(b[0], max_length, PAD_ID))  
            padded_targets.append(pad_sequence(b[1], max_length, PAD_ID))  
        padded_inputs = torch.stack(padded_inputs, dim=0)  
        padded_targets = torch.stack(padded_targets, dim=0)  
        return {  
            'inputs': padded_inputs,  
            'targets': padded_targets  
        }  

    for mode in ['eval', 'train']:  
        for compile in [False, True]:  
            # 如果是 'train' 模式,则使用 causal_block_fn,否则使用 block_fn
            block_func = causal_block_fn if mode == 'train' else block_fn  
            print(f'{mode} 使用 {collate},{"编译" if compile else "未编译"}')  
            main(block_fn=block_func,  
                 data_collate_fn=collate_pad_to_longest,  
                 pad_idx=PAD_ID,  
                 train=mode=='train',  
                 compile=compile)

将每个批次的序列填充到最长长度,可以稍微提升性能。

  • 评估时间:未使用 torch.compile 时为 129 毫秒,使用 torch.compile 时减少到 116 毫秒
  • 训练时间:未使用 torch.compile 时为 337 毫秒,使用 torch.compile 时减少到 294 毫秒
SDPA (PyTorch NestedTensors) 的使用

接下来,我们利用SDPA在评估模式下内置的PyTorch NestedTensors功能。目前,PyTorch NestedTensors是一个原型功能,它允许将长度各异的张量组合在一起。这些张量有时也被称为 jaggedragged 张量。在下面的代码块中,我们定义了一个批处理函数,用于将我们的序列整理成NestedTensors。我们还定义了一个 indices 条目,以便能够正确计算位置嵌入。

PyTorch 的 NestedTensors 只被 有限数量的 PyTorch 运算 支持。绕过这些限制可能需要一些创造力。例如,只有当它们具有完全相同的“锯齿状”形状时,NestedTensors 之间的加法才被支持。在下面的代码中,我们使用了一种变通方法来确保 indices 的形状与模型输入的形状一致。

    def nested_tensor_collate(batch):  
        inputs = torch.nested.as_nested_tensor([b[0] for b in batch],  
                                               layout=torch.jagged)  
        targets = torch.nested.as_nested_tensor([b[1] for b in batch],  
                                                layout=torch.jagged)  
        indices = torch.concat([torch.arange(b[0].shape[0]) for b in batch])  

        # 临时解决方法以创建具有相同“锯齿形”形状的嵌套张量  
        xx = torch.empty_like(inputs)  
        xx.data._values[:] = indices  

        return {  
            'inputs': inputs,  
            'targets': targets,  
            'indices': xx  
        }  

    for compile in [False, True]:  
        print(f'使用嵌套张量进行评估,是否编译为{“编译” if compile else “未编译”}')  
        main(  
            block_fn=block_fn,  
            data_collate_fn=nested_tensor_collate,  
            train=False,  
            compile=compile  
        )

尽管使用了 torch.compile,NestedTensor 优化后的每步时间为 131 毫秒,与基准相似,但在编译模式下,每步时间减少到 42 毫秒,实现了显著的约 3 倍提升。

FlashAttention2

注:FlashAttention2 是一个技术术语,具体含义请参见相关文献或上下文。 (Flash注意力2)

在我们之前的文章里,我们展示了FlashAttention及其对变压器模型性能的影响。在本文中,本文将展示如何使用flash_attn_varlen_func,这是一个专为处理不同长度输入设计的API。为了使用此函数,我们将批次中的所有序列合并成一个单一序列,并创建一个名为_cu_seqlens_的张量,该张量指向拼接张量中每个单独序列的起始索引。下面的代码块包括我们的合并函数,随后是评估和训练实验。需要注意的是,flash_attn_varlen_func目前暂不支持torch.compile(截至撰写本文时)。

    def collate_concat(batch):  
        inputs = torch.concat([b[0] for b in batch]).unsqueeze(0)  
        目标 = torch.concat([b[1] for b in batch]).unsqueeze(0)  
        索引 = torch.concat([torch.arange(b[0].shape[0]) for b in batch])  
        序列长度 = torch.tensor([b[0].shape[0] for b in batch])  
        序列长度 = torch.cumsum(序列长度, dim=0, dtype=torch.int32)  
        cu_seqlens = torch.nn.functional.pad(序列长度, (1, 0))  

        return {  
            'inputs': inputs,  
            'targets': 目标,  
            'indices': 索引,  
            'attn_mask': cu_seqlens  
        }  

    from flash_attn import flash_attn_varlen_func  
    fa_varlen = lambda q, k, v, attn_mask: flash_attn_varlen_func(  
        q.squeeze(0),  
        k.squeeze(0),  
        v.squeeze(0),  
        cu_seqlens_q=attn_mask,  
        cu_seqlens_k=attn_mask,  
        max_seqlen_q=MAX_SEQ_LEN,  
        max_seqlen_k=MAX_SEQ_LEN  
    ).unsqueeze(0)  

    fa_varlen_causal = lambda q, k, v, attn_mask: flash_attn_varlen_func(  
        q.squeeze(0),  
        k.squeeze(0),  
        v.squeeze(0),  
        cu_seqlens_q=attn_mask,  
        cu_seqlens_k=attn_mask,  
        max_seqlen_q=MAX_SEQ_LEN,  
        max_seqlen_k=MAX_SEQ_LEN,  
        causal=True  
    ).unsqueeze(0)  

    block_fn = functools.partial(MyAttentionBlock,  
                                 attn_fn=fa_varlen,  
                                 format='bshd')  

    causal_block_fn = functools.partial(MyAttentionBlock,  
                                        attn_fn=fa_varlen_causal,  
                                        format='bshd')  

    print('闪光注意力评估')  
    main(  
        block_fn=block_fn,  
        data_collate_fn=collate_concat,  
        train=False  
    )  

    print('闪光注意力训练')  
    main(  
        block_fn=causal_block_fn,  
        data_collate_fn=collate_concat,  
        train=True,  
    )

这种优化效果惊人,评估时间仅为51毫秒,训练时间仅为160毫秒,相比我们的基线实验,性能分别提高了2.6倍和2.1倍。

XFormers 内存高效注意力

在我们之前的帖子中,我们展示了如何使用来自 xFormers (0.0.28)memory_efficient_attention 操作符。在这里,我们将演示如何使用该操作符 BlockDiagonalMask,它是专门用于处理任意长度输入序列的。所需的整理函数(collation函数)就出现在下面的代码块中,下面就展示了该函数以及随后的评估和训练实验。需要注意的是,在训练模式下,torch.compile 会失败。

from xformers.ops import fmha  
from xformers.ops import memory_efficient_attention as mea  

def collate_xformer(batch):  
    inputs = torch.concat([b[0] for b in batch]).unsqueeze(0)  
    targets = torch.concat([b[1] for b in batch]).unsqueeze(0)  
    indices = torch.concat([torch.arange(b[0].shape[0]) for b in batch])  
    seqlens = [b[0].shape[0] for b in batch]  
    batch_sizes = [1 for b in batch]  
    block_diag = fmha.BlockDiagonalMask.from_seqlens(seqlens, device='cpu')  
    block_diag._batch_sizes = batch_sizes  

    return {  
        'inputs': inputs,  
        'targets': targets,  
        'indices': indices,  
        'attn_mask': block_diag  
    }  

mea_eval = lambda q, k, v, attn_mask: mea(  
    q, k, v, attn_bias=attn_mask)  

mea_train = lambda q, k, v, attn_mask: mea(  
    q, k, v, attn_bias=attn_mask.make_causal())  

block_fn = functools.partial(MyAttentionBlock,  
                             attn_fn=mea_eval,  
                             format='bshd')  

causal_block_fn = functools.partial(MyAttentionBlock,  
                             attn_fn=mea_train,  
                             format='bshd')  

print(f'xFormer 注意力机制 ')  
for compile in [False, True]:  
    print(f'使用 xFormer 注意力机制进行评估,{"已编译" if compile else "未编译过"}')  
    main(block_fn=block_fn,  
         train=False,  
         data_collate_fn=collate_xformer,  
         compile=compile)  

print(f'使用 xFormer 注意力机制进行训练')  
main(block_fn=causal_block_fn,  
     train=True,  
     data_collate_fn=collate_xformer)

在未使用torch.compile的情况下,评估的步时间为50毫秒,训练的步时间为159毫秒;使用torch.compile进行评估时,步时间减少到42毫秒。

看来看我们成果

以下是我们优化方法的总结。

每步耗时结果的不同优化算法(数值越低越好) — (作者提供)

我们简化模型的表现最好的是xFormer的memory_efficient_attention,它在评估中的性能提升了大约3倍,在训练中的性能提升了大约2倍。但请注意,不要仅凭这些结果得出结论,因为不同注意力函数的性能差异会根据具体模型和应用场景有很大变化,

优化:HuggingFace模型以适应不同的输入长度

上述描述的工具和技术在从零开始创建模型时很容易实现。然而,如今,ML 开发者通常会采用现有的(预训练)模型,并针对其特定用例进行微调,这种情况相当常见。尽管我们所描述的优化可以在不改变模型权重集和模型行为的情况下进行集成,但要以最佳方式做到这一点并不完全清晰。在一个理想的世界里,我们的 ML 框架应该能够让我们编程实现一种针对可变长度输入进行了优化的注意力机制。本节将演示如何为可变长度输入优化 HuggingFace 模型。

一个玩的模型:GPT2LMHeadModel

为了方便讨论,我们创建一个示例,训练HuggingFace的GPT2LMHead模型,用于处理可变长度的序列。这需要根据HuggingFace的输入要求调整我们的随机生成的数据集和填充数据的方式。

    从transformers导入GPT2Config, GPT2LMHeadModel,  

    # 使用随机生成的数据  
    class HuggingFaceFakeDataset(Dataset):  
        def __len__(self):  
            return 1000000  

        def __getitem__(self, index):  
            length = torch.randint(1, MAX_SEQ_LEN, (1,))  
            input_ids = torch.randint(1, NUM_TOKENS, (length,))  
            labels = input_ids.clone()  
            labels[0] = PAD_ID # 忽略第一个标记  
            return {  
                'input_ids': input_ids,  
                'labels': labels  
            }  

    def hf_collate_with_padding(batch):  
        padded_inputs = []  
        padded_labels = []  
        for b in batch:  
            input_ids = b['input_ids']  
            labels = b['labels']  
            padded_inputs.append(pad_sequence(input_ids, MAX_SEQ_LEN, PAD_ID))  
            padded_labels.append(pad_sequence(labels, MAX_SEQ_LEN, PAD_ID))  
        padded_inputs = torch.stack(padded_inputs, dim=0)  
        padded_labels = torch.stack(padded_labels, dim=0)  
        return {  
            'input_ids': padded_inputs,  
            'labels': padded_labels,  
            'attention_mask': (padded_inputs != PAD_ID)  
        }
训练功能

我们的训练功能根据请求的[GPT2Config]配置来实例化一个[GPT2LMHeadModel],然后在可变长度的序列上训练该模型。

    def hf_main(  
        config,  
        collate_fn=hf_collate_with_padding,  
        compile=False  
    ):  
        torch.random.manual_seed(0)  
        device = torch.device(DEVICE)  
        torch.set_float32_matmul_precision("high")  

        # 创建数据集和数据加载器  
        data_set = HuggingFaceFakeDataset()  
        data_loader = DataLoader(  
            data_set,  
            batch_size=BATCH_SIZE,  
            collate_fn=collate_fn,  
            num_workers=12 if DEVICE == "CUDA" else 0,  
            pin_memory=True,  
            drop_last=True  
        )  

        model = GPT2LMHeadModel(config).to(device)  

        if compile:  
            model = torch.compile(model)  

        # 定义损失函数和优化器  
        criterion = torch.nn.CrossEntropyLoss(ignore_index=PAD_ID)  
        optimizer = torch.optim.SGD(model.parameters())  

        model.train()  

        t0 = time.perf_counter()  
        summ = 0  
        count = 0  

        for step, data in enumerate(data_loader):  
            # 将数据复制到GPU  
            data = data_to_device(data, device=device)  
            input_ids = data['input_ids']  
            labels = data['labels']  
            position_ids = data.get('position_ids')  
            attn_mask = data.get('attention_mask')  
            with torch.amp.autocast(DEVICE, dtype=torch.bfloat16):  
                outputs = model(input_ids=input_ids,  
                                position_ids=position_ids,  
                                attention_mask=attn_mask)  
                logits = outputs.logits[..., :-1, :].contiguous()  
                labels = labels[..., 1:].contiguous()  
                loss = criterion(logits.view(-1, NUM_TOKENS), labels.flatten())  

            optimizer.zero_grad(set_to_none=True)  
            loss.backward()  
            optimizer.step()  

            # 捕获步骤时间  
            batch_time = time.perf_counter() - t0  
            if step > 20:  # 跳过前几步  
                summ += batch_time  
                count += 1  
            t0 = time.perf_counter()  
            if step >= 100:  
                break  
        print(f'每步平均时间: {summ / count}')
填充版SDPA方法

在下面的回调函数中,我们使用默认的序列填充方式来调用训练函数。

    config = GPT2Config(  
            n_layer=DEPTH,  
            n_embd=DIM,  
            n_head=NUM_HEADS,  
            vocab_size=NUM_TOKENS,  
        )  

    for compile in [False, True]:  
        print(f"HF GPT2 使用自定义注意力机制进行训练,compile={compile}")  
        hf_main(config=config, compile=compile)

没有使用 /torch.compile/ 时,步骤时间结果为 815 毫秒;使用 /torch.compile/ 时,步骤时间则缩短至 440 毫秒。

FlashAttention2

我们现在利用HuggingFace内置的FlashAttention2支持,通过将_attn_implementation_参数设置为“flash_attention_2”。在背后,HuggingFace将填充过的数据unpad,然后传递给我们之前提到的优化后的flash_attn_varlen_func函数。

    flash_config = GPT2Config(  
            n_layer=DEPTH,  
            n_embd=DIM,  
            n_head=NUM_HEADS,  
            vocab_size=NUM_TOKENS,  
            attn_implementation='flash_attention_2'  
        )  

    print("使用 Flash 训练 HF GPT2 模型")  
    hf_main(config=flash_config)

时间步长为620毫秒,这相当于提升了30%(在未编译模式下),只需切换一下开关。

FlashAttention2(未填充输入)

当然,仅仅整理函数中填充的序列,然后去除填充,似乎并不合理。在最近的HuggingFace更新中,增加了支持将未填充的拼接序列传递给某些模型的功能。遗憾的是(截至撰写本文时),我们的GPT2模型并未包含在内。但是,只需在modeling_gpt2.py文件中添加五行代码即可支持此功能。以便将序列[_positionids]传递到[flash-attention内核]中。完整的代码补丁如下所示:

    @@ -370,0 +371 @@
    +        position_ids = None
    @@ -444,0 +446 @@
    +            position_ids=position_ids
    @@ -611,0 +614 @@
    +        position_ids=None
    @@ -621,0 +625 @@
    +            position_ids=position_ids
    @@ -1140,0 +1145 @@
    +                    position_ids=position_ids

我们定义了一个拼接函数来合并序列,并在未填充的序列上训练该hugging face模型。(请参见内置的DataCollatorWithFlattening工具。)

    def collate_flatten(batch):  
        input_ids = torch.concat([b['input_ids'] for b in batch]).unsqueeze(0)  
        labels = torch.concat([b['labels'] for b in batch]).unsqueeze(0)  
        position_ids = [torch.arange(b['input_ids'].shape[0]) for b in batch]  
        position_ids = torch.concat(position_ids)  

        return {  
            'input_ids': input_ids,  
            'labels': labels,  
            'position_ids': position_ids  
        }  

    print(f"HF GPT2 使用 flash 进行训练,不进行 padding 操作")  
    hf_main(config=flash_config, collate_fn=collate_flatten)

得到的步时是323毫秒,比在填充输入上运行flash-attention快90%的速度。

结果如下

我们的HuggingFace实验结果如下所示。

作者提供的不同优化方法的每一步所需时间的结果(越低越好)

几乎不费吹灰之力,我们就将运行时性能提高了2.5倍,相比未编译的基线实验;与编译版本相比,性能提高了36%。

本节展示了如何通过HuggingFace的API来利用FlashAttention2的优化内核,从而显著提高现有模型在不同长度序列上的训练效率。

总结

随着AI模型越来越受欢迎和复杂,优化它们的性能已成为减少运行时间和成本的关键。特别是对于计算密集型组件,如注意力层,这一点尤其重要。在这次讨论中,我们继续探讨了如何优化注意力层,并介绍了提高Transformer模型性能的新工具和技术。如需了解更多关于AI模型优化的信息,请查看本系列的第一篇文章以及我们关于这个主题的其他文章。

点击查看更多内容
TA 点赞

若觉得本文不错,就分享一下吧!

评论

作者其他优质文章

正在加载中
  • 推荐
  • 评论
  • 收藏
  • 共同学习,写下你的评论
感谢您的支持,我会继续努力的~
扫码打赏,你说多少就多少
赞赏金额会直接到老师账户
支付方式
打开微信扫一扫,即可进行扫码打赏哦
今天注册有机会得

100积分直接送

付费专栏免费学

大额优惠券免费领

立即参与 放弃机会
意见反馈 帮助中心 APP下载
官方微信

举报

0/150
提交
取消