跳转至

RetroBridge关键代码

RetroBridge关键代码

一、变换函数——\(\bar{Q}_t\)的计算

1.参数设定:
- batch_size = 64
- 各分子的最大原子个数(n_atoms)为65
- 原子类型数目(x_dim):17种原子,one-hot编码(包含dummy node这一种类)。
- 边类型(edge_dim):5种,one-hot编码(包含no edge这一种类)。
2.源代码

def get_Qt_bar(self, alpha_bar_t, X_T, E_T, y_T, node_mask, device):
    """
    alpha_bar_t: (bs, 1)
    X_T: (bs, n, dx)
    E_T: (bs, n, n, de)
    y_T: (bs, dy)

    Returns transition matrices for X, E, and y
    """
    alpha_bar_t = alpha_bar_t.unsqueeze(1)  # (bs, 1, 1)
    alpha_bar_t = alpha_bar_t.to(device)

    q_x_1 = alpha_bar_t * torch.eye(self.X_classes, device=device)  # (bs, dx, dx)
    q_x_2 = (1 - alpha_bar_t).unsqueeze(-1) * torch.ones_like(X_T).unsqueeze(-1) * X_T.unsqueeze(-2)  # (bs, n, dx, dx)
    q_x = q_x_1.unsqueeze(1) + q_x_2
    q_x[~node_mask] = torch.eye(q_x.shape[-1], device=device)

    q_e_1 = alpha_bar_t * torch.eye(self.E_classes, device=device)  # (bs, de, de)
    q_e_2 = (1 - alpha_bar_t).unsqueeze(-1).unsqueeze(-1) * torch.ones_like(E_T).unsqueeze(-1) * E_T.unsqueeze(-2)  # (bs, n, n, de, de)
    q_e = q_e_1.unsqueeze(1).unsqueeze(1) + q_e_2

    diag = torch.eye(E_T.shape[1], dtype=torch.bool).unsqueeze(0).expand(E_T.shape[0], -1, -1)
    q_e[diag] = torch.eye(q_e.shape[-1], device=device)

    edge_mask = node_mask[:, None, :] & node_mask[:, :, None]
    q_e[~edge_mask] = torch.eye(q_e.shape[-1], device=device)

    return utils.PlaceHolder(X=q_x, E=q_e, y=y_T)
3.解读
(1)输入含义
- alpha_bar_t
- 意义:采样所得的timestep
- 每个data point对应一个timestep
- 经归一化位于[0, 1]区间
- shape: [batch_size, 1]
- X_T
- 意义:原子类型
- shape: (batch_size, n_atoms, x_dim)
- E_T
- 意义:边类型
- shape: (batch_size, n_atoms, n_atoms, edge_dim)
(2)获取系数\(\bar{\alpha}_{t}\)
alpha_bar_t = alpha_bar_t.unsqueeze(1)  # (bs, 1, 1)
alpha_bar_t = alpha_bar_t.to(device)
shape: [bs, 1]-> [bs, 1, 1]
(3)原子特征变换矩阵\(\bar{Q}_{t}\) - 计算\(\bar{\alpha}_{t}I_{K}\)
q_x_1 = alpha_bar_t * torch.eye(self.X_classes, device=device)  # (bs, dx, dx)
上述代码中,torch.eye(self.X_classes)即生成了形状为[x_dim, x_dim]的单位矩阵。
- 计算\((1 - \bar{\alpha}_{t})y1_{K}^T\)
q_x_2 = (1 - alpha_bar_t).unsqueeze(-1) * torch.ones_like(X_T).unsqueeze(-1) * X_T.unsqueeze(-2)  
X_T.shape = [batch_size, n_atoms, x_dim],故\(y1_K^T\)的计算代码即为torch.ones_like(X_T).unsqueeze(-1) * X_T.unsqueeze(-2),形状变化:[batch_size, n_atoms, x_dim, 1] * [batch_size, n_atoms, 1, x_dim] -> [batch_size, n_atoms, x_dim, x_dim]
- 叠加\(\bar{Q}_t = \bar{\alpha}_{t}I_{K} + (1 - \bar{\alpha}_{t})y1_{K}^T\)
q_x = q_x_1.unsqueeze(1) + q_x_2
q_x[~node_mask] = torch.eye(q_x.shape[-1], device=device)
边变换矩阵计算方法类似,不再赘述。

二、扩散Markov Bridge过程

1.总函数training_step
- 调用process_and_forward,完成逆合成预测。
- 预测:dummy node的变化情况,如有类似leaving group的原子应该准确预测出。
- 预测:化学键变化情况(化学键类型独热编码)。
- 根据self.loss_type,计算损失类型(交叉熵损失、变分下界损失)。

