> 技术文档 > 【Torch】nn.Embedding算法详解

【Torch】nn.Embedding算法详解


1. 定义

nn.Embedding 是 PyTorch 中的 查表式嵌入层(lookup‐table),用于将离散的整数索引(如词 ID、实体 ID、离散特征类别等)映射到一个连续的、可训练的低维向量空间。它通过维护一个形状为 (num_embeddings, embedding_dim)权重矩阵,实现高效的“索引 → 向量”转换。

2. 输入与输出

  • 输入

    • 类型:整型张量(torch.longtorch.int64),必须是 LongTensor,其他类型会报错。
    • 形状:任意形状 (*, L),其中最内层长度 L 常为序列长度,前面的 * 可以是 batch 及其他维度。
    • 取值范围0 ≤ index < num_embeddings;超出范围会抛出 IndexError
  • 输出

    • 类型:浮点型张量(与权重相同的 dtype,默认为 torch.float32)。
    • 形状(*, L, embedding_dim);就是在输入张量后追加一个维度 embedding_dim
    • 语义:若输入某位置的值为 j,则该位置对应输出就是权重矩阵的第 j 行。

3. 底层原理

  1. 查表操作 vs. One-hot 乘法

    • 直观上,Embedding 相当于:
      output=one_hot(input)  ×  W \\text{output} = \\text{one\\_hot}(input) \\;\\times\\; Woutput=one_hot(input)×W
      其中 W(num_embeddings×embedding_dim) 的权重矩阵。
    • 为避免显式构造稀疏的 one-hot 张量,PyTorch 直接根据索引做“取行”操作,效率更高、内存更省。
  2. 梯度更新

    • 稠密模式(默认):整个 W 都有梯度缓冲,优化器根据梯度更新所有行。
    • 稀疏模式sparse=True):仅对被索引过的行计算和存储梯度,可配合 optim.SparseAdam 高效更新,适合超大字典(百万级以上)但每次只访问少量索引的场景。
  3. 范数裁剪

    • 若指定 max_norm,每次前向都会对输出向量(即对应的行)做范数裁剪,保证其 L-norm_type 范数不超过 max_norm,有助于防止某些频繁访问的词向量过大。
  4. 权重初始化

    • 默认初始化使用均匀分布:
      Wi,j∼U(−1num_embeddings,  1num_embeddings) W_{i,j} \\sim \\mathcal{U}\\Bigl(-\\sqrt{\\tfrac{1}{\\text{num\\_embeddings}}},\\;\\sqrt{\\tfrac{1}{\\text{num\\_embeddings}}}\\Bigr)Wi,jU(num_embeddings1,num_embeddings1)
    • 可以通过 _weight 参数传入外部预训练权重(如 Word2Vec、GloVe 等)。

4. 构造函数参数详解

参数 类型及默认 说明 num_embeddings int 必填。嵌入表行数,等于类别总数(最大索引 + 1)。 embedding_dim int 必填。每个向量的维度。 padding_idx intNone 默认 None。指定该索引对应行始终输出全零,并且该行的梯度永远为 0,适合做序列填充。 max_norm floatNone 默认 None。若设为数值,每次前向时对取出的向量做范数裁剪(L-norm_typemax_norm)。 norm_type float,默认 2 与 max_norm 配合使用时定义范数类型,如 1-范数、2-范数等。 scale_grad_by_freq bool,默认 False 若为 True,在反向传播阶段按照索引在 batch 中出现的频次对梯度做缩放(出现越多,梯度越小),有助于高频词的梯度平滑。 sparse bool,默认 False 若为 True,开启稀疏更新,仅对被访问行生成梯度;必须配合 optim.SparseAdam 使用,不支持常规稠密优化器。 _weight TensorNone 若提供,则用此张量(形状应为 (num_embeddings, embedding_dim))作为权重初始化,否则随机初始化。

5. 使用示例

import torchimport torch.nn as nn# 1. 参数设定vocab_size = 10000 # 词表大小embed_dim = 300 # 嵌入维度# 2. 创建 Embedding 层embedding = nn.Embedding( num_embeddings=vocab_size, embedding_dim=embed_dim, padding_idx=0, # 将 0 作为填充索引,输出全 0 max_norm=5.0, # 向量范数不超过 5 norm_type=2.0, scale_grad_by_freq=True, sparse=False)# 3. 构造输入# batch_size=2, seq_len=6input_ids = torch.tensor([ [ 1, 234, 56, 789, 0, 23], [123, 4, 567, 8, 9, 0],], dtype=torch.long)# 4. 前向计算# 输出 shape = [2, 6, 300]output = embedding(input_ids)print(output.shape) # -> torch.Size([2, 6, 300])

加载并冻结预训练权重

import numpy as np# 假设有预训练权重 pre_trained.npy,shape=(10000,300)weights = torch.from_numpy(np.load(\"pre_trained.npy\"))embed_pre = nn.Embedding( num_embeddings=vocab_size, embedding_dim=embed_dim, _weight=weights)# 冻结所有权重embed_pre.weight.requires_grad = False

6. 注意事项

  1. 类型与范围
    • 输入必须为 LongTensor,且所有索引满足 0 ≤ index < num_embeddings
  2. Padding 与 Mask
    • 仅指定 padding_idx 会返回零向量,但上游网络(如 RNN、Transformer)还需显式 mask,避免无效位置影响注意力或累积状态。
  3. 性能考量
    • max_norm 每次前向都做范数计算和裁剪,若不需要可关闭以提升速度。
  4. 稀疏更新限制
    • sparse=True 可节省内存,但只支持 SparseAdam,且在 GPU 上效率有时不如稠密模式。
  5. EmbeddingBag
    • 对于可变长度序列的 sum/mean/power-mean 汇聚,可使用 nn.EmbeddingBag,避免中间张量开销。
  6. 分布式与大词表
    • 在分布式训练时,可将嵌入表切分到多个进程上(torch.nn.parallel.DistributedDataParallel + torch.nn.Embedding 支持参数分布式)。
    • 超大词表(千万级)时,可考虑动态加载、分布式哈希表或专用库(如 DeepSpeed 的嵌入稀疏优化)。