为了账号安全,请及时绑定邮箱和手机立即绑定

大型语言模型中的缓存增强生成(CAG):一步一步教你如何实现

检索增强生成(RAG)技术 是一种强大的方法,它可以将外部知识库与大语言模型连接起来,并在用户提问时检索相关上下文,但这种方法也会因为检索延迟而减慢大语言模型的响应速度。

缓存增强生成(CAG) 提供了一种更快的替代方案;它不是进行实时检索,而是将相关文档提前加载到模型的上下文中,并存储这种推理状态,也就是所谓的键值(KV)缓存。这种方法则消除了检索延迟,使模型能即时访问已预加载的信息,从而实现更快、更高效的响应。

了解更多关于 CAG 的技术解释,可以看看这篇文章

在这份教程中,我们将展示如何构建一个简单的CAG环境,预先嵌入所有知识,快速回答用户的多条询问,并确保每次重置缓存时无需重新加载整个上下文。

先决条件
  1. 一个 HuggingFace 账户和一个访问令牌

2. 一个关于你的句子的document.txt文件。

项目启动

我们引入一些必要的库,如下:

  • torch 用于 PyTorch。
  • transformers 用于 Hugging Face。
  • DynamicCache 用于存储模型的键值状态。
    import torch  # 导入PyTorch库  
    from transformers import AutoTokenizer, AutoModelForCausalLM  # 导入transformers库中的AutoTokenizer和AutoModelForCausalLM模型  
    from transformers.cache_utils import DynamicCache  # 导入DynamicCache类  
    import os  # 导入操作系统库
生成函数(Generate Function)

我们接下来定义 generate 函数。

函数generate使用贪婪解码来处理逐词生成,使用缓存的知识。

贪婪解码法是一种简单的文本生成方法,在这种方法中,每一步中,选择概率最高的词元(即 logits 中概率最大的那个)作为下一个词。

我们输入如下这些参数:

  • model: 我们在这个教程中使用的大型语言模型(LLM)是 Mistral-7B。
  • input_ids: 包含分词输入序列的一个张量。
  • past_key_values: CAG 的核心组件。通过缓存之前计算的注意力值来加速推理过程,避免重复计算。
  • max_new_tokens: 要生成的新标记的最大数量,默认为 50。

该函数在一个循环中运行,最多迭代 max_new_tokens 次或在生成序列结束标记(如果存在的话)时提前终止。

每次迭代:

  • 模型处理当前输入令牌,并结合缓存的 past_key_values,生成下一个令牌的 logits。
  • 通过贪心解码来找出概率最高的令牌。
  • 将此新令牌添加到输出序列中,并将缓存 (past_key_values) 更新为包含当前上下文信息。
  • 新生成的令牌作为下一次迭代的输入。
    def generate(model, input_ids: torch.Tensor, past_key_values, max_new_tokens: int = 50) -> torch.Tensor:  
        device = model.model.embed_tokens.weight.device  
        origin_len = input_ids.shape[-1]  
        input_ids = input_ids.to(device)  
        output_ids = input_ids.clone()  
        next_token = input_ids  

        with torch.no_grad():  
            for _ in range(max_new_tokens):  
                out = model(  
                    input_ids=next_token,  
                    past_key_values=past_key_values,  
                    use_cache=True  
                )  
                logits = out.logits[:, -1, :]  
                token = torch.argmax(logits, dim=-1, keepdim=True)  
                output_ids = torch.cat([output_ids, token], dim=-1)  
                past_key_values = out.past_key_values  
                next_token = token.to(device)  

                if model.config.eos_token_id is not None and token.item() == model.config.eos_token_id:  
                    break  
        return output_ids[:, origin_len:]
DynamicCache 的设置

接下来,我们将定义 get_kv_cache 函数,该函数为变压器模型的注意力机制准备一个可以重复利用的键值缓存,以及 clean_up 函数,该函数通过移除不必要的条目来清理缓存,以确保你可以回答多个独立问题而不会“污染(pollution)”缓存。

get_kv_cache 将提示信息(例如来自 document.txt 的知识)通过模型一次,创建一个 KV 缓存(Key-Value 缓存),记录每一层的隐藏状态。

get_kv_cache 通常传入以下参数:

  • model: 用于编码提示的Transformer模型。
  • tokenizer:将提示转换为token ID的分词器。
  • prompt:用作提示的字符串。

并返回一个 DynamicCache 类型的对象。

get_kv_cache函数首先使用分词器对提供的提示进行分词处理,然后将分词结果转换为输入ID,接着初始化一个DynamicCache对象来存储键值对,然后启用缓存(use_cache=True)并通过模型执行前向传递。这会将模型计算出的键值对填充到缓存。

