> 技术文档 > 多头注意力

多头注意力

class MultiHeadAttention(nn.Module): def __init__(self, embed_size, heads): super(MultiHeadAttention, self).__init__() self.embed_size = embed_size self.heads = heads self.head_dim = embed_size // heads assert self.head_dim * heads == embed_size, \"Embedding size must be divisible by number of heads\" self.values = nn.Linear(embed_size, embed_size) self.keys = nn.Linear(embed_size, embed_size) self.queries = nn.Linear(embed_size, embed_size) self.fc_out = nn.Linear(embed_size, embed_size) def forward(self, values, keys, query, mask): N = query.shape[0] value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1] # Split the embedding into self.heads different pieces values = values.reshape(N, value_len, self.heads, self.head_dim) keys = keys.reshape(N, key_len, self.heads, self.head_dim) query = query.reshape(N, query_len, self.heads, self.head_dim) values = values.permute(2, 0, 1, 3) # (heads, batch_size, value_len, head_dim) keys = keys.permute(2, 0, 1, 3) # (heads, batch_size, key_len, head_dim) query = query.permute(2, 0, 1, 3) # (heads, batch_size, query_len, head_dim) energy = torch.matmul(query, keys.permute(0, 1, 3, 2)) # (heads, batch_size, query_len, key_len) if mask is not None: energy = energy.masked_fill(mask == 0, float(\"-1e20\")) attention = torch.softmax(energy / (self.embed_size ** (1 / 2)), dim=-1) # Scaled dot-product attention out = torch.matmul(attention, values) # (heads, batch_size, query_len, head_dim) out = out.permute(1, 2, 0, 3).reshape(N, query_len, self.heads * self.head_dim) # (batch_size, query_len, embed_size) out = self.fc_out(out) # Final linear layer return out
  • mask

    • mask 是一个掩码矩阵,它的形状通常是 (batch_size, seq_len) ,其中每个元素的值为 0 或 1。

    • 掩码矩阵的作用是指示哪些位置是需要被遮盖的。在自注意力机制中,掩码通常用于遮挡某些位置的注意力权重,以便模型不能“看到”它们。

  • mask == 0

    • 这部分代码的目的是生成一个与 mask 同形状的布尔矩阵。如果 mask 中的某个位置是 0,那么 mask == 0 的对应位置为 True,否则为 False

    • 例如,如果 mask 是一个 [1, 0, 1] 的张量,那么 mask == 0 会变成 [False, True, False]

  • energy.masked_fill(mask == 0, float(\"-1e20\"))

    • energy 是计算出来的注意力分数矩阵,通常形状是 (heads, batch_size, query_len, key_len),表示每个查询与每个键之间的相似度。

    • masked_fill 是一个 PyTorch 操作,它将 energy 张量中符合 mask == 0 条件的元素替换成指定的值。在这个例子中,被替换的值是 float(\"-1e20\")

    • 由于 -1e20 是一个非常小的负数,实际上相当于负无穷大。这意味着它会极大地降低被掩盖位置的注意力分数。

  • 掩蔽未来信息(在解码器中):在 解码器 中,特别是在生成序列时,我们希望每个时间步只依赖于前面的词,而不能“看到”后面的词(避免“未来泄露”)。这种情况可以通过掩码实现。掩码会将未来的词(即当前词后面的词)对应的注意力分数设置为负无穷,这样在 softmax 操作时,这些位置的权重几乎为 0,从而确保模型只依赖于当前及之前的词。

  • 处理填充符(padding):在处理变长序列时,常常会使用填充符(如 )来使所有输入序列的长度一致。为了防止这些填充符干扰模型的学习,我们可以使用掩码将填充位置的注意力分数设置为负无穷,这样它们在注意力计算中不会被考虑。