为了账号安全,请及时绑定邮箱和手机立即绑定

注意力机制:GPT等大模型的基石

1 啥是注意力?

人类观察事物,能快速判断一种事物,是因为大脑能很快把注意力放在事物最具辨识度的部分从而作出判断,而非从头到尾一览无遗观察一遍才能有判断。基于这样的观察实践,产生了注意力机制(Attention Mechanism)。

想象你在人群中找一个穿红衣服的人。你不会一一检查每个人的鞋子、裤子、头发,而是直接把目光锁定在衣服颜色,因为那是“最有辨识度的特征”。大脑就是这么高效工作的。注意力机制是模仿这个过程,让机器也能快速抓住“重点”,而非傻乎乎扫描所有信息。

Java 类比: 假设你有一个 List<Person>,里面存了一堆人的信息(身高、体重、衣服颜色等)。你想找红衣服的人,传统方法是遍历整个列表,检查每个属性。但如果用“注意力”,就像加了个过滤器 focusOnColor(“red”),只关注衣服颜色,忽略其他无关信息,效率一下子就上来了。

2 啥是注意力计算规则

需三个指定的输入Q(query)、K(key)、V(value),再通过计算公式得到注意力的结果,这个结果代表query在key和value作用下的注意力表示。 当输入的Q=K=V时,称作自注意力计算规则。

  • Q(Query):你要找啥,就像搜索框里输入的关键词
  • K(Key):一堆可以匹配的东西,就像数据库里的索引
  • V(Value):实际的内容,就像数据库里对应的详细记录

通过公式计算,得出 Q 对哪些 K 感兴趣(注意力权重),再根据权重从 V 里提取信息。结果就是 Q 的“注意力表示”,也就是它最终关注到的东西。

若 Q、K、V 是同一个东西(比如同一个句子),这就是“自注意力”(Self-Attention),相当于自己跟自己比对,找内部关系。

Javaer 想象一个搜索功能:

class Attention {
    String query; // Q:搜索词,比如 "红衣服"
    List<String> keys; // K:索引列表,比如 ["衣服颜色", "身高", "体重"]
    List<String> values; // V:内容列表,比如 ["红色", "180cm", "70kg"]
    
    // 计算注意力:Q 跟每个 K 比对,得出权重,再从 V 里拿结果
    String computeAttention() {
        // 伪代码:比对 query 和 keys,算权重,提取 values
        return "红色"; // 输出 Q 关注的重点
    }
}

自注意力就像 query = keys = values,自己跟自己玩。

2.1 常见的注意力计算规则

将Q,K进行纵轴拼接,做一次线性变化,再用softmax处理获得结果最后与V做张量乘法。
Attention(Q,K,V)=Softmax(Linear([Q,K]))⋅V Attention(Q, K, V) = Softmax(Linear([Q , K])) \cdot V Attention(Q,K,V)=Softmax(Linear([Q,K]))V
把 Q 和 K 拼在一起(像两个字符串拼接),然后通过一个“公式”(线性变化)加工,再用 softmax(一种归一化方法,把结果变成概率分布,如 [0.7, 0.2, 0.1]),最后拿这个概率去“加权”提取 V 的内容。

Java 类比:

String q = "红衣服";
String k = "衣服颜色";
String combined = q + k; // 拼接
double[] weights = softmax(linearTransform(combined)); // 加工成概率
String v = "红色";
String result = weightedSum(weights, v); // 加权提取

将Q,K进行纵轴拼接, 做一次线性变化后再使用tanh函数激活, 然后再进行内部求和, 最后使用softmax处理获得结果再与V做张量乘法。
Attention(Q,K,V)=Softmax(∑(tanh(Linear(Q,K))))⋅V Attention(Q, K, V) = Softmax(\sum(tanh(Linear(Q, K)))) \cdot V Attention(Q,K,V)=Softmax((tanh(Linear(Q,K))))V
跟上面差不多,但多了几步“调味”:tanh 是另一种数学函数,把结果限制在 [-1, 1],求和是把中间结果汇总一下,最后还是用 softmax 弄成概率,再去提取 V。

Java 类比:

double[] temp = tanh(linearTransform(q + k)); // 加点“tanh”调味
double sum = sum(temp); // 汇总
double[] weights = softmax(sum);
String result = weightedSum(weights, v);

