算法面试准备 - 手撕系列第六期 - 多头注意力机制(包括Self_atten和Cross_atten)_手撕多头注意力
算法面试准备 - 手撕系列第六期 - 多头注意力机制(包括Self_atten和Cross_atten)
目录
- 算法面试准备 - 手撕系列第六期 - 多头注意力机制(包括Self_atten和Cross_atten)
- 多头注意力机制原理
-
- 多头注意力机制原理图像
- 背景介绍
- 原理解析
-
- 1. 输入与嵌入
- 2. 多头注意力的计算流程
-
- (1) 线性变换
- (2) 注意力计算
- (3) 拼接与线性变换
- 总结公式
- 优缺点分析
-
- 优点
- 缺点
- 多头注意力机制代码
-
- 第一步,引入相关的库函数
- 第二步,初始化Multi_atten作为一个类
- 第三步 测试代码 -分为自注意力和交叉注意力两种注意力
- 参考
多头注意力机制原理
多头注意力机制原理图像
多头注意力机制原理图
背景介绍
多头注意力机制(Multi-Head Attention)是 Transformer 架构的核心模块之一,用于捕获输入序列中不同位置的复杂依赖关系。通过多个注意力头,它能够从不同的表示子空间中提取信息,从而提高模型的表达能力。
原理解析
多头注意力机制的核心思想是并行计算多个注意力机制(头),然后将它们的输出连接起来,进一步线性变换得到最终结果。
1. 输入与嵌入
如果为自注意力机制则输入为一个qkv统一源的矩阵 ( X q k v ∈ R l e n ( q k v ) × d \\ X_{qkv} \\in \\mathbb{R}^{len(qkv) \\times d} Xqkv∈Rlen(qkv)×d),交叉注意力机制需要输入两个,kv的源矩阵和q的源矩阵( X q ∈ R l e n q × d \\ X_q \\in \\mathbb{R}^{lenq \\times d} Xq∈Rlenq×d, X k v ∈ R l e n ( k v ) × d \\ X_{kv} \\in \\mathbb{R}^{len(kv) \\times d} Xkv∈Rlen(kv)×d),其中:
- 其中len (n) 是输入对应序列的长度(例如,句子中的词数量)。
- 其中(d) 是输入向量的维度。
输入被映射到查询(Query)、键(Key)和值(Value)矩阵。
2. 多头注意力的计算流程
(1) 线性变换
对于每个注意力头,使用独立的线性变换得到查询、键和值,为生成查询(Query)、键(Key)和值(Value),通过可学习的权重矩阵进行线性变换,如果为自注意力机制则KQV的计算公式为:
Q h = X q k v W Q h , K h = X q k v W K h , V h = X q k v W V h Q_h = X_{qkv}W_Q^h, \\quad K_h = X_{qkv}W_K^h, \\quad V_h = X_{qkv}W_V^h Qh=XqkvWQh,Kh=XqkvWKh,Vh=XqkvWVh
如果是交叉注意力机制则计算公式为:
Q h = X qW Q h , K h = X k v W K h , V h = X k v W V h Q_h = X_{q}W_Q^h, \\quad K_h = X_{kv}W_K^h, \\quad V_h = X_{kv}W_V^h Qh=XqWQh,Kh=XkvWKh,Vh=XkvWVh
其中:
- W Q h , W K h , W V h ∈ R d × d k W_Q^h, W_K^h, W_V^h \\in \\mathbb{R}^{d \\times d_k} WQh,WKh,WVh∈Rd×dk 是第 h h h 个头的可学习权重矩阵。
- d k = d h d_k = \\frac{d}{h} dk=hd 是每个头的向量维度, h h h 是头的数量。
(2) 注意力计算
每个头独立计算自注意力(Scaled Dot-Product Attention):
Attention h ( Q h , K h , V h ) = softmax ( Q h K h ⊤ d k )V h \\text{Attention}_h(Q_h, K_h, V_h) = \\text{softmax}\\left(\\frac{Q_h K_h^\\top}{\\sqrt{d_k}}\\right) V_h Attentionh(Qh,Kh,Vh)=softmax(dkQhKh⊤)Vh
(3) 拼接与线性变换
将所有头的输出拼接在一起,并通过一个线性层进行变换:
MultiHead ( Q , K , V ) = Concat ( head 1 , head 2 , … , head h ) W O \\text{MultiHead}(Q, K, V) = \\text{Concat}(\\text{head}_1, \\text{head}_2, \\dots, \\text{head}_h) W_O MultiHead(Q,K,V)=Concat(head1,head2,…,headh)WO
其中:
- head i = Attention i ( Q i , K i , V i ) \\text{head}_i = \\text{Attention}_i(Q_i, K_i, V_i) headi=Attentioni(Qi,Ki,Vi) 是第 i i i 个头的输出。
- W O ∈ R d × d W_O \\in \\mathbb{R}^{d \\times d} WO∈Rd×d 是输出层的权重矩阵。
总结公式
多头注意力机制的完整公式为:
MultiHead ( Q , K , V ) = Concat ( head 1 , head 2 , … , head h ) W O \\text{MultiHead}(Q, K, V) = \\text{Concat}(\\text{head}_1, \\text{head}_2, \\dots, \\text{head}_h) W_O MultiHead(Q,K,V)=Concat(head1,head2,…,headh)WO
其中:
head i = softmax ( Q i K i ⊤ d k )V i \\text{head}_i = \\text{softmax}\\left(\\frac{Q_i K_i^\\top}{\\sqrt{d_k}}\\right) V_i headi=softmax(dkQiKi⊤)Vi
优缺点分析
优点
- 多样性:多个注意力头可以从不同的子空间学习特征。
- 捕获长距离依赖:能够建模序列中任意位置之间的关系。
- 提升表示能力:比单头注意力机制具有更高的表达能力。
缺点
- 计算开销高:计算多个头的注意力会增加计算量和显存开销。
- 实现复杂性:需要对多个头进行并行计算和拼接。
多头注意力机制代码
以下是基于 PyTorch 实现多头注意力机制的代码:
第一步,引入相关的库函数
# 该模块实现的是多头注意力机制,和单头不一样的点# 1. 需要把头提取出来,2. 需要对mask进行expand\'\'\'# Part1 引入相关的库函数\'\'\'import torchfrom torch import nnimport math
第二步,初始化Multi_atten作为一个类
\'\'\'# Part 2 设计一个多头注意力的类\'\'\'class Multi_atten(nn.Module): def __init__(self,emd_size,q_k_size,v_size,head): super(Multi_atten,self).__init__() # 输入的x为(batch_size,seq_len,emd_size) # 第一步初始化三个全连接矩阵和头的数量 self.head=head # 初始化是head的倍数,便于提取 self.Wk=nn.Linear(emd_size,q_k_size*head) self.Wq=nn.Linear(emd_size,q_k_size*head) self.Wv=nn.Linear(emd_size,v_size*head) # 初始化Softmax函数 self.softmax=nn.Softmax(dim=-1) # 剩下的等会看看 def forward(self,x_q,x_k_v,mask): # 首先得到kvq q=self.Wq(x_q) # (batch_size,q_seq_len,q_size*head) k=self.Wk(x_k_v) v=self.Wv(x_k_v) # 其次是把头分出来得到多头的kvq q=q.reshape(q.size()[0],q.size()[1],self.head,-1).transpose(1,2) # (batch_size,head,q_seq_len,q_size) k = k.reshape(k.size()[0], k.size()[1], self.head, -1).transpose(1,2) v = v.reshape(q.size()[0], v.size()[1], self.head, -1).transpose(1,2) # 把k进行转置 k=k.transpose(2,3) # (batch_size,head,k_seq_len,q_size)# 进行mask(batch,seq_q,seq_k) if mask is not None: mask.unsqueeze(1).expand(-1,self.head,-1,-1) q_k.masked_fill(mask,1e-9)# mask 后进行softmax得到注意力值 q_k=self.softmax(torch.matmul(q,k)/math.sqrt(k.size()[2])) # 和v相乘 atten=torch.matmul(q_k,v) # (batch_size,head,k_seq_len,k_v_size) # 将其进行返回原来的尺寸 atten.transpose(1,2) # (batch_size,k_seq_len,head,k_v_size) atten=atten.reshape(atten.size()[0],atten.size()[1],-1) # (batch_size, k_seq_len, head*k_v_size) return atten
第三步 测试代码 -分为自注意力和交叉注意力两种注意力
if __name__ == \'__main__\': # 类别1 单头的自注意力机制 # 初始化输入x(batch_size,seq_len,emding) batch_size = 1 # 批量也就是句子的数量 emd_size = 128 # 一个token嵌入的维度 seq_len = 5 # kqv源的token长度 q_k_size = emd_size//8 # q和k的嵌入维度 v_size = emd_size//8 # v的嵌入维度 x = torch.rand(size=(batch_size, seq_len, emd_size), dtype=torch.float) self_atten = Multi_atten(emd_size=emd_size, q_k_size=q_k_size, v_size=v_size,head=8) # 初始化mask(batch,len_k,len_q) mask = torch.randn(size=(batch_size, seq_len, seq_len)) mask = mask.bool() print(\'单头的自注意力结果\', self_atten(x, x, mask).size()) # 类别2 单头的交叉注意力机制 # 初始化输入x(batch_size,seq_len,emding) batch_size = 1 # 批量也就是句子的数量 emd_size = 128 # 一个token嵌入的维度 q_seq_len = 5 # q源的token长度 q_k_size = emd_size//8 # q和k的嵌入维度/head k_v_seq_len = 7 # k_v源的token长度 v_size = emd_size//8 # v的嵌入维度/head head=8 # 头的数量 x_q = torch.rand(size=(batch_size, q_seq_len, emd_size), dtype=torch.float) x_k_v = torch.rand(size=(batch_size, k_v_seq_len, emd_size), dtype=torch.float) cross_atten = Multi_atten(emd_size=emd_size, q_k_size=q_k_size, v_size=v_size,head=head) # 初始化mask(batch,len_k,len_q) mask = torch.randn(size=(batch_size, q_seq_len, k_v_seq_len)) mask = mask.bool() print(\'单头的交叉注意力结果\', cross_atten(x_q, x_k_v, mask).size())
参考
自己(值得纪念+1,终于自己会从头开始写多头注意力机制喽,哈(*≧▽≦)):小菜鸟博士-CSDN博客,手撕Transformer – Day3 – MultiHead Attention-CSDN博客
github资料:YanxinTong/Algorithm-Interview-Prep: AI 算法面试准备 - 手撕一些简单模型