为了账号安全,请及时绑定邮箱和手机立即绑定

将小型语言模型转换为支持KV缓存的ONNX格式:从PyTorch到ONNX的半精度优化之旅

由 Francesc Gispert 与 Eric Soriano 共同撰写

在Esperanto Technologies,我们对使用ONNX格式运行生成式AI模型很感兴趣,原因有三个:

1. ONNX 可以与大多数常见格式进行双向转换,是一种互操作性格式。

2. ONNX 由 ONNXRuntime 框架支持,该框架是开源的,也是今天的主要框架之一。

3. ONNX 格式提供了模型的静态模型视图,这使得Esperanto的机器学习编译器更容易加载和优化模型,从而实现更快的推理。

一个 ONNX 文件存储了一个隐藏的操作图。图中的节点和边通过操作和张量的名字来标记。可以通过 Netron. 查看具体的名字。

图1:ONNX互操作性

在这篇文章中,我们将回顾我们如何将 PyTorch 小型语言模型(基础或微调)转换为支持 fp16 精度的完全功能的 ONNX 模型,并支持键值缓存(KV-cache)功能。

从PyTorch到ONNX的转换

首先,我们假设有一个小型语言模型(SLM:基于Transformer,参数少于10亿),它采用PyTorch的torch格式,这是最流行的ML格式。我们的第一步是使用torch.onnx.export这个命令,将我们的torch模型转换为初步的ONNX模型。为此,我们需要向模型提供输入,以便我们可以追踪构成网络的各种操作。

不启用KV缓存的SLM只需要两个独立的输入:输入文本转换后的输入ID(即标记列表)和注意力掩码(一个掩码,实际被消费的标记处为1,其余为0)。文本生成只需要一个输出,即“logits”。这个输出为输入中的每个标记提供了一个分数,表示预测下一个位置标记的概率分布。在常规用例中,我们会忽略这些列表中的所有内容,仅使用最后一个列表来预测一个新标记。在提供这些输入时,我们还可以定义“onnx_symbols”或动态形状参数(例如批次或序列长度),这些是占位符,可以根据推理的需要进行调整。这些特别有助于从同一个模型中获取不同配置(例如不同的序列长度和批次等)。截至2024年9月,我们只能保证操作集版本在14及以下的模型将得到完全支持,因为某些节点的定义在不同操作集版本之间可能有显著差异。

所以我们最后得到了一个类似的命令:

    dummy_input = torch.ones((1, 128), dtype=torch.int64)  
    symbolic_names = {0: "batch", 1: "sequence"}  
    torch.onnx.export(  
            model,     # 注: torch 模型,  
            (dummy_input, dummy_input),  # 注: 输入  
            "model.onnx",   # 注: 保存的 ONNX 模型路径  
            input_names = ["input_ids", "attention_mask"],  
            output_names = ["logits"],  
            dynamic_axes = {"input_ids": symbolic_names, "attention_mask": symbolic_names, "logits": symbolic_names},  # 动态轴  
            opset_version = 14  # ONNX 操作集版本  
    )

一旦我们通过这条命令得到了初步模型,我们可能想要微调它。

  • 将权重以外部文件的形式保存。
  • 如果PyTorch模型尚未转换为fp16精度,则将其转换为fp16。注意,即使在转换为fp16时,我们也希望LayerNormalization继续使用fp32精度以确保计算的高动态范围。
  • 更改输出的精度:如果PyTorch模型使用了fp16,logits可能仍然会被转换为fp32。这可以通过在ONNX的最后一步移除转换成fp32的Cast操作来轻松解决。

一旦这些步骤都完成了,我们就可以轻松地利用ONNXRuntime进行推理,并验证这个ONNX模型是否功能完整且没有KV缓存。

给序列预测模型(SLMs)添加KV缓存

所考虑的语言模型都使用了仅解码器的Transformer架构。特别是,这些模型除了在自注意力层外,在其他地方都并行处理代表不同文本标记的嵌入向量。由于模型的自回归特性(即新标记的预测仅基于之前的标记),可以在自注意力层存储中间结果,从而减少推理复杂度。这些缓存的张量通常称为键和值。