将Q与K的转置做点积运算, 然后除以一个缩放系数, 再使用softmax处理获得结果最后与V做张量乘法。
Attention(Q,K,V)=Softmax(Q⋅KTdk)⋅V Attention(Q, K, V) = Softmax(\frac{Q \cdot K^T}{\sqrt{d_k}}) \cdot V Attention(Q,K,V)=Softmax(dkQKT)V
这次不拼接了,直接拿 Q 和 K 做“点积”(一种向量乘法,衡量相似度),怕数字太大就除以一个缩放系数(如除以 8),然后 softmax 变概率,再提取 V。这是最常用的方法,尤其在 Transformer 模型。

Java 类比:

double similarity = dotProduct(q, k); // 点积算相似度
double scaled = similarity / 8; // 缩放
double[] weights = softmax(scaled);
String result = weightedSum(weights, v);

小结

这三种方法就像做菜的不同配方,核心都是算出 Q 对 K 的“关注度”(权重),然后用权重从 V 里拿东西。第三种(点积)最流行,因为简单高效。

2.2 bmm(Batch Matrix Multiply)批量矩阵乘法

当注意力权重矩阵和V都是三维张量且第一维代表为batch条数时,则做bmm运算。一种特殊的张量乘法运算。

# 如果参数1形状是(b × n × m), 参数2形状是(b × m × p), 则输出为(b × n × p)
>>> input = torch.randn(10, 3, 4)
>>> mat2 = torch.randn(10, 4, 5)
>>> res = torch.bmm(input, mat2)
>>> res.size()
torch.Size([10, 3, 5])

平时我们做矩阵乘法是二维的(行 × 列),但实际数据可能三维,如一堆矩阵叠在一起(batch)。bmm 就是一次性处理这堆矩阵的乘法。

Java 类比: 假设你有 10 个订单(batch=10),每个订单有 3 件商品(n=3),每件商品有 4 个属性(m=4)。你还有一个 10×4×5 的矩阵表示额外信息。bmm 就像:

// 伪代码
double[][][] input = new double[10][3][4]; // 10 个 3×4 矩阵
double[][][] mat2 = new double[10][4][5];  // 10 个 4×5 矩阵
double[][][] result = bmm(input, mat2);    // 输出 10 个 3×5 矩阵

注意力机制是注意力计算规则能应用的深度学习网络的载体,同时包括一些必要的全连接层及相关张量处理, 使其与应用网络融为一体。

使用自注意力计算规则的注意力机制称为自注意力机制。

NLP领域中,当前的注意力机制大多应用于seq2seq架构,即编码器和解码器模型。

2.3 注意力机制的作用

解码器端

能根据模型目标有效的聚焦编码器的输出结果,当其作为解码器的输入时提升效果。改善以往编码器输出是单一定长张量,无法存储过多信息的情况。

编码器把一句话压缩成一个向量,但信息有限。注意力机制就像放大镜,让解码器盯着编码器输出的“重点部分”看,而不是一股脑全用。

Javaer 可理解为从数据库查出一堆数据,注意力帮你挑出最相关的几条。

编码器端

主要解决表征问题,相当于特征提取过程,得到输入的注意力表示。一般使用自注意力(self-attention)。

把输入(比如句子)的每个词都“加工”一遍,找出它们之间的关系(自注意力),提取更牛的特征。

Javaer 可理解为相当于给每个字段加个“标签”,标出它跟其他字段的关系。

2.4 注意力机制实现步骤

  1. 根据注意力计算规则,对Q,K,V进行相应计算
  2. 根据第一步采用的计算方法,如果是拼接方法,则需要将Q与第二步的计算结果再进行拼接,如果是转置点积,一般是自注意力,Q与V相同,则不需要进行与Q的拼接
  3. 最后为了使整个attention机制按照指定尺寸输出,使用线性层作用在第二步的结果上做一个线性变换, 得到最终对Q的注意力表示

通俗解释:

  1. 先算 Q 和 K 的“匹配度”,得出权重
  2. 如果是拼接法,把 Q 和结果再拼一下;如果是点积法(自注意力),直接用
  3. 最后用个“转换器”(线性层)把结果调整成想要的形状

Javaer 理解

double[] weights = computeWeights(q, k); // 第 1 步
String intermediate = (method == "concat") ? q + weightedSum(weights, v) : weightedSum(weights, v); // 第 2 步
String output = linearTransform(intermediate); // 第 3 步