def training_step(self, data, i):
    reactants, product, pred, node_mask, noisy_data, _ = self.process_and_forward(data)
    if self.loss_type == 'vlb':
        return self.compute_training_VLB(
            reactants=reactants,
            pred=pred,
            node_mask=node_mask,
            noisy_data=noisy_data,
            i=i,
        )
    else:
        return self.compute_training_CE_loss_and_metrics(reactants=reactants, pred=pred, i=i)

2.to_dense_batch函数
- 形状统一为一个batch内的max_atom_num,有利于并行运算。
- 作用:批量稀疏特征\(x \in R^{(N_1 + \cdots + N_B) \times F}\) -> 批量稠密特征\(x \in R^{B \times max_{1 \leq i \leq B}\{N_i\} \times F}\).
- 返回值:out, node_mask,其中out为批量稠密特征,mask表示为了统一规格而被扩容,但是没有实际意义的原子序号。

3.encode_no_edge函数
- 作用一:单独使用一个通道,用1, 0分别对有边、无边的信息进行编码。
- 作用二:去除self loop,即E[diag] = 0

def encode_no_edge(E):
    assert len(E.shape) == 4
    if E.shape[-1] == 0:
        return E

    no_edge = torch.sum(E, dim=3) == 0
    first_elt = E[:, :, :, 0]
    first_elt[no_edge] = 1
    E[:, :, :, 0] = first_elt
    diag = torch.eye(E.shape[1], dtype=torch.bool).unsqueeze(0).expand(E.shape[0], -1, -1)
    E[diag] = 0
    return E

4.to_dense函数
- 首先调用to_dense_batch,统一形状。
- 这里的“批量化”图,实际上是把batch_size个独立的图,统一为一个大型森林,将所有结点的编号统一到一个森林中。比如batch中,第0个样本的原子编号为0 ~ 23,第1个样本的原子编号为24 ~ 71
- 这里edge_index, edge_attr同样是批量化稀疏数据。edge_index形状为[2, total_atom_num_in_batch],其中edge_index[0][i], edge_index[1][i]分别表示第i条有向边的起始原子、终止原子。
- 然后调用to_dense_adj,作用与to_dense_batch类似,也是批次数据的稠密化。
- 最后调用自定义的encode_no_edge函数,对同一个产物分子内部,两原子之间不存在边的情况进行编码,即得到了稠密化的批量数据,包括:原子类型、边矩阵、结点掩码。

def to_dense(x, edge_index, edge_attr, batch, explicitly_encode_no_edge=True):
    X, node_mask = to_dense_batch(x=x, batch=batch)
    edge_index, edge_attr = torch_geometric.utils.remove_self_loops(edge_index, edge_attr)
    max_num_nodes = X.size(1)
    E = to_dense_adj(edge_index=edge_index, batch=batch, edge_attr=edge_attr, max_num_nodes=max_num_nodes)
    if explicitly_encode_no_edge:
        E = encode_no_edge(E)

    return PlaceHolder(X=X, E=E), node_mask

5.反应物掩码函数reactants.mask
- 作用:掩码运算
- 怀疑先作encode_no_edge,又进行掩码运算,单键、无化学键被混淆了。4.9有空一定要检查一番。

def mask(self, node_mask, collapse=False):
    x_mask = node_mask.unsqueeze(-1)          # bs, n, 1
    e_mask1 = x_mask.unsqueeze(2)             # bs, n, 1, 1
    e_mask2 = x_mask.unsqueeze(1)             # bs, 1, n, 1

    if collapse:
        self.X = torch.argmax(self.X, dim=-1)
        self.E = torch.argmax(self.E, dim=-1)

        self.X[node_mask == 0] = - 1
        self.E[(e_mask1 * e_mask2).squeeze(-1) == 0] = - 1
    else:
        self.X = self.X * x_mask
        self.E = self.E * e_mask1 * e_mask2
        assert torch.allclose(self.E, torch.transpose(self.E, 1, 2))

    return self

6.噪声应用
- 针对每个样本,独立、随机采样时间步,t_int,范围为[lowest_t, self.T]
- t_float, s_float将时间步归一化,便于进行noise schedule
-

