揭秘何凯明Masked Autoencoders(MAE):计算机视觉中的可扩展自监督学习新星_深度学习 何凯明 mask任务
揭秘 Masked Autoencoders:计算机视觉中的可扩展自监督学习新星
在自然语言处理(NLP)领域,自监督学习早已通过 BERT 和 GPT 等模型证明了其强大的可扩展性和泛化能力。这些模型的核心思想简单却优雅:通过掩盖输入的一部分(例如单词),让模型学会预测被移除的内容,从而捕获语义和上下文信息。那么,这种“掩码预测”的思想能否在计算机视觉中同样大放异彩呢?答案是肯定的!由 Facebook AI Research (FAIR) 团队提出的 Masked Autoencoders (MAE),为我们在视觉领域提供了一种简单、高效且可扩展的自监督学习方法。本文将深入剖析 MAE 的设计理念、技术细节及其独特之处,带你理解它为何能在 Vision Transformer (ViT) 的基础上掀起新的研究浪潮。
下文中图片来自于原论文:https://arxiv.org/pdf/2111.06377
背景:从 NLP 到视觉的自监督挑战
在 NLP 中,BERT 的成功源于掩码语言建模(Masked Language Modeling, MLM):随机掩盖句子中的部分单词(通常是 15%),然后训练模型预测这些缺失的单词。这种方法之所以有效,是因为语言是高度语义化的信号,单词之间存在丰富的上下文依赖。然而,图像与语言的本质差异使得直接迁移这种思想到计算机视觉并非易事:
- 架构差异:传统视觉任务依赖卷积神经网络(CNN),其基于规则网格的操作难以自然融入类似 BERT 的掩码标记(mask token)或位置嵌入(positional embedding)。直到 ViT 的出现,图像被分割为固定大小的 patch 并通过 Transformer 处理,才为掩码预测提供了舞台。
- 信息密度差异:语言是人类生成的密集语义信号,而图像是自然信号,充满了空间冗余。例如,图像中缺失的一个 patch 可能通过邻近区域的低级统计信息(如颜色或纹理)轻松推断出来,这使得掩码任务变得琐碎,难以诱导模型学习高级语义。
- 解码目标差异:在 BERT 中,解码器预测的是语义丰富的单词,而在视觉中,重建的目标通常是像素,这是一个较低层次的表示,可能无法直接推动模型学习适用于分类或检测等任务的高级特征。
针对这些挑战,MAE 提出了一种创新的解决方案,不仅克服了上述问题,还显著提高了训练效率和模型性能。
MAE 的核心设计
MAE 的核心思想可以概括为:随机掩盖输入图像的大部分 patch,然后通过一个不对称的编码器-解码器架构重建缺失的像素。以下是其关键设计细节:
1. 掩码策略:高比例随机掩盖
MAE 采用了极高的掩盖比例,例如 75%,这与 BERT 的 15% 形成了鲜明对比。为什么要这么做?因为图像的空间冗余使得低比例掩盖(如 20%-50%)不足以构成有意义的挑战。通过掩盖大部分 patch,MAE 迫使模型超越低级统计信息,学习全局的、语义化的表示。例如,重建一只猫的图像时,模型需要理解猫的整体形状和结构,而不仅仅是填充局部纹理。
掩盖过程很简单:将图像分成非重叠的 patch(遵循 ViT 的做法),然后随机选择一部分保留(例如 25%),其余移除。这种均匀随机的采样避免了中心偏差,同时最大化任务难度。
2. 不对称架构:轻量解码器与高效编码器
MAE 的架构如图 1 所示,具有显著的不对称性:
- 编码器:基于 ViT,仅处理未被掩盖的可见 patch(不含 mask token)。例如,若掩盖 75%,编码器只需处理 25% 的输入。这大大减少了计算量和内存需求,使得训练大模型(如 ViT-Huge)变得可行。编码器将这些 patch 通过线性投影和位置嵌入转化为潜在表示,再经过 Transformer 层处理。
- 解码器:负责重建完整图像,输入包括编码器的输出(可见 patch 的潜在表示)和 mask token(表示缺失 patch 的占位符)。解码器为所有 token 添加位置嵌入,然后通过一系列 Transformer 层预测每个掩盖 patch 的像素值。值得注意的是,解码器设计得非常轻量(例如计算量仅为编码器的 10%),仅在预训练中使用,推理时被丢弃。
这种不对称设计是 MAE 的效率之源:编码器专注于少量输入,解码器则处理完整集合,但由于其轻量特性,总体计算开销大幅降低(训练速度提升 3 倍以上)。
3. 重建目标:像素空间中的 MSE 损失
MAE 的解码器直接预测掩盖 patch 的像素值,使用均方误差(MSE)作为损失函数,且仅计算掩盖区域的损失(类似于 BERT)。此外,MAE 还探索了归一化像素值的变体,即对每个 patch 的像素计算均值和标准差后进行归一化。这种方法在实验中提升了表示质量,可能因为它增强了局部对比度。
相比之下,类似 BEiT 的方法预测离散 token(例如通过 DALL-E 的 dVAE 生成),而 MAE 的像素重建更为简单直接,避免了额外的预训练阶段和复杂计算。
实现细节与效率优化
MAE 的实现无需复杂的稀疏操作,具有高度实用性:
- 掩码生成:对所有 patch 生成 token 后,随机打乱列表并移除末尾部分(根据掩盖比例),得到编码器的输入。
- 解码准备:将编码器的输出与 mask token 拼接,然后逆打乱(unshuffle),确保 token 与原始位置对齐。
- 前向传播:编码器处理少量可见 patch,解码器处理完整 token 集合,最后输出重建图像。
这种基于洗牌的实现简单高效,适用于标准硬件(如 TPU),无需专门优化。
实验结果与洞见
MAE 在 ImageNet-1K 上的表现令人瞩目:
- 性能:使用 ViT-Huge 模型,MAE 实现了 87.8% 的 top-1 精度,仅依赖 ImageNet-1K 数据,超越了所有先前方法。
- 迁移学习:在 COCO 目标检测、ADE20K 语义分割等任务中,MAE 预训练显著优于监督预训练,且随着模型规模增大,收益更明显。
- 掩盖比例:75% 的掩盖比例在微调和线性探查中表现最佳,凸显了高掩盖任务的挑战性和有效性(见图 5)。
- 数据增强:MAE 对数据增强依赖极低,甚至仅用中心裁剪也能取得不错效果,这与依赖强增强的对比学习(如 SimCLR)形成鲜明对比。
与现有方法的对比
- 对比 BERT:BERT 的掩盖比例较低(15%),而 MAE 的高掩盖比例适应了图像的冗余特性;BERT 的解码器简单(MLP),而 MAE 的解码器需处理像素重建,设计更关键。
- 对比 BEiT:BEiT 预测 token,依赖额外预训练的 dVAE,而 MAE 直接重建像素,更简单高效,且性能更优。
- 对比对比学习:MoCo v3 等对比学习方法在线性探查中表现更好,但在微调和迁移任务中,MAE 的非线性特征更具优势(见图 9)。
为什么 MAE 有效?
MAE 的成功源于其对图像特性与任务设计的深刻理解:
- 减少冗余:高掩盖比例消除了空间冗余,迫使模型学习全局语义。
- 效率与规模:不对称架构降低了计算成本,使大规模 ViT 模型的预训练成为可能。
- 语义推理:重建任务要求模型理解场景的整体结构(见图 4),这隐式地诱导了丰富的潜在表示。
未来展望
MAE 为视觉自监督学习开辟了新方向。它表明,简单算法结合可扩展设计能在计算机视觉中实现类似 NLP 的突破。未来的研究可以探索:
- 将 MAE 应用于视频或多模态数据。
- 优化解码器设计,进一步提升表示质量。
- 结合高级网络结构(如 ConvNeXt),推动性能极限。
结语
Masked Autoencoders 证明了“掩码预测”不仅适用于语言,也能在视觉领域大放异彩。通过高比例掩盖、不对称架构和像素重建,MAE 在效率与性能上实现了双赢。对于深度学习研究者来说,MAE 不仅是一个值得尝试的工具,更是一个启发性的范例:简单往往蕴含深刻的力量。你准备好用 MAE 开启你的视觉预训练之旅了吗?
代码实现
下面将提供一个基于 Python 和 PyTorch 的 Masked Autoencoders (MAE) 的训练代码和应用代码示例,面向熟悉 Transformer、BERT 和 ViT 的深度学习研究者。会尽量简化代码,同时保留 MAE 的核心思想,并详细解释每个部分的功能和设计理念。代码将基于 ViT 架构,使用 ImageNet 数据风格的图像输入。
MAE 训练代码与应用代码
前提条件
- 环境:PyTorch 2.x,torchvision,numpy。
- 数据:假设使用 ImageNet 数据(或其他图像数据集),输入为 224x224 图像。
- 目标:实现 MAE 的预训练和下游任务微调。
训练代码:MAE 预训练
以下是 MAE 预训练的完整代码,包含编码器、解码器和掩码逻辑。
import torchimport torch.nn as nnimport torch.optim as optimfrom torchvision import datasets, transformsfrom torch.utils.data import DataLoaderimport numpy as npimport uuid# 定义 ViT 风格的 Patch Embeddingclass PatchEmbed(nn.Module): def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): super().__init__() self.num_patches = (img_size // patch_size) ** 2 self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) def forward(self, x): # 输入: [B, C, H, W] -> 输出: [B, num_patches, embed_dim] x = self.proj(x).flatten(2).transpose(1, 2) return x# 定义 Transformer 编码器块class TransformerEncoder(nn.Module): def __init__(self, embed_dim=768, num_heads=12, mlp_ratio=4., drop=0.): super().__init__() self.norm1 = nn.LayerNorm(embed_dim) self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=drop) self.norm2 = nn.LayerNorm(embed_dim) mlp_hidden_dim = int(embed_dim * mlp_ratio) self.mlp = nn.Sequential( nn.Linear(embed_dim, mlp_hidden_dim), nn.GELU(), nn.Linear(mlp_hidden_dim, embed_dim), nn.Dropout(drop) ) def forward(self, x): x = x + self.attn(self.norm1(x), self.norm1(x), self.norm1(x))[0] x = x + self.mlp(self.norm2(x)) return x# MAE 模型class MAE(nn.Module): def __init__(self, img_size=224, patch_size=16, embed_dim=768, encoder_depth=12, decoder_depth=4, num_heads=12): super().__init__() self.patch_embed = PatchEmbed(img_size, patch_size, 3, embed_dim) self.num_patches = self.patch_embed.num_patches self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches, embed_dim)) self.encoder = nn.ModuleList([TransformerEncoder(embed_dim, num_heads) for _ in range(encoder_depth)]) self.decoder_embed = nn.Linear(embed_dim, embed_dim, bias=True) self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) self.decoder_pos_embed = nn.Parameter(torch.zeros(1, self.num_patches, embed_dim)) self.decoder = nn.ModuleList([TransformerEncoder(embed_dim, num_heads) for _ in range(decoder_depth)]) self.decoder_pred = nn.Linear(embed_dim, patch_size * patch_size * 3) # 重建像素 def random_masking(self, x, mask_ratio=0.75): B, N, D = x.shape # [batch_size, num_patches, embed_dim] keep_num = int(N * (1 - mask_ratio)) idx = torch.rand(B, N).argsort(dim=1) # 随机打乱索引 keep_idx = idx[:, :keep_num] # 保留的 patch 索引 mask_idx = idx[:, keep_num:] # 掩盖的 patch 索引 x_masked = torch.gather(x, dim=1, index=keep_idx.unsqueeze(-1).expand(-1, -1, D)) return x_masked, keep_idx, mask_idx def forward_encoder(self, x): x = self.patch_embed(x) + self.pos_embed # 添加位置嵌入 x_masked, keep_idx, mask_idx = self.random_masking(x) # 随机掩盖 for block in self.encoder: x_masked = block(x_masked) return x_masked, keep_idx, mask_idx def forward_decoder(self, x, keep_idx, mask_idx): B = x.shape[0] x = self.decoder_embed(x) # 调整编码器输出维度 full_x = torch.zeros(B, self.num_patches, x.shape[-1], device=x.device) # 将编码器输出放回原始位置 full_x.scatter_(1, keep_idx.unsqueeze(-1).expand(-1, -1, x.shape[-1]), x) # 填充 mask token mask_tokens = self.mask_token.expand(B, self.num_patches - x.shape[1], -1) full_x.scatter_(1, mask_idx.unsqueeze(-1).expand(-1, -1, x.shape[-1]), mask_tokens) full_x = full_x + self.decoder_pos_embed # 添加解码器位置嵌入 for block in self.decoder: full_x = block(full_x) full_x = self.decoder_pred(full_x) # 预测像素 return full_x def forward(self, x): # 编码器 latent, keep_idx, mask_idx = self.forward_encoder(x) # 解码器 recon = self.forward_decoder(latent, keep_idx, mask_idx) return recon, mask_idx# 数据预处理transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])# 假设使用 CIFAR-10 替代 ImageNettrain_dataset = datasets.CIFAR10(root=\'./data\', train=True, download=True, transform=transform)train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)# 训练函数def train_mae(): device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\") model = MAE().to(device) optimizer = optim.AdamW(model.parameters(), lr=1.5e-4, weight_decay=0.05) num_epochs = 10 # 简化演示,实际需更多 epoch for epoch in range(num_epochs): model.train() total_loss = 0 for batch_idx, (images, _) in enumerate(train_loader): images = images.to(device) optimizer.zero_grad() recon, mask_idx = model(images) # 计算掩盖区域的 MSE 损失 B, N, _ = recon.shape target = model.patch_embed(images).view(B, N, -1) loss = ((recon - target) ** 2).mean(dim=-1) # 按 patch 计算 MSE mask_loss = loss.gather(1, mask_idx).mean() # 只计算掩盖区域损失 mask_loss.backward() optimizer.step() total_loss += mask_loss.item() print(f\"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss / len(train_loader):.4f}\") torch.save(model.state_dict(), \"mae_pretrained.pth\")if __name__ == \"__main__\": train_mae()
代码解释:MAE 预训练
-
PatchEmbed:
- 功能:将输入图像分割为 16x16 的 patch,并通过卷积投影到嵌入维度(768)。
- 细节:输入 [B, 3, 224, 224] -> 输出 [B, 196, 768](196 = 224² / 16²)。
- 意义:这是 ViT 的基础,将图像转为序列形式。
-
TransformerEncoder:
- 功能:标准的 Transformer 块,包含多头自注意力(MultiheadAttention)和 MLP。
- 细节:MAE 的编码器和解码器都复用此模块,但深度和宽度不同(编码器深,解码器浅)。
- 意义:处理 patch 间的关系,捕获全局信息。
-
MAE 模型:
- 初始化:
patch_embed
:将图像转为 patch 嵌入。pos_embed
:编码器位置嵌入。mask_token
:解码器中表示掩盖 patch 的可学习参数。decoder_pred
:将解码器输出转为像素值(16x16x3)。
- random_masking:
- 功能:随机掩盖 75% 的 patch,保留 25%。
- 细节:通过打乱索引并分割实现无替换采样,返回保留的 patch 和掩盖索引。
- 意义:高掩盖比例增加任务难度,推动语义学习。
- forward_encoder:
- 功能:仅处理可见 patch。
- 细节:输入完整 patch 嵌入,输出掩盖后的潜在表示。
- 意义:减少计算量,仅 25% patch 进入编码器。
- forward_decoder:
- 功能:重建完整图像。
- 细节:将编码器输出和 mask token 组合,添加位置嵌入后通过解码器预测像素。
- 意义:轻量解码器处理完整 token 集合,实现高效重建。
- forward:
- 功能:完整前向传播,返回重建结果和掩盖索引。
- 初始化:
-
训练逻辑:
- 损失函数:仅计算掩盖 patch 的 MSE,与 BERT 只计算掩码 token 损失类似。
- 优化器:AdamW,学习率 1.5e-4,遵循 ViT 的线性缩放规则。
- 细节:每轮迭代随机掩盖,计算损失并更新模型。
应用代码:下游任务微调
以下是基于预训练 MAE 的微调代码,用于图像分类任务。
import torchimport torch.nn as nnfrom torchvision import datasets, transformsfrom torch.utils.data import DataLoader# 微调模型:仅使用编码器class MAEForClassification(nn.Module): def __init__(self, pretrained_mae, num_classes=10): super().__init__() self.patch_embed = pretrained_mae.patch_embed self.pos_embed = pretrained_mae.pos_embed self.encoder = pretrained_mae.encoder self.cls_token = nn.Parameter(torch.zeros(1, 1, 768)) self.head = nn.Linear(768, num_classes) def forward(self, x): B = x.shape[0] x = self.patch_embed(x) + self.pos_embed # [B, 196, 768] cls_tokens = self.cls_token.expand(B, -1, -1) # [B, 1, 768] x = torch.cat((cls_tokens, x), dim=1) # [B, 197, 768] for block in self.encoder: x = block(x) x = x[:, 0] # 取 cls token 输出 x = self.head(x) return x# 数据预处理transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])train_dataset = datasets.CIFAR10(root=\'./data\', train=True, download=True, transform=transform)train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)test_dataset = datasets.CIFAR10(root=\'./data\', train=False, download=True, transform=transform)test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)# 微调函数def finetune_mae(): device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\") # 加载预训练模型 pretrained_mae = MAE() pretrained_mae.load_state_dict(torch.load(\"mae_pretrained.pth\")) model = MAEForClassification(pretrained_mae).to(device) optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.05) criterion = nn.CrossEntropyLoss() num_epochs = 5 # 训练 for epoch in range(num_epochs): model.train() total_loss = 0 for images, labels in train_loader: images, labels = images.to(device), labels.to(device) optimizer.zero_grad() outputs = model(images) loss = criterion(outputs, labels) loss.backward() optimizer.step() total_loss += loss.item() print(f\"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss / len(train_loader):.4f}\") # 测试 model.eval() correct = 0 total = 0 with torch.no_grad(): for images, labels in test_loader: images, labels = images.to(device), labels.to(device) outputs = model(images) _, predicted = torch.max(outputs, 1) total += labels.size(0) correct += (predicted == labels).sum().item() print(f\"Accuracy: {100 * correct / total:.2f}%\")if __name__ == \"__main__\": finetune_mae()
代码解释:MAE 微调
-
MAEForClassification:
- 功能:基于预训练的 MAE 编码器构建分类模型。
- 细节:
- 复用预训练的
patch_embed
和encoder
。 - 添加
cls_token
(类似 BERT),用于提取全局特征。 head
是一个线性层,将 cls token 输出映射到类别数。
- 复用预训练的
- 意义:丢弃解码器,仅用编码器进行特征提取。
-
前向传播:
- 流程:图像 -> patch 嵌入 -> 添加 cls token -> 编码器 -> 分类头。
- 细节:完整处理所有 patch(无掩盖),输出分类 logits。
- 意义:利用预训练的表示进行下游任务。
-
微调逻辑:
- 损失函数:交叉熵损失,用于分类任务。
- 优化器:AdamW,学习率较高(1e-3),因为微调需要调整参数。
- 测试:计算分类精度,验证预训练效果。
代码特点与注意事项
-
训练效率:
- MAE 的编码器仅处理 25% patch,显著降低计算量。
- 解码器轻量设计进一步优化预训练。
-
可扩展性:
- 可通过增加
encoder_depth
或embed_dim
扩展模型容量。 - 支持更大的图像(如 448x448)或更小的 patch(如 14x14)。
- 可通过增加
-
局限性:
- 示例使用 CIFAR-10,实际应用需 ImageNet 数据和更长训练时间(800-1600 epochs)。
- 未实现数据增强(可添加 RandomResizedCrop)。
-
与论文一致性:
- 掩盖比例 75%、不对称架构、像素重建目标均与论文一致。
- 简化了部分细节(如位置嵌入用固定值而非正弦函数)。
如何运行
- 安装依赖:
pip install torch torchvision numpy
- 运行预训练:
python mae_train.py
- 运行微调:
python mae_finetune.py
总结
MAE 的训练代码实现了高比例掩盖和不对称架构的核心思想,通过预训练捕获强大的视觉表示。应用代码展示了如何将预训练编码器用于下游任务,体现了 MAE 的实用性。对于深度学习研究者,这是一个可直接上手的基础实现,可以进一步优化或扩展以适配具体需求。希望这能帮助你深入理解 MAE 的实际操作!
问题解释
问题 1:如何将预训练编码器用于下游任务?为什么只有编码器,不用解码器?
如何将预训练编码器用于下游任务
在 MAE 的预训练完成后,编码器(Encoder)被提取出来用于下游任务(如图像分类、目标检测等)。具体步骤如下:
-
加载预训练权重:
- 从预训练模型中加载编码器的参数(例如
patch_embed
,pos_embed
, 和encoder
的权重),丢弃解码器部分。
- 从预训练模型中加载编码器的参数(例如
-
调整输入处理:
- 在预训练中,编码器只处理部分未掩盖的 patch(例如 25%),但在下游任务中,输入是完整的图像(所有 patch)。因此,需要调整前向传播逻辑,让编码器处理所有 patch,而不是掩盖后的子集。
-
添加任务特定头(Head):
- 根据下游任务需求,在编码器输出后添加一个任务特定的层。例如:
- 分类任务:添加一个线性层(
nn.Linear
),将编码器输出的特征映射到类别数。 - 检测任务:结合特征金字塔网络(FPN)等结构,输出边界框和类别预测。
- 分类任务:添加一个线性层(
- 根据下游任务需求,在编码器输出后添加一个任务特定的层。例如:
-
微调(Fine-tuning):
- 使用下游任务的数据集(如 ImageNet 分类数据),对编码器和任务头进行端到端微调。微调时通常使用较小的学习率(如 1e-3),以保留预训练学到的表示,同时适配新任务。
在提供的微调代码中,MAEForClassification
类实现了上述过程:
- 复用预训练的
patch_embed
和encoder
。 - 添加
cls_token
(类似 BERT),用于提取全局特征。 - 通过
head
将特征映射到分类类别。
为什么只有编码器,不用解码器?
MAE 的设计目标是学习通用的视觉表示,而不是生成图像。解码器在预训练中的作用是辅助任务(像素重建),而下游任务通常需要特征提取而非图像重建。原因如下:
-
预训练目标:
- MAE 的预训练任务是重建被掩盖的 patch,这是一个自监督的“借口任务”(pretext task)。通过这个任务,编码器被迫学习图像的语义和结构信息,形成强大的潜在表示。
- 解码器仅用于计算重建损失,帮助编码器优化,但它本身不直接生成下游任务所需的特征。
-
下游任务需求:
- 下游任务(如分类、检测)需要的是特征表示,而不是像素级重建。例如,分类任务需要判断图像内容是“猫”还是“狗”,而不需要生成图像。
- 编码器输出的潜在表示(latent representation)已经包含了丰富的语义信息,可以直接用于这些任务。
-
效率与通用性:
- 解码器是轻量设计的(计算量远低于编码器),其作用局限于预训练阶段。保留解码器会增加不必要的计算开销,且对下游任务无直接帮助。
- 只使用编码器符合自监督学习的常见模式(如 BERT 只用编码器部分),保持模型的通用性和灵活性。
因此,在下游任务中,解码器被丢弃,只保留编码器进行特征提取和任务适配。这也是 MAE 的不对称架构带来的效率优势之一。
问题 2:预训练中 forward_encoder
的作用是什么?为什么只有掩码部分经过 Encoder,其他部分不经过?
代码分析
让我们逐行分析 forward_encoder
的实现:
def forward_encoder(self, x): x = self.patch_embed(x) + self.pos_embed # 添加位置嵌入 x_masked, keep_idx, mask_idx = self.random_masking(x) # 随机掩盖 for block in self.encoder: x_masked = block(x_masked) return x_masked, keep_idx, mask_idx
-
x = self.patch_embed(x) + self.pos_embed
:- 功能:将输入图像(例如 [B, 3, 224, 224])通过
PatchEmbed
转为 patch 嵌入([B, 196, 768],其中 196 = 224² / 16²),并添加位置嵌入(pos_embed
)。 - 输出:
x
是所有 patch 的完整嵌入表示,维度为 [batch_size, num_patches, embed_dim]。 - 意义:这是 ViT 的标准步骤,将图像转为 Transformer 可处理的序列形式。
- 功能:将输入图像(例如 [B, 3, 224, 224])通过
-
x_masked, keep_idx, mask_idx = self.random_masking(x)
:- 功能:对
x
进行随机掩盖,默认掩盖比例为 75%。 - 实现:
random_masking
首先生成随机索引,打乱所有 patch 的顺序。- 保留前 25% 的 patch(
keep_idx
),掩盖剩余 75%(mask_idx
)。 x_masked
是保留的 patch 子集,维度变为 [B, 49, 768](49 = 196 * 0.25)。
- 输出:
x_masked
:未掩盖的 patch(25%)。keep_idx
:保留 patch 的索引。mask_idx
:掩盖 patch 的索引。
- 意义:通过高比例掩盖,创建具有挑战性的自监督任务,避免模型仅依赖局部冗余信息。
- 功能:对
-
for block in self.encoder: x_masked = block(x_masked)
:- 功能:将
x_masked
(未掩盖的 patch)输入到 Transformer 编码器块中,逐层处理。 - 细节:编码器由多个 Transformer 块组成(例如 12 层),每个块包含多头自注意力和 MLP。
- 输出:
x_masked
经过编码器后仍是 [B, 49, 768],但包含了更丰富的潜在表示。 - 意义:编码器只处理未掩盖的 patch,学习其语义和上下文信息。
- 功能:将
-
return x_masked, keep_idx, mask_idx
:- 功能:返回编码器处理后的特征(
x_masked
)以及掩盖相关索引。 - 意义:这些返回值用于后续解码器重建完整图像。
- 功能:返回编码器处理后的特征(
为什么只有掩码部分经过 Encoder,其他部分不经过?
这里需要澄清一个误解:经过编码器的是未掩盖的 patch(visible patches),而不是掩盖的部分(masked patches)。具体原因如下:
-
MAE 的不对称设计:
- MAE 的核心创新是编码器只处理未掩盖的 patch(例如 25%),而掩盖的 patch(75%)由解码器通过
mask_token
表示。 - 在
forward_encoder
中,x_masked
是未掩盖的子集,只有这部分被送入编码器。这减少了计算量(只需处理 25% 的数据),提高了训练效率。
- MAE 的核心创新是编码器只处理未掩盖的 patch(例如 25%),而掩盖的 patch(75%)由解码器通过
-
掩盖部分的作用:
- 掩盖的 patch(由
mask_idx
表示)在编码器阶段被移除,不参与计算。它们在解码器阶段通过mask_token
重新引入,用于重建任务。 - 这种设计迫使编码器专注于未掩盖 patch 的表示,而解码器负责推断掩盖区域的内容,形成一个完整的自监督循环。
- 掩盖的 patch(由
-
与下游任务的区别:
- 在预训练中,编码器只看到部分 patch(25%),这是为了学习鲁棒的表示。
- 在下游任务中,编码器处理所有 patch(100%),因为任务目标是利用完整图像信息进行预测,而不是重建。
工作流程总结
- 输入:完整图像 -> 所有 patch 的嵌入(196 个)。
- 掩盖:随机保留 25%(49 个),掩盖 75%(147 个)。
- 编码器:只处理 49 个未掩盖 patch,输出潜在表示。
- 解码器:接收编码器输出 + mask token,重建所有 196 个 patch 的像素。
这种机制确保了编码器在预训练中学会从部分信息推断全局语义,而下游任务则利用完整输入提取特征。
综合解答
-
如何用编码器:
- 预训练后,编码器捕获了通用视觉表示,通过微调适配下游任务(如分类)。只需加载编码器权重,添加任务头即可。
-
为什么不用解码器:
- 解码器仅用于预训练的重建任务,下游任务需要特征而非像素生成,因此丢弃解码器。
-
forward_encoder
的作用:- 将图像转为 patch 嵌入,随机掩盖 75%,只让未掩盖的 25% 通过编码器,输出潜在表示和索引,用于后续重建。
- 掩盖部分不经过编码器,而是由解码器通过 mask token 处理。
这种设计既高效(编码器只处理少量 patch),又有效(高掩盖比例诱导语义学习),是 MAE 的核心优势。
后记
2025年3月26日14点10分于上海,在grok 3大模型辅助下完成。