RetroBridge关键代码
RetroBridge
关键代码
一、变换函数——\(\bar{Q}_t\)的计算
1.参数设定:
- batch_size = 64
- 各分子的最大原子个数(n_atoms
)为65
- 原子类型数目(x_dim
):17
种原子,one-hot
编码(包含dummy node
这一种类)。
- 边类型(edge_dim
):5
种,one-hot
编码(包含no edge
这一种类)。
2.源代码
def get_Qt_bar(self, alpha_bar_t, X_T, E_T, y_T, node_mask, device):
"""
alpha_bar_t: (bs, 1)
X_T: (bs, n, dx)
E_T: (bs, n, n, de)
y_T: (bs, dy)
Returns transition matrices for X, E, and y
"""
alpha_bar_t = alpha_bar_t.unsqueeze(1) # (bs, 1, 1)
alpha_bar_t = alpha_bar_t.to(device)
q_x_1 = alpha_bar_t * torch.eye(self.X_classes, device=device) # (bs, dx, dx)
q_x_2 = (1 - alpha_bar_t).unsqueeze(-1) * torch.ones_like(X_T).unsqueeze(-1) * X_T.unsqueeze(-2) # (bs, n, dx, dx)
q_x = q_x_1.unsqueeze(1) + q_x_2
q_x[~node_mask] = torch.eye(q_x.shape[-1], device=device)
q_e_1 = alpha_bar_t * torch.eye(self.E_classes, device=device) # (bs, de, de)
q_e_2 = (1 - alpha_bar_t).unsqueeze(-1).unsqueeze(-1) * torch.ones_like(E_T).unsqueeze(-1) * E_T.unsqueeze(-2) # (bs, n, n, de, de)
q_e = q_e_1.unsqueeze(1).unsqueeze(1) + q_e_2
diag = torch.eye(E_T.shape[1], dtype=torch.bool).unsqueeze(0).expand(E_T.shape[0], -1, -1)
q_e[diag] = torch.eye(q_e.shape[-1], device=device)
edge_mask = node_mask[:, None, :] & node_mask[:, :, None]
q_e[~edge_mask] = torch.eye(q_e.shape[-1], device=device)
return utils.PlaceHolder(X=q_x, E=q_e, y=y_T)
(1)输入含义
-
alpha_bar_t
- 意义:采样所得的
timestep
- 每个
data point
对应一个timestep
- 经归一化位于
[0, 1]
区间-
shape: [batch_size, 1]
-
X_T
- 意义:原子类型
-
shape: (batch_size, n_atoms, x_dim)
-
E_T
- 意义:边类型
-
shape: (batch_size, n_atoms, n_atoms, edge_dim)
(2)获取系数\(\bar{\alpha}_{t}\)
shape: [bs, 1]-> [bs, 1, 1]
。(3)原子特征变换矩阵\(\bar{Q}_{t}\) - 计算\(\bar{\alpha}_{t}I_{K}\)
上述代码中,
torch.eye(self.X_classes)
即生成了形状为[x_dim, x_dim]
的单位矩阵。- 计算\((1 - \bar{\alpha}_{t})y1_{K}^T\)
X_T.shape = [batch_size, n_atoms, x_dim]
,故\(y1_K^T\)的计算代码即为torch.ones_like(X_T).unsqueeze(-1) * X_T.unsqueeze(-2)
,形状变化:[batch_size, n_atoms, x_dim, 1] * [batch_size, n_atoms, 1, x_dim] -> [batch_size, n_atoms, x_dim, x_dim]
。- 叠加\(\bar{Q}_t = \bar{\alpha}_{t}I_{K} + (1 - \bar{\alpha}_{t})y1_{K}^T\)
边变换矩阵计算方法类似,不再赘述。
二、扩散Markov Bridge
过程
1.总函数training_step
- 调用process_and_forward
,完成逆合成预测。
- 预测:dummy node
的变化情况,如有类似leaving group
的原子应该准确预测出。
- 预测:化学键变化情况(化学键类型独热编码)。
- 根据self.loss_type
,计算损失类型(交叉熵损失、变分下界损失)。
def training_step(self, data, i):
reactants, product, pred, node_mask, noisy_data, _ = self.process_and_forward(data)
if self.loss_type == 'vlb':
return self.compute_training_VLB(
reactants=reactants,
pred=pred,
node_mask=node_mask,
noisy_data=noisy_data,
i=i,
)
else:
return self.compute_training_CE_loss_and_metrics(reactants=reactants, pred=pred, i=i)
2.to_dense_batch
函数
- 形状统一为一个batch
内的max_atom_num
,有利于并行运算。
- 作用:批量稀疏特征\(x \in R^{(N_1 + \cdots + N_B) \times F}\) -> 批量稠密特征\(x \in R^{B \times max_{1 \leq i \leq B}\{N_i\} \times F}\).
- 返回值:out, node_mask
,其中out
为批量稠密特征,mask
表示为了统一规格而被扩容,但是没有实际意义的原子序号。
3.encode_no_edge
函数
- 作用一:单独使用一个通道,用1, 0
分别对有边、无边的信息进行编码。
- 作用二:去除self loop
,即E[diag] = 0
。
def encode_no_edge(E):
assert len(E.shape) == 4
if E.shape[-1] == 0:
return E
no_edge = torch.sum(E, dim=3) == 0
first_elt = E[:, :, :, 0]
first_elt[no_edge] = 1
E[:, :, :, 0] = first_elt
diag = torch.eye(E.shape[1], dtype=torch.bool).unsqueeze(0).expand(E.shape[0], -1, -1)
E[diag] = 0
return E
4.to_dense
函数
- 首先调用to_dense_batch
,统一形状。
- 这里的“批量化”图,实际上是把batch_size
个独立的图,统一为一个大型森林,将所有结点的编号统一到一个森林中。比如batch
中,第0
个样本的原子编号为0 ~ 23
,第1
个样本的原子编号为24 ~ 71
。
- 这里edge_index, edge_attr
同样是批量化稀疏数据。edge_index
形状为[2, total_atom_num_in_batch]
,其中edge_index[0][i], edge_index[1][i]
分别表示第i
条有向边的起始原子、终止原子。
- 然后调用to_dense_adj
,作用与to_dense_batch
类似,也是批次数据的稠密化。
- 最后调用自定义的encode_no_edge
函数,对同一个产物分子内部,两原子之间不存在边的情况进行编码,即得到了稠密化的批量数据,包括:原子类型、边矩阵、结点掩码。
def to_dense(x, edge_index, edge_attr, batch, explicitly_encode_no_edge=True):
X, node_mask = to_dense_batch(x=x, batch=batch)
edge_index, edge_attr = torch_geometric.utils.remove_self_loops(edge_index, edge_attr)
max_num_nodes = X.size(1)
E = to_dense_adj(edge_index=edge_index, batch=batch, edge_attr=edge_attr, max_num_nodes=max_num_nodes)
if explicitly_encode_no_edge:
E = encode_no_edge(E)
return PlaceHolder(X=X, E=E), node_mask
5.反应物掩码函数reactants.mask
- 作用:掩码运算
- 怀疑先作encode_no_edge
,又进行掩码运算,单键、无化学键被混淆了。4.9有空一定要检查一番。
def mask(self, node_mask, collapse=False):
x_mask = node_mask.unsqueeze(-1) # bs, n, 1
e_mask1 = x_mask.unsqueeze(2) # bs, n, 1, 1
e_mask2 = x_mask.unsqueeze(1) # bs, 1, n, 1
if collapse:
self.X = torch.argmax(self.X, dim=-1)
self.E = torch.argmax(self.E, dim=-1)
self.X[node_mask == 0] = - 1
self.E[(e_mask1 * e_mask2).squeeze(-1) == 0] = - 1
else:
self.X = self.X * x_mask
self.E = self.E * e_mask1 * e_mask2
assert torch.allclose(self.E, torch.transpose(self.E, 1, 2))
return self
6.噪声应用
- 针对每个样本,独立、随机采样时间步,t_int
,范围为[lowest_t, self.T]
。
- t_float, s_float
将时间步归一化,便于进行noise schedule
。
-
def apply_noise(self, X, E, y, X_T, E_T, y_T, node_mask):
# Sample a timestep t.
# When evaluating, the loss for t=0 is computed separately
lowest_t = 0 if self.training else 1
t_int = torch.randint(lowest_t, self.T + 1, size=(X.size(0), 1), device=X.device).float() # (bs, 1)
s_int = t_int - 1
t_float = t_int / self.T
s_float = s_int / self.T
# beta_t and alpha_s_bar are used for denoising/loss computation
beta_t = self.noise_schedule(t_normalized=t_float) # (bs, 1)
alpha_s_bar = self.noise_schedule.get_alpha_bar(t_normalized=s_float) # (bs, 1)
alpha_t_bar = self.noise_schedule.get_alpha_bar(t_normalized=t_float) # (bs, 1)
Qtb = self.transition_model.get_Qt_bar(
alpha_bar_t=alpha_t_bar,
X_T=X_T,
E_T=E_T,
y_T=y_T,
node_mask=node_mask,
device=self.device,
) # (bs, n, dx_in, dx_out), (bs, n, n, de_in, de_out)
assert (len(Qtb.X.shape) == 4 and len(Qtb.E.shape) == 5)
assert (abs(Qtb.X.sum(dim=3) - 1.) < 1e-4).all(), Qtb.X.sum(dim=3) - 1
assert (abs(Qtb.E.sum(dim=4) - 1.) < 1e-4).all()
probX = (X.unsqueeze(-2) @ Qtb.X).squeeze(-2) # (bs, n, dx_out)
probE = (E.unsqueeze(-2) @ Qtb.E).squeeze(-2) # (bs, n, n, de_out)
sampled_t = diffusion_utils.sample_discrete_features(probX=probX, probE=probE, node_mask=node_mask)
X_t = F.one_hot(sampled_t.X, num_classes=self.Xdim_output)
E_t = F.one_hot(sampled_t.E, num_classes=self.Edim_output)
assert (X.shape == X_t.shape) and (E.shape == E_t.shape)
z_t = utils.PlaceHolder(X=X_t, E=E_t, y=y).type_as(X_t).mask(node_mask)
noisy_data = {
't_int': t_int,
't': t_float,
'beta_t': beta_t,
'alpha_s_bar': alpha_s_bar,
'alpha_t_bar': alpha_t_bar,
'X_t': z_t.X,
'E_t': z_t.E,
'y_t': z_t.y,
'node_mask': node_mask
}
return noisy_data
7.关键处理函数process_and_forward
- 反应物、产物规格统一、稠密化
- 首先调用to_dense
函数,将合并为一个大tensor
存储的张量重新分解为[batch_size, max_atom_num, atom_type_num]
形状的张量。
- 调用mask
,进行掩码处理。
- 噪声应用
- 这里X, E, y
对应于时间步\(t = 0\),即Markov Bridge
的起点;X_T, E_T, y_T
对应于时间步\(t = T\),即Markov Bridge
的终点。
- 根据论文定义,\(t = 0\)时恰好对应于产物分子数据,\(t = T\)恰好对应于反应物分子数据。
- 该函数随机对时间步\(T\)、噪声\(\epsilon\)进行采样,每个样本该轮的\(T\)分别随机采样获得,互不相关;获得添加噪声后的数据。
def process_and_forward(self, data):
# Getting graphs of reactants (target) and product (context)
reactants, r_node_mask = utils.to_dense(data.x, data.edge_index, data.edge_attr, data.batch)
reactants = reactants.mask(r_node_mask)
product, p_node_mask = utils.to_dense(data.p_x, data.p_edge_index, data.p_edge_attr, data.batch)
product = product.mask(p_node_mask)
assert torch.allclose(r_node_mask, p_node_mask)
node_mask = r_node_mask
# Getting noisy data
# Note that here products and reactants are swapped
noisy_data = self.apply_noise(
X=product.X, E=product.E, y=product.y,
X_T=reactants.X, E_T=reactants.E, y_T=reactants.y,
node_mask=node_mask,
)
# Computing extra features + context and making predictions
context = product.clone() if self.use_context else None
extra_data = self.compute_extra_data(noisy_data, context=context)
pred = self.forward(noisy_data, extra_data, node_mask)
# Masking unchanged part
if self.fix_product_nodes:
fixed_nodes = (product.X[..., -1] == 0).unsqueeze(-1)
modifiable_nodes = (product.X[..., -1] == 1).unsqueeze(-1)
assert torch.all(fixed_nodes | modifiable_nodes)
pred.X = pred.X * modifiable_nodes + product.X * fixed_nodes
pred.X = pred.X * node_mask.unsqueeze(-1)
return reactants, product, pred, node_mask, noisy_data, context