def apply_noise(self, X, E, y, X_T, E_T, y_T, node_mask):
    # Sample a timestep t.
    # When evaluating, the loss for t=0 is computed separately
    lowest_t = 0 if self.training else 1
    t_int = torch.randint(lowest_t, self.T + 1, size=(X.size(0), 1), device=X.device).float()  # (bs, 1)
    s_int = t_int - 1

    t_float = t_int / self.T
    s_float = s_int / self.T

    # beta_t and alpha_s_bar are used for denoising/loss computation
    beta_t = self.noise_schedule(t_normalized=t_float)  # (bs, 1)
    alpha_s_bar = self.noise_schedule.get_alpha_bar(t_normalized=s_float)  # (bs, 1)
    alpha_t_bar = self.noise_schedule.get_alpha_bar(t_normalized=t_float)  # (bs, 1)

    Qtb = self.transition_model.get_Qt_bar(
        alpha_bar_t=alpha_t_bar,
        X_T=X_T,
        E_T=E_T,
        y_T=y_T,
        node_mask=node_mask,
        device=self.device,
    )  # (bs, n, dx_in, dx_out), (bs, n, n, de_in, de_out)

    assert (len(Qtb.X.shape) == 4 and len(Qtb.E.shape) == 5)
    assert (abs(Qtb.X.sum(dim=3) - 1.) < 1e-4).all(), Qtb.X.sum(dim=3) - 1
    assert (abs(Qtb.E.sum(dim=4) - 1.) < 1e-4).all()

    probX = (X.unsqueeze(-2) @ Qtb.X).squeeze(-2)  # (bs, n, dx_out)
    probE = (E.unsqueeze(-2) @ Qtb.E).squeeze(-2)  # (bs, n, n, de_out)

    sampled_t = diffusion_utils.sample_discrete_features(probX=probX, probE=probE, node_mask=node_mask)

    X_t = F.one_hot(sampled_t.X, num_classes=self.Xdim_output)
    E_t = F.one_hot(sampled_t.E, num_classes=self.Edim_output)
    assert (X.shape == X_t.shape) and (E.shape == E_t.shape)

    z_t = utils.PlaceHolder(X=X_t, E=E_t, y=y).type_as(X_t).mask(node_mask)

    noisy_data = {
        't_int': t_int,
        't': t_float,
        'beta_t': beta_t,
        'alpha_s_bar': alpha_s_bar,
        'alpha_t_bar': alpha_t_bar,
        'X_t': z_t.X,
        'E_t': z_t.E,
        'y_t': z_t.y,
        'node_mask': node_mask
    }
    return noisy_data

7.关键处理函数process_and_forward
- 反应物、产物规格统一、稠密化
- 首先调用to_dense函数,将合并为一个大tensor存储的张量重新分解为[batch_size, max_atom_num, atom_type_num]形状的张量。
- 调用mask,进行掩码处理。
- 噪声应用
- 这里X, E, y对应于时间步\(t = 0\),即Markov Bridge的起点;X_T, E_T, y_T对应于时间步\(t = T\),即Markov Bridge的终点。
- 根据论文定义,\(t = 0\)时恰好对应于产物分子数据,\(t = T\)恰好对应于反应物分子数据。
- 该函数随机对时间步\(T\)、噪声\(\epsilon\)进行采样,每个样本该轮的\(T\)分别随机采样获得,互不相关;获得添加噪声后的数据。

def process_and_forward(self, data):
    # Getting graphs of reactants (target) and product (context)
    reactants, r_node_mask = utils.to_dense(data.x, data.edge_index, data.edge_attr, data.batch)
    reactants = reactants.mask(r_node_mask)

    product, p_node_mask = utils.to_dense(data.p_x, data.p_edge_index, data.p_edge_attr, data.batch)
    product = product.mask(p_node_mask)

    assert torch.allclose(r_node_mask, p_node_mask)
    node_mask = r_node_mask

    # Getting noisy data
    # Note that here products and reactants are swapped
    noisy_data = self.apply_noise(
        X=product.X, E=product.E, y=product.y,
        X_T=reactants.X, E_T=reactants.E, y_T=reactants.y,
        node_mask=node_mask,
    )

    # Computing extra features + context and making predictions
    context = product.clone() if self.use_context else None
    extra_data = self.compute_extra_data(noisy_data, context=context)
    pred = self.forward(noisy_data, extra_data, node_mask)

    # Masking unchanged part
    if self.fix_product_nodes:
        fixed_nodes = (product.X[..., -1] == 0).unsqueeze(-1)
        modifiable_nodes = (product.X[..., -1] == 1).unsqueeze(-1)
        assert torch.all(fixed_nodes | modifiable_nodes)
        pred.X = pred.X * modifiable_nodes + product.X * fixed_nodes
        pred.X = pred.X * node_mask.unsqueeze(-1)

    return reactants, product, pred, node_mask, noisy_data, context