【CVPR 2025】 炸场新作!MambaVision 横空出世,重新定义视觉主干网络
【CVPR 2025】 炸场新作!MambaVision 横空出世,重新定义视觉主干网络
当 Transformer 的瓶颈愈发明显,谁能扛起视觉模型创新的大旗?CVPR 2025 最新收录论文重磅揭晓答案 ——MambaVision,一款颠覆性的混合型 Mamba-Transformer 视觉主干网络,正在改写视觉领域的技术版图!
MambaVision 并非简单的架构堆砌,而是深度融合 Mamba 的高效序列建模能力与 Transformer 的强表征优势。Mamba 独特的线性复杂度机制,成功解决了传统 Transformer 在长序列处理时计算量爆炸的难题,让模型在处理海量视觉信息时也能 “健步如飞”;Transformer 强大的自注意力机制,则保障了视觉特征的精准捕捉,二者结合堪称 “天作之合”。实验数据更是令人惊叹,在图像分类、目标检测、语义分割等多个主流视觉任务上,MambaVision 均实现了性能跃升,超越一众 SOTA 模型!
不仅如此,MambaVision 还暗藏诸多 “黑科技”。通过精心设计的混合模块,实现了不同架构优势的无缝衔接,让模型既能快速处理全局信息,又能聚焦关键细节;轻量化设计使其在资源受限设备上也能游刃有余,真正做到 “高性能” 与 “低能耗” 兼得。
这一成果的诞生,无疑为视觉领域带来了全新的解题思路。无论是学术界的算法攻坚,还是工业界的落地应用,MambaVision 都展现出巨大潜力。想知道 MambaVision 背后还有哪些硬核技术?它将如何引领视觉技术的下一波浪潮?我们一起来解锁 CVPR 顶会论文中的全部细节,一起见证视觉 AI 的新巅峰!
摘要
- 重新设计Mamba公式,以增强其对视觉特征进行有效建模的能力。
- 在 Mamba 架构的最终层配备多个自注意力块,可以极大地提高模型捕获长距离空间依赖的能力。
诸论
1.注意力机制相对于序列长度的二次复杂度使得Transformer在训练和部署方面的计算成本很高
解释:计算时,序列中每个位置都要跟其他所有位置对比计算相关性,序列长度为n,复杂度是O(n²)。比如,处理100词的序列,要计算10000次相关性;序列增到1000词,就得算1000000次,计算量剧增。这导致训练和部署时,计算资源和时间成本都很高。
文中解决:“Mamba:具有选择性状态空间的线性时间序列建模”提出了新的空间状态模型,实现了线性时间复杂度,效果与transformer持平
2.Mamba 的核心贡献是一种新的选择机制,这种机制可以实现高效且依赖输入的选择操作。
3.Mamba的自回归公式虽然在需要顺序数据处理的任务中有效,但是图像的数据是没有顺序依赖关系的,之前提到的“Mamba:具有选择性状态空间的线性时间序列建模”限制了第一次前向传播中捕获全局上下文的能力
提出观点:视觉任务通常需要理解全局上下文才能对局部区域做出准确预测
4.Vision Mamba:基于双向状态空间模型的高效视觉表征学习”提出用双向SSM修改,但是还是缺乏全局上下文和空间理解问题,虽然双向的可能捕获上下文,但是需要在预测之前处理整个序列,引入显著延迟
5.复杂性会有过拟合风险,并不总是带来更高的准确率,并且现在ViT+CNN的架构仍然优于mamba
解决问题工作:
- 重新设计了Mamba模块,使其更适合视觉任务。
- 提出了一种混合架构,它包含我们提出的公式(即MambaVisionMixer和MLP)以及Transformer模块。
- 研究了不同的集成模式
- 结果表明:
在最终阶段利用多个自注意力块可以显著增强捕捉全局上下文和长距离空间依赖的能力。
使用混合架构提高图像吞吐量
利用基于 CNN 的残差块来快速提取更高分辨率特征的特征。
贡献
- 我们介绍了一种重新设计的视觉友好型 Mamba 模块,它在原始 Mamba 架构的基础上提高了准确性和图像吞吐量。
- 我们对 Mamba 和 Transformer 模块的集成模式进行了系统研究,并证明在最终阶段加入自注意力模块可以显著提高模型捕获全局上下文和长距离空间依赖的能力。
- 这是一种新颖的混合 Mamba Transformer 模型。分层我们介绍了MambaVision,MambaVision 在 ImageNet-1K数据集上实现了 Top-1 和图像吞吐量权衡方面的新 SOTA Pareto 前沿。
相关工作
一、Conv-Based(基于卷积的模型)
- 核心进展:
CNN 长期作为视觉 backbone 的基石,近年通过融合 Transformer 思想实现性能突破。例如:- ConvNeXt:通过扩大宽度、使用更大卷积核和层归一化,缩小了与 Transformer 的性能差距。
- RegNetY:通过设计空间分析实现系统性网络设计。
- EfficientNetV2:利用神经架构搜索和渐进学习优化效率 - 精度权衡。
- 局限性:
CNN 天然缺乏全局感受野,难以捕捉长距离空间依赖,这是其在复杂视觉任务(如语义分割)中的主要瓶颈。
二、Transformer-Based(基于 Transformer 的模型)
- 核心进展:
ViT 引入自注意力机制,通过全局上下文建模提升视觉任务性能,代表性工作包括:- DeiT:通过蒸馏训练在中小规模数据集上提升 ViT 性能。
- Swin Transformer:采用分层架构和滑动窗口自注意力,平衡局部与全局上下文,降低计算复杂度。
- Twins/PVT:通过空间可分离自注意力和分层结构优化效率。
- 局限性:
自注意力的二次复杂度(与序列长度成平方关系)导致计算和内存开销大,限制了模型在高分辨率输入下的效率。
三、Conv-Transformer(卷积 - Transformer 混合模型)
- 核心思路:融合 CNN 的局部特征提取优势与 Transformer 的全局建模能力。
- 代表性工作:
- CoAT/CrossViT:直接结合卷积与自注意力,增强特征学习。
- NextViT:系统性将 CNN 式处理融入 Transformer,优化工业部署效率。
- FasterViT:通过分层注意力设计提升吞吐量,平衡精度与速度。
- 优势:通过混合架构在精度、效率之间取得更好平衡,成为近年研究热点。
四、Mamba-Based(基于 Mamba 的模型)
- Mamba 核心优势:
作为线性复杂度的状态空间模型(SSM),Mamba 在自然语言处理中已证明高效处理长序列的能力,近年被引入视觉领域。 - 现有视觉 Mamba 模型及不足:
- Vim(Vision Mamba):
- 提出双向 SSM 处理 tokens,试图捕捉全局上下文。
- 缺陷:双向处理需等待整个序列输入,引入显著延迟,训练复杂度高,准确率提升有限。
- VMamba:
- 设计 Cross-Scan Module(CSM)通过四向扫描整合上下文,结合深度卷积和分层结构。
- 缺陷:CSM 的感受野受限于扫描路径,全局建模能力仍不足,且结构复杂。
- EfficientVMamba:
- 对高分辨率使用 SSM,低分辨率使用 CNN,但分辨率处理策略与 MambaVision 相反,导致精度和吞吐量落后。
- SiMBA:
- 通过 EinFFT 通道建模提升 Mamba 稳定性,但未解决空间理解的根本局限。
- Vim(Vision Mamba):
- MambaVision 的创新定位:
- 首个混合架构:首次系统性融合 Mamba 与 Transformer,在最后阶段引入自注意力块,弥补 Mamba 单向处理的全局上下文缺陷。
- 效率优化:前两阶段使用 CNN 快速提取高分辨率特征,后两阶段结合 Mamba 的高效序列处理与自注意力的全局建模,实现吞吐量与精度的双提升(如图 1 所示,MambaVision 在 ImageNet 上达到新的 Pareto 最优前沿)。
方法
论文的方法论部分详细阐述了 MambaVision 的架构设计,核心围绕分层混合架构与视觉友好的 Mamba 块改进,旨在解决纯 Mamba 模型在视觉任务中的局限性(如全局上下文不足、空间建模低效)。以下从宏观架构、微观设计、关键创新三方面展开深度解析:
一、宏观架构:分层设计与模块分工
1. 四阶段分层架构(图 2)
MambaVision 采用金字塔式分层结构,将视觉特征提取分为四个阶段,按分辨率递减划分为:
-
Stage 1-2(高分辨率):CNN 主导的快速特征提取
- 输入处理:通过两层 3×3 卷积(stride=2)组成的 stem 模块,将原始图像(H×W×3)转换为 H/4×W/4×C 的嵌入,保留高频空间细节。
- 残差块设计:使用类 ResNet 的残差块(公式 1),包含 GELU 激活和 BN 层,通过短连接缓解梯度消失,高效提取局部纹理和边缘特征。
- 下采样:阶段间通过 3×3 卷积(stride=2)将分辨率减半,逐步降低空间维度、增加通道数,为后续阶段减少计算量。
-
Stage 3-4(低分辨率):Mamba-Transformer 混合模块
- 核心目标:在低分辨率下(如 14×14、7×7),利用 Mamba 的线性复杂度处理长序列,同时通过自注意力补充全局上下文。
- 模块配比:每个阶段包含 N 层,前 N/2 层为 MambaVision Mixer(高效局部 - 长程建模),后 N/2 层为自注意力块(全局依赖捕捉),形成 “先 Mamba 后 Transformer” 的递进式设计(表 5 ablation 证明此模式最优)。
2. 设计逻辑
- 分工协同:
- CNN 阶段(Stage1-2):聚焦高分辨率下的局部特征提取,利用卷积的权值共享和局部连接优势,避免直接对高分辨率输入使用 Mamba/Transformer 的高计算成本。
- 混合阶段(Stage3-4):低分辨率下序列长度缩短(如 7×7 对应 49 tokens),自注意力的二次复杂度(O (N²))可控,同时 Mamba 的线性复杂度(O (N))在长序列处理中保持高效,两者结合平衡全局与局部建模。
- 分辨率递减优势:
随着分辨率降低,特征图语义抽象度提升,此时引入自注意力可更高效地建模全局关系(如物体各部分关联),避免早期阶段因语义信息不足导致的注意力冗余。
二、微观架构:MambaVision Mixer 与自注意力设计
1. 设计动机
原始 Mamba 在视觉任务中存在两大局限:
- 因果约束:Mamba 的因果卷积强制按序列顺序处理,不适合视觉任务中并行的空间建模需求。
- 空间信息损失:纯 SSM(状态空间模型)路径缺乏显式的空间特征提取能力。
MambaVision Mixer 通过对称双分支结构解决上述问题:
2. 数学公式
- 分支 1(SSM 路径):通过选择性扫描(Selective Scan)捕获长程依赖。
- 分支 2(对称路径):通过常规卷积和激活函数保留空间特征。
- 融合策略:直接拼接两分支输出,强制模型学习互补信息。
```pythonclass MambaVisionMixer(nn.Module): def __init__(self, in_channels, hidden_channels): super().__init__() self.in_proj = nn.Linear(in_channels, hidden_channels) self.conv1 = nn.Conv1d(hidden_channels//2, hidden_channels//2, kernel_size=3, padding=1) self.conv2 = nn.Conv1d(hidden_channels//2, hidden_channels//2, kernel_size=3, padding=1) self.silu = nn.SiLU() self.out_proj = nn.Linear(hidden_channels, in_channels) def forward(self, x): # 输入投影并分割为两分支 x = self.in_proj(x) xz = x.chunk(2, dim=-1) # 对称分支(空间特征增强) x2 = self.silu(self.conv2(xz[1].transpose(1, 2)).transpose(1, 2)) # SSM分支(长程依赖建模) x_proj = self.silu(self.conv1(xz[0].transpose(1, 2)).transpose(1, 2)) x1 = selective_scan_fn(x_proj) # 实际中需实现SSM核心逻辑 # 分支融合 x_out = torch.cat([x1, x2], dim=-1) x_out = self.out_proj(x_out) return x_out
4. 关键技术点
- 因果卷积替换:将原始 Mamba 的因果卷积改为常规卷积(
kernel_size=3, padding=1
),允许并行处理空间维度。 - 通道分割:将输入通道均分为两部分(各占
hidden_channels//2
),分别处理后再拼接,平衡计算成本与表达能力。 - 选择性扫描(Selective Scan):通过
selective_scan_fn
动态调整 SSM 参数,过滤无关信息,适应视觉特征的多样性。
二、自注意力模块设计
1. 设计位置与作用
- 位置:仅在 Stage3-4 的后 N/2 层使用自注意力,而非全网络部署。
- 作用:在特征抽象度最高的阶段,显式建模全局像素关联,弥补 Mamba 单向扫描的局限性。
4. 关键技术点
- 窗口化优化:论文未明确提及,但实践中可能采用局部窗口自注意力(如 Swin Transformer)降低计算复杂度。
- 维度匹配:通过
reshape
和permute
实现多头注意力的并行计算,保持通道数不变。
三、MambaVision Layer 整体架构
1. 层结构设计
其中:
Mixer
在前 N/2 层为 MambaVision Mixer,后 N/2 层为自注意力。Norm
通常为层归一化(Layer Normalization)。
四、关键创新点与技术突破
1. 混合架构的理论依据
- Mamba 的局限性破解:
原始 Mamba 的因果处理适合文本等序列任务,但视觉任务需要并行的空间建模。通过前半段 Mamba 的选择性扫描(保留长程依赖建模效率)+ 后半段自注意力(显式全局关联),实现 “线性复杂度打底,全局能力增强” 的双赢。 - CNN-Transformer-Mamba 的三角协同:
- CNN:快速提取底层视觉特征(边缘、纹理),高分辨率下效率远超 Transformer/Mamba。
- Mamba:在中等分辨率(如 28×28→14×14)下,以线性复杂度处理序列,捕获跨区域的长程依赖(如物体局部部件的关联)。
- Transformer:在低分辨率(7×7)下,以可控的二次复杂度(49²=2401)建模全局上下文(如整个物体的结构、多物体空间布局)。
2. 效率与精度的平衡艺术
-
计算复杂度优化:
- CNN 阶段:3×3 卷积的 FLOPs 随分辨率降低呈平方级下降,Stage1-2 的总计算量仅占整体的~30%。
- Mamba 阶段:线性复杂度(O (N))处理中等长度序列(如 14×14=196 tokens),FLOPs 远低于同层自注意力(O (N²))。
- 自注意力阶段:低分辨率下 N 小,O (N²) 可控(如 7×7=49 tokens 时,FLOPs 仅为 224×224 分辨率下的 1/256)。
最终,MambaVision-B 的 GFLOPs 比 MaxViT-B 低 56%,但精度更高(表 1)。
-
吞吐量提升关键:
- CNN 和 Mamba 的卷积 / 线性操作高度优化,适合 GPU 并行计算;自注意力仅在最后阶段少量使用,避免成为瓶颈。
- 单向处理(非双向 SSM):相比 Vim 的双向处理(需等待整个序列输入),MambaVision 的单向 Mamba + 单向自注意力支持流式处理,延迟更低(图 1 中吞吐量显著高于 VMamba/Vim)。
- CNN 和 Mamba 的卷积 / 线性操作高度优化,适合 GPU 并行计算;自注意力仅在最后阶段少量使用,避免成为瓶颈。
3. 消融实验验证设计合理性
-
Mixer 设计验证(表 4):
- 因果卷积→常规卷积(+0.4% Top-1):证明视觉任务中并行处理优于序列依赖。
- 增加对称分支(+1.0% Top-1):验证空间特征补充的必要性。
- 拼接融合(vs 门控融合):拼接使模型直接学习双分支互补,避免门控机制的信息损耗。
-
混合模式验证(表 5):
- 最后 N/2 层自注意力(vs 随机 / 前半段):证明高层特征需要自注意力显式建模全局关系,早期阶段 Mamba 的隐式长程建模已足够。
五、与现有 Mamba 视觉模型的核心区别
对称双分支 Mixer
总结:方法论的核心价值
MambaVision 的方法论通过 “分层分工、模块互补” 的设计哲学,解决了纯 Mamba 模型在视觉任务中的两大痛点:
- 空间建模低效:前阶段 CNN 强化局部特征,Mixer 的对称分支补充空间细节,避免 SSM 的序列处理偏差。
- 全局依赖不足:后阶段自注意力显式捕捉长距离关系,与 Mamba 的隐式长程建模形成 “粗粒度全局 + 细粒度局部” 的特征层次。
这种设计不仅继承了 Mamba 的线性复杂度优势,还通过巧妙引入自注意力(仅在必要时使用),在精度 - 效率权衡上达到新高度,为后续混合视觉架构提供了可复用的设计范式(如 “早期高效特征提取 + 后期全局增强”)。
深度解析:4. 实验设置与 5. 结果分析
一、4. 实验设置(Experiments)
1. 图像分类实验
-
数据集与训练配置:
- 数据集:ImageNet-1K,包含 1000 类,128 万训练图像,5 万验证图像。
- 训练策略:遵循经典视觉模型训练范式(如ConvNeXt、Swin Transformer),训练 300 个 epoch,使用 32 张 NVIDIA A100 GPU,批量大小 4096(隐含在 “标准训练配方” 中)。
- 自注意力优化:阶段 3(14×14 分辨率)使用窗口大小 14,阶段 4(7×7 分辨率)使用窗口大小 7,平衡局部建模效率与全局依赖捕获(见附录 Table S.1,大窗口提升 0.1%-0.3% 准确率,吞吐量几乎不变)。
-
关键技术细节:
- 分层混合架构验证:通过控制阶段 3-4 中 MambaVision Mixer 与自注意力的比例(前 N/2 层 Mixer,后 N/2 层自注意力),验证 “先高效局部建模,后全局增强” 的设计合理性。
- 基线对比:覆盖四大类模型 ——
- Conv-Based:ConvNeXt、ResNet、EfficientNetV2
- Transformer-Based:Swin Transformer、DeiT、Twins
- Conv-Transformer:NextViT、FasterViT、MaxViT
- Mamba-Based:Vim、VMamba、EfficientVMamba
2. 下游任务实验
-
目标检测与实例分割(MS COCO):
- 骨干网络:使用预训练的 MambaVision 作为骨干,接入 Cascade Mask-RCNN 头,采用 ×3 学习率调度(训练 36 个 epoch,LR 先升后降),输入分辨率 1280×800。
- 对比模型:Swin Transformer、ConvNeXt、ResNet 等,确保参数量和 FLOPs 可比(如 MambaVision-T 与 Swin-T 参数均为~86M)。
-
语义分割(ADE20K):
- 骨干网络:预训练模型接入 UperNet 头,使用 8 张 A100 GPU 训练,输入分辨率 512×512,验证多分辨率特征融合能力(MambaVision 的分层特征对密集预测任务至关重要)。
二、5. 结果分析(Results)
1. ImageNet 分类:准确率与吞吐量双 SOTA
-
核心数据(Table 1):
模型 Top-1 Accuracy 吞吐量 (Img/Sec) GFLOPs 参数 (M) ConvNeXt-B 83.8% 1485 15.4 88.6 Swin-B 83.5% 535 15.1 87.9 VMamba-B 83.9% 645 15.4 89.0 MambaVision-B 84.2% 3670 15.0 97.7 - 优势解析:
- 准确率提升:相比纯 CNN(ConvNeXt-B +0.4%)和纯 Transformer(Swin-B +0.7%),混合架构通过 “CNN 快速特征提取 + Mamba 长程建模 + 自注意力全局增强” 实现更高特征表达能力。
- 吞吐量翻倍:MambaVision-B 吞吐量达 3670 Img/Sec,远超 VMamba-B(645)和 Swin-B(535),得益于:
- 前两阶段 CNN 的高效并行计算(3×3 卷积占比高,GPU 优化成熟);
- 后阶段自注意力仅在低分辨率(7×7)使用,计算量为 224×224 分辨率的 1/256。
- 效率碾压:MambaVision-B 的 GFLOPs(15.0)与 VMamba-B(15.4)相近,但准确率高 0.3%,吞吐量高 470%,打破 “准确率 - 吞吐量” 权衡瓶颈(图 1 的 Pareto 前沿)。
- 优势解析:
-
规模化能力:
- 首次在 Mamba-based 模型中扩展至 ImageNet-21K 预训练,MambaVision-L3 在 512 分辨率下 Top-1 达 88.1%,证明大模型 scalability(图 4)。
2. 目标检测与实例分割(MS COCO,Table 2)
-
核心指标(Box AP / Mask AP):
- MambaVision-T vs Swin-T:51.1% / 44.3% vs 50.4% / 43.7%,分别 + 0.7% / +0.6%,显示更强的物体定位与掩码预测能力。
- MambaVision-B vs ConvNeXt-B:52.8% / 45.7% vs 52.7% / 45.6%,以相同 FLOPs(964G)实现微弱但关键超越,验证混合架构在复杂场景的泛化性。
-
关键原因:
- 多尺度特征优势:阶段 3-4 的 MambaVision Mixer 捕获跨区域依赖(如物体部件关联),自注意力补全全局上下文(如物体与背景关系),提升检测头对遮挡、小物体的识别能力。
- 预训练迁移性:ImageNet 预训练的全局上下文建模能力有效迁移至密集预测任务,减少对额外数据增强的依赖。
3. 语义分割(ADE20K,Table 3)
-
核心指标(mIoU):
- MambaVision-T vs Swin-T:46.0% vs 44.5%,+1.5%,在语义复杂场景(如多物体重叠)中优势显著。
- MambaVision-B vs Focal-B:49.1% vs 49.0%,以更低参数(126M vs 126M)和 FLOPs(1342G vs 1354G)实现超越。
-
技术优势:
- 空间 - 序列联合建模:对称分支的卷积保留边缘、纹理等底层空间特征,SSM 分支捕获长程语义关联(如 “天空 - 云朵”“道路 - 车辆”),提升像素级分类精度。
- 分辨率鲁棒性:对高分辨率输入(512×512),Mamba 的线性复杂度避免计算爆炸,而 Transformer 的二次复杂度模型(如 Swin-B)出现性能饱和。
三、实验设计的底层逻辑与验证
1. 控制变量法验证混合架构
- 消融实验(隐含在结果分析中):
- 自注意力位置:仅在最后 N/2 层使用自注意力时,准确率比随机位置高 1.0%(Table 5),证明高层特征更需要全局建模。
- 对称分支必要性:移除对称分支导致 ImageNet 准确率下降 1.8%(Table 4),验证空间特征对视觉任务的不可替代性。
2. 工业级部署导向
- 吞吐量优先:在保证准确率的前提下,通过以下设计提升推理速度:
- CNN 阶段使用 3×3 卷积和 BN,比 Transformer 的 LayerNorm 更适合硬件加速;
- 避免双向 SSM(如 Vim 的双向扫描增加延迟),采用单向 Mamba + 单向自注意力,支持流式处理。
3. 跨任务泛化性证明
- 统一骨干优势:同一 MambaVision 骨干在分类、检测、分割任务均达 SOTA,验证其特征表示的通用性,降低多任务部署成本。
四、潜在局限与未来方向
- SSM 核心逻辑简化:代码中
selective_scan_fn
为模拟实现,实际需优化 SSM 参数动态生成(如论文中的 A、B、C 矩阵离散化),进一步提升长程建模效率。 - 多模态扩展:当前聚焦视觉任务,未来可探索与 NLP、语音的跨模态融合,利用 Mamba 的序列建模优势构建统一架构。
- 数据效率:在中小规模数据集(如 CIFAR-100)上的表现未提及,需验证低数据场景下的泛化能力。
总结:实验与结果的核心价值
MambaVision 通过精心设计的实验,从三个维度证明其创新性:
- 横向对比:在四大类模型中刷新 ImageNet 准确率 - 吞吐量边界,尤其碾压纯 Mamba 模型(如 VMamba-B 准确率 + 0.3%,吞吐量 + 470%)。
- 纵向验证:消融实验与混合模式分析,揭示 “CNN 快速特征提取→Mamba 长程建模→自注意力全局增强” 的逐层优化逻辑。
- 下游迁移:在检测、分割任务中的稳定提升,证明混合架构不仅是分类专用,更是通用视觉骨干的有效范式
这些结果为 “轻量高效骨干网络” 研究提供了新标杆,也为 Mamba 系列模型在工业界的落地奠定了实验基础。
代码解析
MambaVision
MambaVision 是一种混合架构,它结合了 Mamba 的高效序列处理能力和 Transformer 的全局建模能力,特别适合处理视觉任务中的长序列数据。该模型采用了分层设计,在不同阶段使用不同的模块,既能捕捉局部特征,又能处理长距离依赖关系.
class MambaVision(nn.Module): \"\"\" MambaVision, \"\"\" def __init__(self, dim, in_dim, depths, window_size, mlp_ratio, num_heads, drop_path_rate=0.2, in_chans=3, num_classes=1000, qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0., layer_scale=None, layer_scale_conv=None, **kwargs): \"\"\" Args: dim: feature size dimension. depths: number of layers in each stage. window_size: window size in each stage. mlp_ratio: MLP ratio. num_heads: number of heads in each stage. drop_path_rate: drop path rate. in_chans: number of input channels. num_classes: number of classes. qkv_bias: bool argument for query, key, value learnable bias. qk_scale: bool argument to scaling query, key. drop_rate: dropout rate. attn_drop_rate: attention dropout rate. norm_layer: normalization layer. layer_scale: layer scaling coefficient. layer_scale_conv: conv layer scaling coefficient. \"\"\" super().__init__() # 计算最后一个阶段的特征维度 num_features = int(dim * 2 ** (len(depths) - 1)) self.num_classes = num_classes # 初始patch嵌入层,将输入图像转换为特征图 self.patch_embed = PatchEmbed(in_chans=in_chans, in_dim=in_dim, dim=dim) # 生成随机深度丢弃路径的概率,线性增加从0到drop_path_rate dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # 初始化多个阶段的网络层 self.levels = nn.ModuleList() for i in range(len(depths)): # 前两个阶段使用卷积结构,后两个阶段使用混合结构 conv = True if (i == 0 or i == 1) else False level = MambaVisionLayer(dim=int(dim * 2 ** i), depth=depths[i], num_heads=num_heads[i], window_size=window_size[i], mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, conv=conv, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[sum(depths[:i]):sum(depths[:i + 1])], downsample=(i < 3), layer_scale=layer_scale, layer_scale_conv=layer_scale_conv, # 指定使用Transformer块的位置 transformer_blocks=list(range(depths[i]//2+1, depths[i])) if depths[i]%2!=0 else list(range(depths[i]//2, depths[i])), ) self.levels.append(level) # 最后的归一化层 self.norm = nn.BatchNorm2d(num_features) # 自适应平均池化层,将特征图转换为固定大小 self.avgpool = nn.AdaptiveAvgPool2d(1) # 分类头,将特征映射到类别空间 self.head = nn.Linear(num_features, num_classes) if num_classes > 0 else nn.Identity() # 应用权重初始化 self.apply(self._init_weights)
解析:
- 这是 MambaVision 模型的主类,继承自 PyTorch 的 nn.Module。
- 初始化参数定义了模型的各种配置,包括维度、深度、窗口大小等。
- 模型采用分层设计,前两个阶段使用卷积结构,后两个阶段使用混合结构。
- drop_path_rate 参数用于实现随机深度丢弃,这是一种正则化技术,可以提高模型的泛化能力。
- PatchEmbed 是一个自定义模块,用于将输入图像转换为特征图。
- MambaVisionLayer 是模型的核心构建块,根据不同阶段的配置使用不同的结构。
def _init_weights(self, m): if isinstance(m, nn.Linear): # 对线性层使用截断正态分布初始化权重 trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): # 对LayerNorm层初始化偏置为0,权重为1 nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) elif isinstance(m, LayerNorm2d): # 对2D LayerNorm层初始化偏置为0,权重为1 nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) elif isinstance(m, nn.BatchNorm2d): # 对BatchNorm2d层初始化权重为1,偏置为0 nn.init.ones_(m.weight) nn.init.zeros_(m.bias)
解析:
- 这是权重初始化函数,根据不同的层类型应用不同的初始化策略。
- 线性层使用截断正态分布初始化,标准差为 0.02,这有助于防止梯度消失或爆炸。
- 各种归一化层的初始化保持一致性,确保模型训练的稳定性。
@torch.jit.ignore def no_weight_decay_keywords(self): return {\'rpb\'}
解析:
- 这个函数指定了不需要权重衰减的参数名称,rpb 可能代表相对位置偏置 (relative position bias),这是一种在注意力机制中常用的技术。
def forward_features(self, x): x = self.patch_embed(x) for level in self.levels: x = level(x) x = self.norm(x) x = self.avgpool(x) x = torch.flatten(x, 1) return x
解析:
- 这是特征提取的前向传播函数。
- 输入图像首先通过 patch_embed 转换为特征图。
- 然后依次通过各个阶段的网络层。
- 最后通过归一化、池化和展平操作,得到最终的特征向量。
def forward(self, x): x = self.forward_features(x) x = self.head(x) return x
解析:
- 这是完整的前向传播函数。
- 先调用 forward_features 提取特征。
- 然后通过分类头得到最终的分类结果。
def _load_state_dict(self, pretrained, strict: bool = False): _load_checkpoint(self, pretrained, strict=strict)
解析:
- 这是加载预训练权重的函数。
- 使用_load_checkpoint 辅助函数加载预训练模型参数,strict 参数控制是否严格匹配模型结构。
MambaVisionLayer
MambaVisionLayer
是 MambaVision 模型的核心构建块,负责实现模型的分层结构。根据论文设计,该层可配置为纯卷积结构(前两个阶段)或混合结构(后两个阶段),通过窗口化操作平衡局部与全局建模能力。
初始化
class MambaVisionLayer(nn.Module): \"\"\" MambaVision layer\" \"\"\" def __init__(self, dim, depth, num_heads, window_size, conv=False, downsample=True, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., layer_scale=None, layer_scale_conv=None, transformer_blocks = [], ): \"\"\" Args: dim: feature size dimension. depth: number of layers in each stage. window_size: window size in each stage. conv: bool argument for conv stage flag. downsample: bool argument for down-sampling. mlp_ratio: MLP ratio. num_heads: number of heads in each stage. qkv_bias: bool argument for query, key, value learnable bias. qk_scale: bool argument to scaling query, key. drop: dropout rate. attn_drop: attention dropout rate. drop_path: drop path rate. norm_layer: normalization layer. layer_scale: layer scaling coefficient. layer_scale_conv: conv layer scaling coefficient. transformer_blocks: list of transformer blocks. \"\"\" super().__init__() self.conv = conv self.transformer_block = False if conv: self.blocks = nn.ModuleList([ConvBlock(dim=dim, drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, layer_scale=layer_scale_conv) for i in range(depth)]) self.transformer_block = False else: self.blocks = nn.ModuleList([Block(dim=dim, counter=i, transformer_blocks=transformer_blocks, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop, attn_drop=attn_drop, drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, layer_scale=layer_scale) for i in range(depth)]) self.transformer_block = True self.downsample = None if not downsample else Downsample(dim=dim) self.do_gt = False self.window_size = window_size
-
配置判断(
conv
参数):- 若
conv=True
(前两个阶段),使用纯卷积结构(ConvBlock
),基于残差卷积块实现高效局部特征提取。 - 若
conv=False
(后两个阶段),使用混合结构(Block
),结合 Mamba 和 Transformer 模块,通过transformer_blocks
参数指定哪些层使用 Transformer 自注意力。
- 若
-
模块初始化:
ConvBlock
:基于 CNN 的残差块,包含深度可分离卷积和扩张卷积,用于快速提取局部特征,对应论文中的 CNN 阶段设计。Block
:混合模块,根据层数索引(counter
)和transformer_blocks
决定使用 Mamba 块还是 Transformer 自注意力块,实现局部与全局建模的平衡。
-
下采样设计:
downsample
参数控制是否在层末尾进行下采样,通过Downsample
模块实现(通常为步长为 2 的卷积),用于缩小特征图尺寸并增加通道数。
-
窗口化配置:
window_size
控制窗口化注意力的窗口大小,影响局部与全局信息的平衡,对应论文中 “窗口化注意力机制” 提升效率的设计。
前向传播部分
def forward(self, x): _, _, H, W = x.shape if self.transformer_block: pad_r = (self.window_size - W % self.window_size) % self.window_size pad_b = (self.window_size - H % self.window_size) % self.window_size if pad_r > 0 or pad_b > 0: x = torch.nn.functional.pad(x, (0,pad_r,0,pad_b)) _, _, Hp, Wp = x.shape else: Hp, Wp = H, W x = window_partition(x, self.window_size) for _, blk in enumerate(self.blocks): x = blk(x) if self.transformer_block: x = window_reverse(x, self.window_size, Hp, Wp) if pad_r > 0 or pad_b > 0: x = x[:, :, :H, :W].contiguous() if self.downsample is None: return x return self.downsample(x)
解析:
-
窗口化处理(Transformer 阶段):
- 填充操作:确保特征图尺寸能被
window_size
整除,通过填充(pad_r
和pad_b
)实现。 - 窗口划分:使用
window_partition
将特征图分割成多个不重叠的窗口,每个窗口独立计算注意力,降低计算复杂度。这对应论文中的 “窗口化自注意力” 设计,减少全局注意力的开销。
- 填充操作:确保特征图尺寸能被
-
模块处理流程:
- 依次通过
self.blocks
中的所有块(卷积块或混合块)。 - 若为混合结构,根据
transformer_blocks
列表决定每层使用 Mamba 还是 Transformer,实现论文中 “前半部分 Mamba + 后半部分 Transformer” 的混合策略。
- 依次通过
-
窗口恢复与下采样:
- 窗口合并:使用
window_reverse
将处理后的窗口重新合并为完整的特征图。 - 裁剪填充:去除之前添加的填充,恢复原始尺寸。
- 下采样:若配置了
downsample
,则通过Downsample
模块进行下采样,为下一阶段做准备。
- 窗口合并:使用
与论文的对应关系
conv
参数控制阶段类型,前两阶段用 CNN(ConvBlock
),后两阶段用混合模块(Block
)。window_partition
和 window_reverse
实现窗口划分与合并,降低计算复杂度。transformer_blocks
指定 Transformer 块的位置,实现 “前半 Mamba + 后半 Transformer” 策略。ConvBlock
),包含深度可分离卷积和扩张卷积。Downsample
模块在层末尾实现特征图尺寸缩减和通道数增加。MambaVisionMixer
1. 初始化参数与公式映射
论文公式依据
该公式描述了 MambaVision Mixer 的核心结构:输入经线性层分为两支,一支通过 SSM(Scan
),另一支通过纯卷积(对称分支),最后拼接输出。
代码对应
self.d_model = d_model # 输入维度 C(公式中的 C)self.d_inner = int(self.expand * self.d_model) # 扩展后的维度,对应公式中 Linear(C, C/2) 的输出维度 C/2 × 2(因分支拆分)self.in_proj = nn.Linear(self.d_model, self.d_inner, bias=bias, **factory_kwargs) # 公式中的 Linear(C, C/2×2),拆分后每支维度为 C/2
in_proj
将输入从d_model
(C)投影到d_inner
(C×2,因分支拆分为两支各 C/2),对应公式中线性层将输入分为两支。
2. 状态空间模型(SSM)参数初始化
论文公式依据
代码对应
self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == \"auto\" else dt_rank # 时间步长的秩,控制动态参数的维度self.x_proj = nn.Linear(self.d_inner//2, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs) # 投影到 dt_rank(Δ)、B、C 三个部分self.A_log = nn.Parameter(torch.log(repeat(torch.arange(1, self.d_state + 1), \"n -> d n\", d=self.d_inner//2))) # A矩阵的对数形式,对应公式中的 A(连续时间参数)self.D = nn.Parameter(torch.ones(self.d_inner//2, device=device)) # 衰减参数,论文未显式提及,用于稳定训练
A_log
存储连续时间矩阵 A 的对数,通过 (\\exp(-A_log)) 得到离散化的 (\\overline{A})(公式 3 中的负号对应代码中的-torch.exp(self.A_log)
)。x_proj
将输入分支 x 投影为动态参数 dt(对应 (\\Delta))、B、C,实现输入依赖的选择性扫描。
3. 对称分支与卷积操作
论文公式依据
- 公式 7 中的 (X_2):对称分支使用纯卷积(非 SSM 路径),补偿 SSM 的顺序约束导致的信息丢失。
- 代码中的卷积设计:将因果卷积替换为常规卷积(
padding=\'same\'
),支持并行处理空间信息。
代码对应
self.conv1d_x = nn.Conv1d(..., kernel_size=d_conv, groups=self.d_inner//2, ...) # x分支的卷积(SSM路径)self.conv1d_z = nn.Conv1d(..., kernel_size=d_conv, groups=self.d_inner//2, ...) # z分支的卷积(对称路径)
conv1d_x
和conv1d_z
均为分组卷积(groups=self.d_inner//2
),保持参数效率,对应论文中 “常规卷积替代因果卷积” 的改进。xz.chunk(2, dim=1)
将投影后的特征拆分为 x(SSM 分支)和 z(对称分支),分别经过卷积和激活函数(F.silu
)。
4. 前向传播:SSM 与对称分支融合
论文公式依据
- 公式 7 中的 (X_1):通过
selective_scan_fn
实现 SSM 的选择性扫描,对应论文中的Scan 操作。 - 拼接操作:将 SSM 输出 y 与对称分支 z 拼接,融合序列与空间信息。
代码对应
xz = self.in_proj(hidden_states) # 输入投影(公式中的 Linear 层)xz = rearrange(xz, \"b l d -> b d l\") # 调整维度为 (B, D, L),适配 Conv1D 输入x, z = xz.chunk(2, dim=1) # 拆分为两支,各维度为 (B, C/2, L)# SSM 分支(X1)x = F.silu(self.conv1d_x(x)) # 卷积+激活(公式中的 Conv+σ)x_dbl = self.x_proj(rearrange(x, \"b d l -> (b l) d\")) # 投影到动态参数空间dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1) # 拆分参数dt = self.dt_proj(dt) # 调整时间步长(对应公式中的 Δ 离散化)y = selective_scan_fn(x, dt, -torch.exp(self.A_log), B, C, self.D, ...) # 选择性扫描,计算 SSM 输出# 对称分支(X2)z = F.silu(self.conv1d_z(z)) # 卷积+激活(公式中的 Conv+σ)# 拼接与输出y = torch.cat([y, z], dim=1) # 拼接两支(公式中的 Concat)y = rearrange(y, \"b d l -> b l d\") # 恢复维度 (B, L, D)out = self.out_proj(y) # 最终投影(公式中的 Linear 层)
selective_scan_fn
函数实现论文中的选择性扫描机制,通过动态参数 (dt, B, C) 处理序列,对应公式 5 中的全局卷积 (\\overline{K})。- 对称分支 z 不经过 SSM,直接通过卷积和激活,与 SSM 分支形成并行结构,最终通过拼接(
torch.cat
)融合信息。
5. 关键改进与论文一致性
conv1d_x
和 conv1d_z
使用 padding=\'same\'
,支持双向信息流动。z
分支,仅通过卷积和激活,补偿 SSM 的顺序约束。torch.cat([y, z], dim=1)
,确保同时捕获序列依赖(SSM)和空间结构(卷积)。x_proj
和 dt_proj
根据输入动态生成 (dt, B, C),实现输入依赖的特征处理。总结:代码与公式的完整映射
- 线性投影:
in_proj
对应公式中的线性层,拆分输入为两支。 - 卷积操作:
conv1d_x
和conv1d_z
对应公式中的卷积和激活函数((\\sigma))。 - SSM 计算:
selective_scan_fn
实现公式中的状态空间模型和选择性扫描。 - 对称分支:
z
分支不经过 SSM,直接通过卷积,对应论文中的 “对称路径” 设计。
SSM,直接通过卷积和激活,与 SSM 分支形成并行结构,最终通过拼接(torch.cat
)融合信息。
5. 关键改进与论文一致性
conv1d_x
和 conv1d_z
使用 padding=\'same\'
,支持双向信息流动。z
分支,仅通过卷积和激活,补偿 SSM 的顺序约束。torch.cat([y, z], dim=1)
,确保同时捕获序列依赖(SSM)和空间结构(卷积)。x_proj
和 dt_proj
根据输入动态生成 (dt, B, C),实现输入依赖的特征处理。总结:代码与公式的完整映射
- 线性投影:
in_proj
对应公式中的线性层,拆分输入为两支。 - 卷积操作:
conv1d_x
和conv1d_z
对应公式中的卷积和激活函数((\\sigma))。 - SSM 计算:
selective_scan_fn
实现公式中的状态空间模型和选择性扫描。 - 对称分支:
z
分支不经过 SSM,直接通过卷积,对应论文中的 “对称路径” 设计。 - 特征融合:拼接两支并通过
out_proj
输出,对应公式中的拼接和最终线性层。
原论文链接:https://arxiv.org/abs/2407.08083