3 常见注意力机制的代码

import torch
import torch.nn as nn
import torch.nn.functional as F

class Attn(nn.Module):
    def __init__(self, query_size, key_size, value_size1, value_size2, output_size):
        """初始化函数中的参数有5个, query_size代表query的最后一维大小
           key_size代表key的最后一维大小, value_size1代表value的导数第二维大小, 
           value = (1, value_size1, value_size2)
           value_size2代表value的倒数第一维大小, output_size输出的最后一维大小"""
        super(Attn, self).__init__()
        # 将以下参数传入类中
        self.query_size = query_size
        self.key_size = key_size
        self.value_size1 = value_size1
        self.value_size2 = value_size2
        self.output_size = output_size

        # 初始化注意力机制实现第一步中需要的线性层.
        self.attn = nn.Linear(self.query_size + self.key_size, value_size1)

        # 初始化注意力机制实现第三步中需要的线性层.
        self.attn_combine = nn.Linear(self.query_size + value_size2, output_size)


    def forward(self, Q, K, V):
        """forward函数的输入参数有三个, 分别是Q, K, V, 根据模型训练常识, 输入给Attion机制的
           张量一般情况都是三维张量, 因此这里也假设Q, K, V都是三维张量"""

        # 第一步, 按照计算规则进行计算, 
        # 我们采用常见的第一种计算规则
        # 将Q,K进行纵轴拼接, 做一次线性变化, 最后使用softmax处理获得结果
        attn_weights = F.softmax(
            self.attn(torch.cat((Q[0], K[0]), 1)), dim=1)

        # 然后进行第一步的后半部分, 将得到的权重矩阵与V做矩阵乘法计算, 
        # 当二者都是三维张量且第一维代表为batch条数时, 则做bmm运算
        attn_applied = torch.bmm(attn_weights.unsqueeze(0), V)

        # 之后进行第二步, 通过取[0]是用来降维, 根据第一步采用的计算方法, 
        # 需要将Q与第一步的计算结果再进行拼接
        output = torch.cat((Q[0], attn_applied[0]), 1)

        # 最后是第三步, 使用线性层作用在第三步的结果上做一个线性变换并扩展维度,得到输出
        # 因为要保证输出也是三维张量, 因此使用unsqueeze(0)扩展维度
        output = self.attn_combine(output).unsqueeze(0)
        return output, attn_weights
      
      
      
# 调用:
query_size = 32
key_size = 32
value_size1 = 32
value_size2 = 64
output_size = 64
attn = Attn(query_size, key_size, value_size1, value_size2, output_size)
Q = torch.randn(1,1,32)
K = torch.randn(1,1,32)
V = torch.randn(1,32,64)
out = attn(Q, K ,V)
print(out[0])
print(out[1])

通俗总结: 这代码就是把 Q、K、V 扔进去,按步骤算权重、提取信息、调整输出。输入输出都是三维数组(batch × 行 × 列),就像处理一堆矩阵。

本文已收录在Github关注我,紧跟本系列专栏文章,咱们下篇再续!

  • 🚀 魔都架构师 | 全网30W+技术追随者
  • 🔧 大厂分布式系统/数据中台实战专家
  • 🏆 主导交易系统亿级流量调优 & 车联网平台架构
  • 🧠 AIGC应用开发先行者 | 区块链落地实践者
  • 🌍 以技术驱动创新,我们的征途是改变世界!
  • 👉 实战干货:编程严选网

本文由博客一文多发平台 OpenWrite 发布!

点击查看更多内容
TA 点赞

若觉得本文不错,就分享一下吧!

评论

作者其他优质文章

正在加载中
JAVA开发工程师
手记
粉丝
1.4万
获赞与收藏
1477

关注作者,订阅最新文章

阅读免费教程

  • 推荐
  • 评论
  • 收藏
  • 共同学习,写下你的评论
感谢您的支持,我会继续努力的~
扫码打赏,你说多少就多少
赞赏金额会直接到老师账户
支付方式
打开微信扫一扫,即可进行扫码打赏哦
今天注册有机会得

100积分直接送

付费专栏免费学

大额优惠券免费领

立即参与 放弃机会
微信客服

购课补贴
联系客服咨询优惠详情

帮助反馈 APP下载

慕课网APP
您的移动学习伙伴

公众号

扫描二维码
关注慕课网微信公众号

举报

0/150
提交
取消