Mamba:基于选择状态空间的线性时间序列模型

FreeGuideOnline 最新 2026-06-21

Mamba:基于选择状态空间的线性时间序列模型

1. 为什么我们需要 Mamba

在深度学习序列建模领域,Transformer 凭借自注意力机制几乎统治了所有任务,但它的计算复杂度随序列长度呈平方级增长((O(L^2))),导致处理长序列时资源消耗巨大。传统的循环神经网络(RNN)虽然具有线性复杂度((O(L))),却难以并行训练,且长程依赖能力受限。状态空间模型(State Space Models, SSMs)作为 RNN 的线性版本,兼具推理高效与可并行训练的优点,成为近年来序列建模的新星。但早期 SSM 对输入不具备选择性,在语言生成、DNA 建模等需要上下文感知的任务中表现不佳。

Mamba 正是为了解决这一矛盾而提出的。它引入了一种选择机制,让状态空间模型能够动态地根据当前输入来决定信息的保留或丢弃,同时通过硬件感知的并行算法实现了极快的训练与推理。Mamba 是首个在多个领域(如语言、音频、基因组学)上达到或超越同等规模 Transformer 性能的线性时间序列模型。

2. 状态空间模型基础

2.1 连续时间线性 SSM

状态空间模型起源于控制理论,它可以描述一个动态系统如何将连续输入信号 (u(t)) 映射到输出信号 (y(t))。一个经典的线性时不变 SSM 由状态方程和输出方程组成:

[ \begin{aligned} x'(t) &= \boldsymbol{A} x(t) + \boldsymbol{B} u(t) \ y(t) &= \boldsymbol{C} x(t) + \boldsymbol{D} u(t) \end{aligned} ]

  • (x(t) \in \mathbb{R}^N) 是隐藏状态(记忆)。
  • (u(t) \in \mathbb{R}) 是输入信号(标量)。
  • (y(t) \in \mathbb{R}) 是输出。
  • (\boldsymbol{A} \in \mathbb{R}^{N \times N}) 是状态转移矩阵,控制状态的演化。
  • (\boldsymbol{B} \in \mathbb{R}^{N \times 1}),(\boldsymbol{C} \in \mathbb{R}^{1 \times N}),(\boldsymbol{D} \in \mathbb{R}) 是参数矩阵。通常省略 (\boldsymbol{D}) 或将其视为跳跃连接。

2.2 离散化处理

因为我们要处理的是离散的时间序列(如文本 token),必须将连续系统离散化。常用方法如零阶保持(ZOH)给出离散化后的参数:

[ \begin{aligned} \overline{\boldsymbol{A}} &= \exp(\Delta \boldsymbol{A}) \ \overline{\boldsymbol{B}} &= (\Delta \boldsymbol{A})^{-1}(\exp(\Delta \boldsymbol{A}) - \boldsymbol{I}) \cdot \Delta \boldsymbol{B} \end{aligned} ]

其中 (\Delta) 是采样步长。离散 SSM 的形式变为:

[ \begin{aligned} x_k &= \overline{\boldsymbol{A}} x_{k-1} + \overline{\boldsymbol{B}} u_k \ y_k &= \boldsymbol{C} x_k \end{aligned} ]

这个结构看起来很像 RNN,但它也可以写成卷积形式,实现高效的并行训练。

2.3 卷积视角与 HiPPO 条件

SSM 可以表示为一个全局卷积 (y = u * \overline{\boldsymbol{K}}),其中卷积核 (\overline{\boldsymbol{K}} = (\boldsymbol{C}\overline{\boldsymbol{B}}, \boldsymbol{C}\overline{\boldsymbol{A}}\overline{\boldsymbol{B}}, \dots))。这使得训练时可以利用 FFT 实现并行加速。为了让这个卷积核能捕获长程依赖,矩阵 (\boldsymbol{A}) 需要特殊的初始化。HiPPO(High-order Polynomial Projection Operators)理论提供了一种初始化方式,确保模型以最优方式记忆输入历史。S4(Structured State Space)模型就是基于这种思想,将 (\boldsymbol{A}) 初始化为 HiPPO 矩阵,并利用对角加低秩分解实现高效计算。

3. 从 S4 到 Mamba:选择机制的核心

3.1 线性时不变性(LTI)的局限

S4 及之前的 SSM 都是线性时不变(LTI)系统:参数 (\overline{\boldsymbol{A}}, \overline{\boldsymbol{B}}, \boldsymbol{C}, \Delta) 在整个序列上保持固定,与输入内容无关。这意味着模型对每个 token 执行完全相同的过滤操作,无法 “选择性” 地关注或忽略某些信息。例如,在复制任务或选择性复制任务中,LTI 模型无法根据上下文变化调整记忆行为,表现极差。

Mamba 的关键创新是让 SSM 的某些参数变为输入的函数,使其成为时变系统。具体来说,Mamba 让 (\overline{\boldsymbol{B}}, \boldsymbol{C}, \Delta) 都依赖于当前输入 (u_k):

[ \begin{aligned} \boldsymbol{B}_k &= \text{Linear}_B(u_k) \ \boldsymbol{C}_k &= \text{Linear}C(u_k) \ \Delta_k &= \text{softplus}(\text{Linear}\Delta(u_k)) \end{aligned} ]

这样,每一步的离散参数 (\overline{\boldsymbol{A}}_k, \overline{\boldsymbol{B}}_k) 都会动态变化,模型可以根据当前 token 的重要性调整状态记忆的更新强度。同时,矩阵 (\boldsymbol{A}) 本身可以不随输入变化(保持 HiPPO 结构),从而保留长程记忆的骨架。

3.2 选择机制的直观理解

选择机制使 SSM 能够执行上下文感知的信息过滤:

  • 当遇到关键 token 时,模型可以放大 (\overline{\boldsymbol{B}}_k),将更多新信息写入状态。
  • 当需要忽略噪声时,模型可以减小 (\overline{\boldsymbol{B}}_k),或者调整 (\Delta_k) 来缩短有效时间窗口。
  • 输出参数 (\boldsymbol{C}_k) 决定了从状态中读取哪部分记忆以生成输出。

本质上,Mamba 的每个时刻都相当于一个 “小型门控循环单元”,但没有显式的激活函数和复杂的门结构,而是通过参数化的离散化过程实现选择性。这使得模型在保留线性推理复杂度的同时,获得了类似 RNN 门控机制(如 LSTM、GRU)的内容自适应能力。

4. Mamba 架构细节

4.1 块结构:H3 Block 的演进

Mamba 的整体架构基于 H3(Hungry Hungry Hippos)模型,将 SSM 与门控 MLP 结合,形成可堆叠的 Mamba Block。

每个 Mamba Block 的输入 (z) 沿着两个路径处理:

  1. 主路径:经过一个投影层(Linear),然后进入 SSM 核心(选择性 SSM),再通过 SiLU 和非线性投影。
  2. 门控路径:输入直接通过另一个投影层和 SiLU 激活,作为门信号与主路径输出进行逐元素乘法。
  3. 最后加上残差连接和归一化。

这种设计与门控 Transformer 或 gated CNN 有异曲同工之妙,但核心组件替换为了线性时间的 SSM,显著降低了计算量。

4.2 选择 SSM 层的具体计算

对于输入序列 (u \in \mathbb{R}^{B \times L \times D})(B: batch size, L: 序列长度, D: 特征维度),Mamba 的 SSM 核心层按如下步骤工作:

  1. 输入投影:将 (u) 线性投影到 (D) 维空间,得到 (u')。
  2. 计算时变参数
    • (B_k = \text{Linear}_B(u'_k)),(C_k = \text{Linear}_C(u'k)),(\Delta_k = \text{softplus}(\text{Linear}\Delta(u'k) + \text{bias}\Delta))。
  3. 离散化:利用 (\Delta_k) 和固定的 (\boldsymbol{A}) 通过 ZOH 计算 (\overline{\boldsymbol{A}}_k, \overline{\boldsymbol{B}}_k)。
  4. 时变 SSM 运算:由于参数不再恒定,无法直接表示为普通卷积。Mamba 采用一种硬件感知的并行扫描算法(见第5节)高效计算输出 (y)。
  5. 输出投影:将输出线性变换回原始维度,并与门控信号相乘。

4.3 扩展到多头机制与多维度

Mamba 没有显式的“多头”结构。相反,它利用多输入多输出(MIMO)SSM:状态维度 (N) 对于每个特征通道是独立的。即,将 (D) 个特征分成多个独立的 SSM 实例,每个实例拥有自己的状态矩阵 (\boldsymbol{A})(通常共享 (\boldsymbol{A}) 或每个特征独立一份),但时变参数 (\Delta, B, C) 根据输入独立产生。这种设计类似于深度卷积中的分组思想,在保持参数高效的同时,允许不同通道具有不同的选择性行为。

5. 硬件感知的并行算法

5.1 时序并行扫描(Parallel Scan)

因为 Mamba 是时变系统,传统的 FFT 卷积不再适用。但 SSM 的递归形式本质上是一次前缀和操作(或扫描操作)。对于线性时变系统,我们可以使用并行扫描(Blelloch scan)在 (O(\log L)) 步内计算长度为 (L) 的序列输出,且每一步都是可并行的。

具体来说,SSM 的递推关系可以写为:

[ x_k = \overline{\boldsymbol{A}}k x{k-1} + \overline{\boldsymbol{B}}_k u_k ]

若定义二元操作符 (\bullet) 组合两个连续步骤的状态变换,则可以将整体变换表达为一个扫描问题。通过将序列划分为树状结构,在同一树级别内并行计算,总复杂度仍为 (O(L)),但充分利用了 GPU 的并行性。

5.2 内核融合与内存优化

Mamba 还在实现层面进行了极致的硬件优化:

  • 内核融合:将投影、离散化、扫描等多个步骤合并为单个 CUDA 内核,避免反复读写 GPU 显存。
  • 扩展状态共享:设计允许在序列长度维度上并行计算,同时保持状态维度的快速访问。
  • 内存访问模式:利用 SRAM(共享内存)存储中间状态和扫描所需的临时变量,大幅减少 HBM(高带宽内存)的访问次数。

这些优化的直接效果是:在训练长序列时,Mamba 的吞吐量数倍于同等大小的 Transformer,且推理延迟几乎与序列长度无关,实现了真正的线性时间。

6. 与 Transformer 和 RNN 的对比

模型类型 训练复杂度 推理复杂度 长程依赖能力 并行训练 选择性
Transformer (O(L^2)) (O(L)) 强(全局注意力) 强(自注意力)
LSTM/GRU (O(L)) (O(L)) 中(门控)
S4 (LTI SSM) (O(L \log L)) 或 (O(L)) (O(L)) 强(HiPPO)
Mamba (O(L)) (O(L)) 强(HiPPO + 选择)

Mamba 在复杂度、并行性和选择性之间达到了更优的平衡。它既不像 RNN 那样必须串行训练,也不像 Transformer 那样付出平方代价;同时通过选择机制弥补了传统 SSM 在上下文依赖任务上的不足。

7. 训练与实际应用

7.1 预训练与缩放

Mamba 遵循与 GPT 类似的自回归语言建模预训练方式。实验表明,在同等计算量下,Mamba 在 The Pile 等数据集上的困惑度优于 Pythia 等 Transformer 基线,且模型参数从 2.8B 扩展到更大规模时,性能曲线保持稳定。在 DNA 序列、音频波形等长序列数据上,Mamba 的优势更为突出:它可以处理百万长度级别的序列而不会内存溢出。

7.2 应用场景示例

  • 大语言模型:Mamba 可作为基础骨干网络构建纯 SSM 对话模型,实现与 Transformer 相当的语言生成质量。
  • 基因组学:DNA 序列长度通常很长(>10k bp),Mamba 能直接对原始序列建模,准确率优于卷积和 LSTM 模型。
  • 音频处理:原始音频波形是极长的一维序列,Mamba 在语音分类、生成等任务上表现出色,推理速度极快。
  • 强化学习:在需要长期记忆的决策任务中,Mamba 可以替代 LSTM 提供稳定的状态追踪。

8. 总结与未来展望

Mamba 通过将选择机制融入状态空间模型,成功打破了线性时不变 SSM 的表达力瓶颈,同时保留了线性复杂度和硬件高效的并行性。它是序列建模领域一次重要的架构创新,为长序列任务提供了一种更具竞争力的基础模块。

未来方向包括:

  • 更强的长程记忆:结合更先进的记忆初始化理论,进一步延长有效上下文窗口。
  • 混合架构:将 Mamba 的局部 SSM 与全局注意力模块结合,实现局部感知与全局交互的折中。
  • 高效微调与推理:探索 Mamba 在低资源环境下的量化、剪枝和适配器方法。
  • 理论分析:深入理解选择 SSM 的表达能力上限与泛化边界。

无论是作为 Transformer 的替代方案,还是作为长序列学习的专用工具,Mamba 及其变体(如 Mamba-2)都正在成为现代深度学习栈中不可或缺的组成部分。