跳转至

专题一——Mamba源码解读

专题一——Mamba源码解读

学习时间:2024年5月24日
代码来源:github

一、Mamba参数确认

- d_model:表示每个输入token的嵌入表示维度,即Transformer中的\(d_{model}\)/S6算法中的\(D\).

class MambaConfig:
    d_model: int # D
    n_layers: int
    dt_rank: Union[int, str] = 'auto'
    d_state: int = 16 # N in paper/comments
    expand_factor: int = 2 # E in paper/comments
    d_conv: int = 4

    dt_min: float = 0.001
    dt_max: float = 0.1
    dt_init: str = "random" # "random" or "constant"
    dt_scale: float = 1.0
    dt_init_floor = 1e-4

    rms_norm_eps: float = 1e-5

    bias: bool = False
    conv_bias: bool = True
    inner_layernorms: bool = False # apply layernorms to internal activations

    pscan: bool = True # use parallel scan mode or sequential mode when training
    use_cuda: bool = False # use official CUDA implementation when training (not compatible with (b)float16)

    def __post_init__(self):
        self.d_inner = self.expand_factor * self.d_model # E*D = ED in comments

        if self.dt_rank == 'auto':
            self.dt_rank = math.ceil(self.d_model / 16)
二、选择扫描API
三、Mamba


MambaBlock即上图右侧描述的Mamba块。
1.__init__函数
输入维度:\((B, L, D)\)
\(B\)batch_size
\(L\):序列长度
\(D\):每个输入token的嵌入表示维度,即Transformer中的\(d_{model}\)
(1)模块入口线性投影层
- 将每个token的特征向量/嵌入表示由\(d_{model} = D\)维度投影到\(d_{inner} = ED\)维度。
- 一般而言,取\(E = 2\),形成鹅颈瓶结构。
- 此处\(2 * d_{inner}\)表示分别作两次投影,形成两个输出,即对应于图中下方两个绿色块的输出。

# projects block input from D to 2*ED (two branches)
self.in_proj = nn.Linear(config.d_model, 2 * config.d_inner, bias=config.bias)
(2)一维卷积
- 即左侧通道的Conv块输出。
self.conv1d = nn.Conv1d(in_channels=config.d_inner, out_channels=config.d_inner, 
    kernel_size=config.d_conv, bias=config.conv_bias, 
    groups=config.d_inner,
    padding=config.d_conv - 1
)
(3)\(\Delta\)权重矩阵

(4) 2.ssm函数

四、残差块(ResidualBlock)