> 技术文档 > 算法面试准备 - 手撕系列第六期 - 多头注意力机制(包括Self_atten和Cross_atten)_手撕多头注意力

算法面试准备 - 手撕系列第六期 - 多头注意力机制(包括Self_atten和Cross_atten)_手撕多头注意力


算法面试准备 - 手撕系列第六期 - 多头注意力机制(包括Self_atten和Cross_atten)

目录

  • 算法面试准备 - 手撕系列第六期 - 多头注意力机制(包括Self_atten和Cross_atten)
  • 多头注意力机制原理
    • 多头注意力机制原理图像
    • 背景介绍
    • 原理解析
      • 1. 输入与嵌入
      • 2. 多头注意力的计算流程
        • (1) 线性变换
        • (2) 注意力计算
        • (3) 拼接与线性变换
    • 总结公式
    • 优缺点分析
      • 优点
      • 缺点
  • 多头注意力机制代码
    • 第一步,引入相关的库函数
    • 第二步,初始化Multi_atten作为一个类
    • 第三步 测试代码 -分为自注意力和交叉注意力两种注意力
  • 参考

多头注意力机制原理

多头注意力机制原理图像

在这里插入图片描述

多头注意力机制原理图

背景介绍

多头注意力机制(Multi-Head Attention)是 Transformer 架构的核心模块之一,用于捕获输入序列中不同位置的复杂依赖关系。通过多个注意力头,它能够从不同的表示子空间中提取信息,从而提高模型的表达能力。


原理解析

多头注意力机制的核心思想是并行计算多个注意力机制(头),然后将它们的输出连接起来,进一步线性变换得到最终结果。

1. 输入与嵌入

如果为自注意力机制则输入为一个qkv统一源的矩阵 (   X q k v ∈ R l e n ( q k v ) × d \\ X_{qkv} \\in \\mathbb{R}^{len(qkv) \\times d}  XqkvRlen(qkv)×d),交叉注意力机制需要输入两个,kv的源矩阵和q的源矩阵(   X q ∈ R l e n q × d \\ X_q \\in \\mathbb{R}^{lenq \\times d}  XqRlenq×d,   X k v ∈ R l e n ( k v ) × d \\ X_{kv} \\in \\mathbb{R}^{len(kv) \\times d}  XkvRlen(kv)×d),其中:

  • 其中len (n) 是输入对应序列的长度(例如,句子中的词数量)。
  • 其中(d) 是输入向量的维度。

输入被映射到查询(Query)、键(Key)和值(Value)矩阵。


2. 多头注意力的计算流程

(1) 线性变换

对于每个注意力头,使用独立的线性变换得到查询、键和值,为生成查询(Query)、键(Key)和值(Value),通过可学习的权重矩阵进行线性变换,如果为自注意力机制则KQV的计算公式为:
Q h = X q k v W Q h , K h = X q k v W K h , V h = X q k v W V h Q_h = X_{qkv}W_Q^h, \\quad K_h = X_{qkv}W_K^h, \\quad V_h = X_{qkv}W_V^h Qh=XqkvWQh,Kh=XqkvWKh,Vh=XqkvWVh
如果是交叉注意力机制则计算公式为:
Q h = X qW Q h , K h = X k v W K h , V h = X k v W V h Q_h = X_{q}W_Q^h, \\quad K_h = X_{kv}W_K^h, \\quad V_h = X_{kv}W_V^h Qh=XqWQh,Kh=XkvWKh,Vh=XkvWVh
其中:

  • W Q h , W K h , W V h ∈ R d × d k W_Q^h, W_K^h, W_V^h \\in \\mathbb{R}^{d \\times d_k} WQh,WKh,WVhRd×dk 是第 h h h 个头的可学习权重矩阵。
  • d k = d h d_k = \\frac{d}{h} dk=hd 是每个头的向量维度, h h h 是头的数量。
(2) 注意力计算

每个头独立计算自注意力(Scaled Dot-Product Attention):
Attention h ( Q h , K h , V h ) = softmax ( Q h K h ⊤ d k )V h \\text{Attention}_h(Q_h, K_h, V_h) = \\text{softmax}\\left(\\frac{Q_h K_h^\\top}{\\sqrt{d_k}}\\right) V_h Attentionh(Qh,Kh,Vh)=softmax(dk QhKh)Vh

(3) 拼接与线性变换

将所有头的输出拼接在一起,并通过一个线性层进行变换:
MultiHead ( Q , K , V ) = Concat ( head 1 , head 2 , … , head h ) W O \\text{MultiHead}(Q, K, V) = \\text{Concat}(\\text{head}_1, \\text{head}_2, \\dots, \\text{head}_h) W_O MultiHead(Q,K,V)=Concat(head1,head2,,headh)WO
其中:

  • head i = Attention i ( Q i , K i , V i ) \\text{head}_i = \\text{Attention}_i(Q_i, K_i, V_i) headi=Attentioni(Qi,Ki,Vi) 是第 i i i 个头的输出。
  • W O ∈ R d × d W_O \\in \\mathbb{R}^{d \\times d} WORd×d 是输出层的权重矩阵。

总结公式

多头注意力机制的完整公式为:
MultiHead ( Q , K , V ) = Concat ( head 1 , head 2 , … , head h ) W O \\text{MultiHead}(Q, K, V) = \\text{Concat}(\\text{head}_1, \\text{head}_2, \\dots, \\text{head}_h) W_O MultiHead(Q,K,V)=Concat(head1,head2,,headh)WO
其中:
head i = softmax ( Q i K i ⊤ d k )V i \\text{head}_i = \\text{softmax}\\left(\\frac{Q_i K_i^\\top}{\\sqrt{d_k}}\\right) V_i headi=softmax(dk QiKi)Vi