clean_up 会修剪 DynamicCache 对象,使其长度与原始序列长度一致,移除处理过程中添加的任何额外令牌。对于缓存的每一层,它会对键和值张量进行切片,只保留序列维度上的前 origin_len 个令牌。

    def get_kv_cache(model, tokenizer, prompt: str) -> DynamicCache:  # 获取键值缓存
        device = model.model.embed_tokens.weight.device
        input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
        cache = DynamicCache()

        with torch.no_grad():  # 在不计算梯度的情况下执行
            _ = model(
                input_ids=input_ids,
                past_key_values=cache,
                use_cache=True
            )
        return cache  # 返回缓存

    def clean_up(cache: DynamicCache, origin_len: int):  # 清理缓存,保留原始长度的数据
        for i in range(len(cache.key_cache)):
            cache.key_cache[i] = cache.key_cache[i][:, :, :origin_len, :]
            cache.value_cache[i] = cache.value_cache[i][:, :, :origin_len, :]
加载大模型(Mistral)

现在我们将加载Mistral-7B模型,并在可用的GPU上以全精度或半精度(FP16)模式加载分词器和模型。

记得输入你独有的HuggingFace Token YOUR_HF_TOKEN

    model_name = "mistralai/Mistral-7B-Instruct-v0.1"  
    tokenizer = AutoTokenizer.from_pretrained(model_name, token="YOUR_HF_TOKEN", trust_remote_code=True)  
    model = AutoModelForCausalLM.from_pretrained(  
        model_name,  
        torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,  
        device_map="auto",  
        trust_remote_code=True,  
        token="YOUR_HF_TOKEN"  
    )  
    device = "cuda" if torch.cuda.is_available() else "cpu"  
    model.to(device)  
    print(f"模型 {model_name} 已加载。")
从document.txt生成一个知识摘要

接下来,我们将阅读 document.txt,你可以在这个文档中填写自己的信息。在本教程中,document.txt 包含关于我的信息(Ronan Takizawa)。

在这里,我们构建了一个包含文档信息的简单系统提示,并将其传入 get_kv_cache 以生成KV缓存数据。

    with open("document.txt", "r", encoding="utf-8") as f:  
        doc_text = f.read()  

    system_prompt = f"""  
    <|system|>  
    你是一个提供简短而事实性的答案的助手。  
    <|user|>  
    上下文:  
    {doc_text}  
    问题:  
    """.strip()  

    ronan_cache = get_kv_cache(model, tokenizer, system_prompt)  
    origin_len = ronan_cache.key_cache[0].shape[-2]  
    print("缓存构建完成。")  
利用缓存提问

我们首先运行clean_up来清空缓存(这是CAGs中的一个良好实践)。

接下来,我们将问题转化为input_ids_q1中的token,然后附加到存储在ronan_cache中的知识内容。

最后,我们通过调用 generate 函数生成答案,并用 tokenizer.decode 解码结果。

    question1 = "谁是Ronan Takizawa?"  # 谁是Ronan Takizawa?  
    clean_up(ronan_cache, origin_len)  # 清理ronan_cache和origin_len  
    input_ids_q1 = tokenizer(question1 + "\n", return_tensors="pt").input_ids.to(device)  # 将问题1和换行符转换为输入ID并放置到设备上  
    gen_ids_q1 = generate(model, input_ids_q1, ronan_cache)  # 生成模型的答案ID  
    answer1 = tokenizer.decode(gen_ids_q1[0], skip_special_tokens=True)  # 解码生成的答案ID,跳过特殊标记  
    print("问题1:", question1)  # 打印问题1  
    print("A1:", answer1)  # 打印答案1

你应该收到这样的回复:

    Q1: 谁是隆安·高桥?  
    A1: 答:一个雄心勃勃且成就卓著的科技爱好者,他在软件开发、AI/ML 等领域拥有广泛的技术能力...

完整代码在这里: 链接

别忘了给 ORIGINAL REPO 点个赞哦!⭐️
结论部分

缓存增强生成技术(CAG) 简化了AI架构的设计,通过直接将小型知识库存储在模型的上下文窗口中,从而消除了RAG中的检索循环,减少延迟。这种方法提高了响应速度和效率,并使LLM更好地利用外部知识。通过使用CAG,开发人员能够简化他们的AI系统,实现更快和更高效的知识集成,特别是在数据集稳定且紧凑的任务中。

点击查看更多内容
TA 点赞

若觉得本文不错,就分享一下吧!

评论

作者其他优质文章

正在加载中
  • 推荐
  • 评论
  • 收藏
  • 共同学习,写下你的评论
感谢您的支持,我会继续努力的~
扫码打赏,你说多少就多少
赞赏金额会直接到老师账户
支付方式
打开微信扫一扫,即可进行扫码打赏哦
今天注册有机会得

100积分直接送

付费专栏免费学

大额优惠券免费领

立即参与 放弃机会
意见反馈 帮助中心 APP下载
官方微信

举报

0/150
提交
取消