我们通过滑动窗口方法来实现KV缓存。也就是说,模型每次推理时只处理一小段令牌,随着新令牌的生成,这段令牌的范围会向前移动。除此之外,在自注意力层,模型会从KV缓存中获取表示过去令牌的嵌入向量,而其他所有内部张量则会相应地被裁剪。

因此,序列长度被换成几个数量级。

  1. 窗口大小是指单次推理中处理的令牌数量。
  2. 序列长度是指每次推理时模型所能考虑的令牌总数。这是过去缓存中的令牌加上当前窗口令牌的总长度,序列开头可能还有填充令牌。
  3. 最长上下文长度是指缓存可以存储的最大令牌数量。如果需要在不更改缓存数据的情况下增加序列长度,这一点会很有用。

图2:KV缓存的简化图示。灰色部分为填充区,黄色部分为推断填充的窗口

上面的图展示了KV缓存张量的形状。实际上,这幅图进行了简化,因为每个token索引对应的线实际上被分割成了多段。新模型计算黄色部分,对应于当前窗口。为了完整地描述这个场景,我们增加了一个输入张量来表示过去的白色和灰色部分,一个表示整个缓存的输出张量,一个用于将所有部分连接的运算,以及另一个运算用于提取白色和黄色部分(即去除灰色部分表示的填充)。埃斯佩朗托技术公司的机器学习编译器将KV缓存张量保持为设备驻留张量,从而避免了不必要的数据移动和复制。

图3:添加的操作以形成键/值张量

input_ids 张量中表示序列长度的维度替换为窗口大小。通过 ONNX 的形状推断功能,这一变化会影响到模型的其他部分。

图4:应用于输入令牌张量的变换图

关于注意力掩码,有两个张量需要考虑。模型的自注意力块应用的掩码最初是一个维度与序列长度相等的方形矩阵。在缓存版本的模型中,仅使用该矩阵的最后几行(行数与窗口大小相同)。在最简单的使用场景中,自回归语言模型总是提供一个三角形的因果掩码,因为它编码了每个标记的预测依赖于所有之前的标记。实际上,输入注意力掩码张量是一个向量,模型内部将其扩展为一个矩阵。该向量本身保持不变,但扩展它的模型部分需要修改,以适应窗口大小与序列长度的比例。

图5:应用于注意力掩码的变更

旋转向量嵌入

一些SLM包含另一个在我们的实现中需要进行一些仔细调整的部分:旋转位置嵌入。这是一种应用于自注意力层的查询和键张量的转换,在通过点积进行比较之前执行。具体来说,标记从0开始索引,对应于第k个标记的嵌入向量会旋转角度kθ(对于查询)和-kθ(对于键)。当在自注意力机制中组合查询和键时,这些旋转部分相互抵消。总的来说,这些操作有助于捕捉序列中两个标记之间的相对位置。

窗口内的标记不能在每次推断时从0开始索引,因为窗口随着推断过程滑动。相反,我们修改了模型,使其将窗口视为序列的最后部分(即,其第一个索引是序列长度减窗口大小)。然后,在应用旋转嵌入变换之前缓存键张量至关重要。这样,即使给定标记的绝对索引在每次推断时发生变化,索引之间的相对差异得以保持,模型也能正常工作。

结尾

本节中描述的所有工作都是通用的,适用于任何基于解码器架构的SLM。在Esperanto,我们提倡使用优化的ONNX模型,并为此致力于尽可能地开源SLM,使用不同的精度,并扩展以支持不同的优化技术。请注意,这些模型可以用于任何ONNX提供者,但它们特别被定制以利用Esperanto产品的优势,并实现更快的推理速度。

您可以在我们的HuggingFace页面上找到Esperanto转换为fp16精度ONNX格式的基础SLM模型。

点击查看更多内容
TA 点赞

若觉得本文不错,就分享一下吧!

评论

作者其他优质文章

正在加载中
  • 推荐
  • 评论
  • 收藏
  • 共同学习,写下你的评论
感谢您的支持,我会继续努力的~
扫码打赏,你说多少就多少
赞赏金额会直接到老师账户
支付方式
打开微信扫一扫,即可进行扫码打赏哦
今天注册有机会得

100积分直接送

付费专栏免费学

大额优惠券免费领

立即参与 放弃机会
意见反馈 帮助中心 APP下载
官方微信

举报

0/150
提交
取消