优缺点分析

优点

  • 多样性:多个注意力头可以从不同的子空间学习特征。
  • 捕获长距离依赖:能够建模序列中任意位置之间的关系。
  • 提升表示能力:比单头注意力机制具有更高的表达能力。

缺点

  • 计算开销高:计算多个头的注意力会增加计算量和显存开销。
  • 实现复杂性:需要对多个头进行并行计算和拼接。

多头注意力机制代码

以下是基于 PyTorch 实现多头注意力机制的代码:

第一步,引入相关的库函数

# 该模块实现的是多头注意力机制,和单头不一样的点# 1. 需要把头提取出来,2. 需要对mask进行expand\'\'\'# Part1 引入相关的库函数\'\'\'import torchfrom torch import nnimport math

第二步,初始化Multi_atten作为一个类

\'\'\'# Part 2 设计一个多头注意力的类\'\'\'class Multi_atten(nn.Module): def __init__(self,emd_size,q_k_size,v_size,head): super(Multi_atten,self).__init__() # 输入的x为(batch_size,seq_len,emd_size) # 第一步初始化三个全连接矩阵和头的数量 self.head=head # 初始化是head的倍数,便于提取 self.Wk=nn.Linear(emd_size,q_k_size*head) self.Wq=nn.Linear(emd_size,q_k_size*head) self.Wv=nn.Linear(emd_size,v_size*head) # 初始化Softmax函数 self.softmax=nn.Softmax(dim=-1) # 剩下的等会看看 def forward(self,x_q,x_k_v,mask): # 首先得到kvq q=self.Wq(x_q) # (batch_size,q_seq_len,q_size*head) k=self.Wk(x_k_v) v=self.Wv(x_k_v) # 其次是把头分出来得到多头的kvq q=q.reshape(q.size()[0],q.size()[1],self.head,-1).transpose(1,2) # (batch_size,head,q_seq_len,q_size) k = k.reshape(k.size()[0], k.size()[1], self.head, -1).transpose(1,2) v = v.reshape(q.size()[0], v.size()[1], self.head, -1).transpose(1,2) # 把k进行转置 k=k.transpose(2,3) # (batch_size,head,k_seq_len,q_size)# 进行mask(batch,seq_q,seq_k) if mask is not None: mask.unsqueeze(1).expand(-1,self.head,-1,-1) q_k.masked_fill(mask,1e-9)# mask 后进行softmax得到注意力值 q_k=self.softmax(torch.matmul(q,k)/math.sqrt(k.size()[2])) # 和v相乘 atten=torch.matmul(q_k,v) # (batch_size,head,k_seq_len,k_v_size) # 将其进行返回原来的尺寸 atten.transpose(1,2) # (batch_size,k_seq_len,head,k_v_size) atten=atten.reshape(atten.size()[0],atten.size()[1],-1) # (batch_size, k_seq_len, head*k_v_size) return atten

第三步 测试代码 -分为自注意力和交叉注意力两种注意力

if __name__ == \'__main__\': # 类别1 单头的自注意力机制 # 初始化输入x(batch_size,seq_len,emding) batch_size = 1 # 批量也就是句子的数量 emd_size = 128 # 一个token嵌入的维度 seq_len = 5 # kqv源的token长度 q_k_size = emd_size//8 # q和k的嵌入维度 v_size = emd_size//8 # v的嵌入维度 x = torch.rand(size=(batch_size, seq_len, emd_size), dtype=torch.float) self_atten = Multi_atten(emd_size=emd_size, q_k_size=q_k_size, v_size=v_size,head=8) # 初始化mask(batch,len_k,len_q) mask = torch.randn(size=(batch_size, seq_len, seq_len)) mask = mask.bool() print(\'单头的自注意力结果\', self_atten(x, x, mask).size()) # 类别2 单头的交叉注意力机制 # 初始化输入x(batch_size,seq_len,emding) batch_size = 1 # 批量也就是句子的数量 emd_size = 128 # 一个token嵌入的维度 q_seq_len = 5 # q源的token长度 q_k_size = emd_size//8 # q和k的嵌入维度/head k_v_seq_len = 7 # k_v源的token长度 v_size = emd_size//8 # v的嵌入维度/head head=8 # 头的数量 x_q = torch.rand(size=(batch_size, q_seq_len, emd_size), dtype=torch.float) x_k_v = torch.rand(size=(batch_size, k_v_seq_len, emd_size), dtype=torch.float) cross_atten = Multi_atten(emd_size=emd_size, q_k_size=q_k_size, v_size=v_size,head=head) # 初始化mask(batch,len_k,len_q) mask = torch.randn(size=(batch_size, q_seq_len, k_v_seq_len)) mask = mask.bool() print(\'单头的交叉注意力结果\', cross_atten(x_q, x_k_v, mask).size())

参考

自己(值得纪念+1,终于自己会从头开始写多头注意力机制喽,哈(*≧▽≦)):小菜鸟博士-CSDN博客,手撕Transformer – Day3 – MultiHead Attention-CSDN博客

github资料:YanxinTong/Algorithm-Interview-Prep: AI 算法面试准备 - 手撕一些简单模型