多头注意力
class MultiHeadAttention(nn.Module): def __init__(self, embed_size, heads): super(MultiHeadAttention, self).__init__() self.embed_size = embed_size self.heads = heads self.head_dim = embed_size // heads assert self.head_dim * heads == embed_size, \"Embedding size must be divisible by number of heads\" self.values = nn.Linear(embed_size, embed_size) self.keys = nn.Linear(embed_size, embed_size) self.queries = nn.Linear(embed_size, embed_size) self.fc_out = nn.Linear(embed_size, embed_size) def forward(self, values, keys, query, mask): N = query.shape[0] value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1] # Split the embedding into self.heads different pieces values = values.reshape(N, value_len, self.heads, self.head_dim) keys = keys.reshape(N, key_len, self.heads, self.head_dim) query = query.reshape(N, query_len, self.heads, self.head_dim) values = values.permute(2, 0, 1, 3) # (heads, batch_size, value_len, head_dim) keys = keys.permute(2, 0, 1, 3) # (heads, batch_size, key_len, head_dim) query = query.permute(2, 0, 1, 3) # (heads, batch_size, query_len, head_dim) energy = torch.matmul(query, keys.permute(0, 1, 3, 2)) # (heads, batch_size, query_len, key_len) if mask is not None: energy = energy.masked_fill(mask == 0, float(\"-1e20\")) attention = torch.softmax(energy / (self.embed_size ** (1 / 2)), dim=-1) # Scaled dot-product attention out = torch.matmul(attention, values) # (heads, batch_size, query_len, head_dim) out = out.permute(1, 2, 0, 3).reshape(N, query_len, self.heads * self.head_dim) # (batch_size, query_len, embed_size) out = self.fc_out(out) # Final linear layer return out
-
mask
: -
mask == 0
:-
这部分代码的目的是生成一个与
mask
同形状的布尔矩阵。如果mask
中的某个位置是 0,那么mask == 0
的对应位置为True
,否则为False
。 -
例如,如果
mask
是一个[1, 0, 1]
的张量,那么mask == 0
会变成[False, True, False]
。
-
-
energy.masked_fill(mask == 0, float(\"-1e20\"))
:-
energy
是计算出来的注意力分数矩阵,通常形状是(heads, batch_size, query_len, key_len)
,表示每个查询与每个键之间的相似度。 -
masked_fill
是一个 PyTorch 操作,它将energy
张量中符合mask == 0
条件的元素替换成指定的值。在这个例子中,被替换的值是float(\"-1e20\")
。 -
由于
-1e20
是一个非常小的负数,实际上相当于负无穷大。这意味着它会极大地降低被掩盖位置的注意力分数。
-
-
掩蔽未来信息(在解码器中):在 解码器 中,特别是在生成序列时,我们希望每个时间步只依赖于前面的词,而不能“看到”后面的词(避免“未来泄露”)。这种情况可以通过掩码实现。掩码会将未来的词(即当前词后面的词)对应的注意力分数设置为负无穷,这样在 softmax 操作时,这些位置的权重几乎为 0,从而确保模型只依赖于当前及之前的词。
-
处理填充符(padding):在处理变长序列时,常常会使用填充符(如
)来使所有输入序列的长度一致。为了防止这些填充符干扰模型的学习,我们可以使用掩码将填充位置的注意力分数设置为负无穷,这样它们在注意力计算中不会被考虑。