专题一——Mamba源码解读
专题一——Mamba
源码解读
学习时间:2024年5月24日
代码来源:github
一、Mamba
参数确认
- d_model
:表示每个输入token
的嵌入表示维度,即Transformer
中的\(d_{model}\)/S6
算法中的\(D\).
class MambaConfig:
d_model: int # D
n_layers: int
dt_rank: Union[int, str] = 'auto'
d_state: int = 16 # N in paper/comments
expand_factor: int = 2 # E in paper/comments
d_conv: int = 4
dt_min: float = 0.001
dt_max: float = 0.1
dt_init: str = "random" # "random" or "constant"
dt_scale: float = 1.0
dt_init_floor = 1e-4
rms_norm_eps: float = 1e-5
bias: bool = False
conv_bias: bool = True
inner_layernorms: bool = False # apply layernorms to internal activations
pscan: bool = True # use parallel scan mode or sequential mode when training
use_cuda: bool = False # use official CUDA implementation when training (not compatible with (b)float16)
def __post_init__(self):
self.d_inner = self.expand_factor * self.d_model # E*D = ED in comments
if self.dt_rank == 'auto':
self.dt_rank = math.ceil(self.d_model / 16)
二、选择扫描API
三、Mamba
块
MambaBlock
即上图右侧描述的Mamba
块。
1.__init__
函数
输入维度:\((B, L, D)\)
\(B\):batch_size
\(L\):序列长度
\(D\):每个输入token
的嵌入表示维度,即Transformer
中的\(d_{model}\)。
(1)模块入口线性投影层
- 将每个token
的特征向量/嵌入表示由\(d_{model} = D\)维度投影到\(d_{inner} = ED\)维度。
- 一般而言,取\(E = 2\),形成鹅颈瓶结构。
- 此处\(2 * d_{inner}\)表示分别作两次投影,形成两个输出,即对应于图中下方两个绿色块的输出。
# projects block input from D to 2*ED (two branches)
self.in_proj = nn.Linear(config.d_model, 2 * config.d_inner, bias=config.bias)
- 即左侧通道的
Conv
块输出。self.conv1d = nn.Conv1d(in_channels=config.d_inner, out_channels=config.d_inner,
kernel_size=config.d_conv, bias=config.conv_bias,
groups=config.d_inner,
padding=config.d_conv - 1
)
(4)
2.ssm
函数