【Torch】nn.Embedding算法详解
1. 定义
nn.Embedding
是 PyTorch 中的 查表式嵌入层(lookup‐table),用于将离散的整数索引(如词 ID、实体 ID、离散特征类别等)映射到一个连续的、可训练的低维向量空间。它通过维护一个形状为 (num_embeddings, embedding_dim)
的权重矩阵,实现高效的“索引 → 向量”转换。
2. 输入与输出
-
输入
- 类型:整型张量(
torch.long
或torch.int64
),必须是 LongTensor,其他类型会报错。 - 形状:任意形状
(*, L)
,其中最内层长度L
常为序列长度,前面的*
可以是 batch 及其他维度。 - 取值范围:
0 ≤ index < num_embeddings
;超出范围会抛出IndexError
。
- 类型:整型张量(
-
输出
- 类型:浮点型张量(与权重相同的
dtype
,默认为torch.float32
)。 - 形状:
(*, L, embedding_dim)
;就是在输入张量后追加一个维度embedding_dim
。 - 语义:若输入某位置的值为
j
,则该位置对应输出就是权重矩阵的第j
行。
- 类型:浮点型张量(与权重相同的
3. 底层原理
-
查表操作 vs. One-hot 乘法
- 直观上,Embedding 相当于:
output=one_hot(input) × W \\text{output} = \\text{one\\_hot}(input) \\;\\times\\; Woutput=one_hot(input)×W
其中W
是(num_embeddings×embedding_dim)
的权重矩阵。 - 为避免显式构造稀疏的 one-hot 张量,PyTorch 直接根据索引做“取行”操作,效率更高、内存更省。
- 直观上,Embedding 相当于:
-
梯度更新
- 稠密模式(默认):整个
W
都有梯度缓冲,优化器根据梯度更新所有行。 - 稀疏模式(
sparse=True
):仅对被索引过的行计算和存储梯度,可配合optim.SparseAdam
高效更新,适合超大字典(百万级以上)但每次只访问少量索引的场景。
- 稠密模式(默认):整个
-
范数裁剪
- 若指定
max_norm
,每次前向都会对输出向量(即对应的行)做范数裁剪,保证其 L-norm_type
范数不超过max_norm
,有助于防止某些频繁访问的词向量过大。
- 若指定
-
权重初始化
- 默认初始化使用均匀分布:
Wi,j∼U(−1num_embeddings, 1num_embeddings) W_{i,j} \\sim \\mathcal{U}\\Bigl(-\\sqrt{\\tfrac{1}{\\text{num\\_embeddings}}},\\;\\sqrt{\\tfrac{1}{\\text{num\\_embeddings}}}\\Bigr)Wi,j∼U(−num_embeddings1,num_embeddings1) - 可以通过
_weight
参数传入外部预训练权重(如 Word2Vec、GloVe 等)。
- 默认初始化使用均匀分布:
4. 构造函数参数详解
num_embeddings
int
embedding_dim
int
padding_idx
int
或 None
None
。指定该索引对应行始终输出全零,并且该行的梯度永远为 0,适合做序列填充。max_norm
float
或 None
None
。若设为数值,每次前向时对取出的向量做范数裁剪(L-norm_type
≤ max_norm
)。norm_type
float
,默认 2max_norm
配合使用时定义范数类型,如 1-范数、2-范数等。scale_grad_by_freq
bool
,默认 FalseTrue
,在反向传播阶段按照索引在 batch 中出现的频次对梯度做缩放(出现越多,梯度越小),有助于高频词的梯度平滑。sparse
bool
,默认 FalseTrue
,开启稀疏更新,仅对被访问行生成梯度;必须配合 optim.SparseAdam
使用,不支持常规稠密优化器。_weight
Tensor
或 None
(num_embeddings, embedding_dim)
)作为权重初始化,否则随机初始化。5. 使用示例
import torchimport torch.nn as nn# 1. 参数设定vocab_size = 10000 # 词表大小embed_dim = 300 # 嵌入维度# 2. 创建 Embedding 层embedding = nn.Embedding( num_embeddings=vocab_size, embedding_dim=embed_dim, padding_idx=0, # 将 0 作为填充索引,输出全 0 max_norm=5.0, # 向量范数不超过 5 norm_type=2.0, scale_grad_by_freq=True, sparse=False)# 3. 构造输入# batch_size=2, seq_len=6input_ids = torch.tensor([ [ 1, 234, 56, 789, 0, 23], [123, 4, 567, 8, 9, 0],], dtype=torch.long)# 4. 前向计算# 输出 shape = [2, 6, 300]output = embedding(input_ids)print(output.shape) # -> torch.Size([2, 6, 300])
加载并冻结预训练权重
import numpy as np# 假设有预训练权重 pre_trained.npy,shape=(10000,300)weights = torch.from_numpy(np.load(\"pre_trained.npy\"))embed_pre = nn.Embedding( num_embeddings=vocab_size, embedding_dim=embed_dim, _weight=weights)# 冻结所有权重embed_pre.weight.requires_grad = False
6. 注意事项
- 类型与范围
- 输入必须为 LongTensor,且所有索引满足
0 ≤ index < num_embeddings
。
- 输入必须为 LongTensor,且所有索引满足
- Padding 与 Mask
- 仅指定
padding_idx
会返回零向量,但上游网络(如 RNN、Transformer)还需显式 mask,避免无效位置影响注意力或累积状态。
- 仅指定
- 性能考量
max_norm
每次前向都做范数计算和裁剪,若不需要可关闭以提升速度。
- 稀疏更新限制
sparse=True
可节省内存,但只支持SparseAdam
,且在 GPU 上效率有时不如稠密模式。
- EmbeddingBag
- 对于可变长度序列的 sum/mean/power-mean 汇聚,可使用
nn.EmbeddingBag
,避免中间张量开销。
- 对于可变长度序列的 sum/mean/power-mean 汇聚,可使用
- 分布式与大词表
- 在分布式训练时,可将嵌入表切分到多个进程上(
torch.nn.parallel.DistributedDataParallel
+torch.nn.Embedding
支持参数分布式)。 - 超大词表(千万级)时,可考虑动态加载、分布式哈希表或专用库(如 DeepSpeed 的嵌入稀疏优化)。
- 在分布式训练时,可将嵌入表切分到多个进程上(