Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update multi-chain permutation and permutation unittest #406

Merged
merged 12 commits into from
May 11, 2024
231 changes: 176 additions & 55 deletions openfold/utils/multi_chain_permutation.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
import random
import torch

from typing import Tuple, List, Dict
from openfold.np import residue_constants as rc

logger = logging.getLogger(__name__)
Expand All @@ -13,6 +13,17 @@ def compute_rmsd(
atom_mask: torch.Tensor = None,
eps: float = 1e-6,
) -> torch.Tensor:
"""
Function to calculate RMSD between predicted and ground truth atom position

Args:
true_atom_pos: a [nres*3] tensor
pred_atom_pos: a [nres*3] tensor
atom_mask: a [1*nres] tensor

Return:
RMSD value between true and predicted atom positions
"""
sq_diff = torch.square(true_atom_pos - pred_atom_pos).sum(dim=-1, keepdim=False)
if atom_mask is not None:
sq_diff = torch.masked_select(sq_diff, atom_mask.to(sq_diff.device))
Expand All @@ -21,19 +32,19 @@ def compute_rmsd(
return torch.sqrt(msd + eps) # prevent sqrt 0


def kabsch_rotation(P, Q):
def kabsch_rotation(P: torch.Tensor, Q: torch.Tensor) -> torch.Tensor:
"""
Calculate the best rotation that minimises the RMSD between P and Q.

The optimal rotation matrix was calculated using Kabsch algorithm:
https://en.wikipedia.org/wiki/Kabsch_algorithm

Args:
P: [N * 3] Nres is the number of atoms and each row corresponds to the atom's x,y,z coordinates
Q: [N * 3] the same dimension as P
P: [N * 3] Nres is the number of atoms and each row corresponds to the atom's x,y,z coordinates
Q: [N * 3] the same dimension as P

return:
A 3*3 rotation matrix
one 3*3 rotation matrix that best aligns the sorce and target atoms
"""
assert P.shape == torch.Size([Q.shape[0], Q.shape[1]])

Expand All @@ -54,11 +65,20 @@ def get_optimal_transform(
src_atoms: torch.Tensor,
tgt_atoms: torch.Tensor,
mask: torch.Tensor = None,
):
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
src_atoms: predicted CA positions, shape:[num_res,3]
tgt_atoms: ground-truth CA positions, shape:[num_res,3]
mask: a vector of boolean values, shape:[num_res]
A function that obtain the transformation that optimally align
src_atoms with tgt_atoms

Args:
src_atoms: predicted CA positions, shape:[num_res,3]
tgt_atoms: ground-truth CA positions, shape:[num_res,3]
mask: a vector of boolean values, shape:[num_res]

Returns:
a rotation matrix that record the optimal rotation
that will best align selected anchor prediction to selected anchor truth
a matrix records how the atoms should be shifted after applying r i.e. optimal alignment requires 1) rotate 2) shift the positions
"""
dingquanyu marked this conversation as resolved.
Show resolved Hide resolved
assert src_atoms.shape == tgt_atoms.shape, (src_atoms.shape, tgt_atoms.shape)
assert src_atoms.shape[-1] == 3
Expand Down Expand Up @@ -88,7 +108,7 @@ def get_optimal_transform(
return r, x


def get_least_asym_entity_or_longest_length(batch, input_asym_id):
def get_least_asym_entity_or_longest_length(batch: dict, input_asym_id: list) -> Tuple[torch.Tensor, List[torch.Tensor]]:
"""
First check how many subunit(s) one sequence has. Select the subunit that is less
common, e.g. if the protein was AABBB then select one of the A as anchor
Expand All @@ -97,15 +117,15 @@ def get_least_asym_entity_or_longest_length(batch, input_asym_id):
then choose one of the corresponding subunits as anchor

Args:
batch: in this function batch is the full ground truth features
input_asym_id: A list of asym_ids that are in the cropped input features
batch: in this function batch is the full ground truth features
input_asym_id: A list of asym_ids that are in the cropped input features

Return:
anchor_gt_asym_id: Tensor(int) selected ground truth asym_id
anchor_pred_asym_ids: list(Tensor(int)) a list of all possible pred anchor candidates
anchor_gt_asym_id: Tensor(int) selected ground truth asym_id
anchor_pred_asym_ids: list(Tensor(int)) a list of all possible pred anchor candidates
"""
entity_2_asym_list = get_entity_2_asym_list(batch)
unique_entity_ids = torch.unique(batch["entity_id"])
unique_entity_ids = [i for i in torch.unique(batch["entity_id"]) if i !=0]# if entity_id is 0, that means this entity_id comes from padding
entity_asym_count = {}
entity_length = {}

Expand Down Expand Up @@ -145,19 +165,38 @@ def get_least_asym_entity_or_longest_length(batch, input_asym_id):


def greedy_align(
batch,
per_asym_residue_index,
entity_2_asym_list,
pred_ca_pos,
pred_ca_mask,
true_ca_poses,
true_ca_masks,
):
batch: dict,
per_asym_residue_index: dict,
entity_2_asym_list: dict,
pred_ca_pos: torch.Tensor,
pred_ca_mask: torch.Tensor,
true_ca_poses: list,
true_ca_masks: list
) -> List[Tuple[int, int]]:
"""
Implement Algorithm 4 in the Supplementary Information of AlphaFold-Multimer paper:
Evans,R et al., 2022 Protein complex prediction with AlphaFold-Multimer, bioRxiv 2021.10.04.463034; doi: https://doi.org/10.1101/2021.10.04.463034

Args:
batch: a dictionary of ground truth features
per_asym_residue_index: a dictionary recording which residues belong to which aysm_id
entity_2_asym_list: a dictionary recording which asym_id(s) belong to which entity_id
pred_ca_pos: predicted positions of c-alpha atoms from the results of model.forward()
pred_ca_mask: a boolean tensor that masks pred_ca_pos
true_ca_poses: a list of tensors, corresponding to the c-alpha positions of the ground truth structure. e.g. If there are 5 chains, this list will have a length of 5
true_ca_masks: a list of tensors, corresponding to the masks of c-alpha positions of the ground truth structure. If there are 5 chains, this list will have a length of 5
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you explain what relationship (if any) there is between true_ca_masks and pred_ca_mask? Is this an indication of which residues between chains are expected to align.

If you think this is sufficiently defined elsewhere in the multimer codebase, then maybe a simple addition will suffice here.


Return:
A list of tuple(int,int) that provides instructions of how the ground truth chains should be permuated
e.g. if 3 chains in the imput model have the same sequences, an example return would be:
[(0,2),(1,1),(2,0)], meaning the 1st chain in the predicted structure should be aligned to the 3rd chain in the ground truth,
and the 2nd chain in the predicted structure is ok to stay with the 2nd chain in the ground truth.

Note: the tuples in the returned list begin with 0 indexing but aym_id begins with 1. The reason why tuples in the return are 0-indexing
is that at the stage of loss calculation, the ground truth atom positions: true_ca_poses, are already split up into a list of matrices.
Hence, now this function needs to return tuples that provide the index to select from the list: true_ca_poses, and list index starts from 0.
"""
used = [False for _ in range(len(true_ca_poses))]
used = [False for _ in range(len(true_ca_poses))] # a list the keeps recording whether a ground truth chain has been used or not
align = []
unique_asym_ids = [i for i in torch.unique(batch["asym_id"]) if i != 0]
for cur_asym_id in unique_asym_ids:
Expand Down Expand Up @@ -189,21 +228,38 @@ def greedy_align(
return align


def pad_features(feature_tensor, nres_pad, pad_dim):
"""Pad input feature tensor"""
def pad_features(feature_tensor: torch.Tensor, nres_pad: int, pad_dim: int) -> torch.Tensor:
"""
Pad input feature tensor. Padding values will be 0 and put behind the true feature values

Args:
feature_tensor: A feature tensor
nres_pad: number of residues to add
pad_dim: along which dimension of the feature_tensor to pad

Returns:
a padded feature tensor
"""
pad_shape = list(feature_tensor.shape)
pad_shape[pad_dim] = nres_pad
padding_tensor = feature_tensor.new_zeros(pad_shape, device=feature_tensor.device)
return torch.concat((feature_tensor, padding_tensor), dim=pad_dim)


def merge_labels(per_asym_residue_index, labels, align, original_nres):
def merge_labels(per_asym_residue_index: Dict[int,List[int]],
labels: List[Dict], align: List[Tuple[int, int]],
original_nres: int) -> Dict[str, torch.Tensor]:
"""
Merge ground truth labels according to the permutation results

labels: list of original ground truth feats
align: list of tuples, each entry specify the corresponding label of the asym.
Args:
per_asym_residue_index: a dictionary recording which residues belong to which aysm_id
labels: list of original ground truth feats e.g. if there're 5 chains, labels will have a length of 5
dingquanyu marked this conversation as resolved.
Show resolved Hide resolved
align: list of tuples, each entry specify the corresponding label of the asym.
original_nres: int, corresponding to the number of residues specified by crop_size in config.py

Returns:
A new dictionary of permuated ground truth features
modified based on UniFold:
https://github.com/dptech-corp/Uni-Fold/blob/b1c89a2cebd4e4ee4c47b4e443f92beeb9138fbb/unifold/losses/chain_align.py#L176C1-L176C1
"""
Expand All @@ -230,13 +286,20 @@ def merge_labels(per_asym_residue_index, labels, align, original_nres):
return outs


def split_ground_truth_labels(gt_features):
def split_ground_truth_labels(gt_features: dict) -> List[Dict]:
"""
dingquanyu marked this conversation as resolved.
Show resolved Hide resolved
Splits ground truth features according to chains

Args:
gt_features: A dictionary within a the PyTorch DataSet iteration, which returns by the upstream DataLoader.iter() method
In the DataLoader pipeline, all tensors belonging to all the ground truth changes are concatenated so it stays the same as monomer data input format/pipeline,
thus, this function is needed to 1) detect the number of chains i.e. unique(asym_id)
2) split the concatenated tensors back to individual ones that correspond to individual asym_ids

Returns:
a list of feature dictionaries with only necessary ground truth features
required to finish multi-chain permutation
a list of feature dictionaries with only necessary ground truth features
required to finish multi-chain permutation, e.g. it will be a list of 5 elements if there
are 5 chains in total.
"""
unique_asym_ids, asym_id_counts = torch.unique(gt_features["asym_id"], sorted=True, return_counts=True)
n_res = gt_features["asym_id"].shape[-1]
Expand All @@ -251,7 +314,16 @@ def split_dim(shape):
return labels


def get_per_asym_residue_index(features):
def get_per_asym_residue_index(features: dict) -> Dict[int, torch.Tensor]:
"""
A function that retrieve which residues belong to which asym_id

Args:
features: a dictionary that contains input features after cropping

Returns:
A dictionary that records which region of the sequence belongs to which asym_id
"""
unique_asym_ids = [i for i in torch.unique(features["asym_id"]) if i != 0]
per_asym_residue_index = {}
for cur_asym_id in unique_asym_ids:
Expand All @@ -261,34 +333,36 @@ def get_per_asym_residue_index(features):
return per_asym_residue_index


def get_entity_2_asym_list(batch):
def get_entity_2_asym_list(features: dict) -> Dict[int, list]:
"""
Generates a dictionary mapping unique entity IDs to lists of unique asymmetry IDs (asym_id) for each entity.

Args:
batch (dict): A dictionary containing data batches, including "entity_id" and "asym_id" tensors.
features (dict): A dictionary containing data features, including "entity_id" and "asym_id" tensors.

Returns:
entity_2_asym_list (dict): A dictionary where keys are unique entity IDs, and values are lists of unique asymmetry IDs
associated with each entity.
"""
entity_2_asym_list = {}
unique_entity_ids = torch.unique(batch["entity_id"])
unique_entity_ids = torch.unique(features["entity_id"])
for cur_ent_id in unique_entity_ids:
ent_mask = batch["entity_id"] == cur_ent_id
cur_asym_id = torch.unique(batch["asym_id"][ent_mask])
ent_mask = features["entity_id"] == cur_ent_id
cur_asym_id = torch.unique(features["asym_id"][ent_mask])
entity_2_asym_list[int(cur_ent_id)] = cur_asym_id
return entity_2_asym_list


def calculate_input_mask(true_ca_masks, anchor_gt_idx, anchor_gt_residue,
asym_mask, pred_ca_mask):
def calculate_input_mask(true_ca_masks: List[torch.Tensor], anchor_gt_idx: torch.Tensor,
anchor_gt_residue: torch.Tensor,
asym_mask: torch.Tensor, pred_ca_mask: torch.Tensor) -> torch.Tensor:
"""
Calculate an input mask for downstream optimal transformation computation

Args:
true_ca_masks (Tensor): ca mask from ground truth.
anchor_gt_idx (Tensor): The index of selected ground truth anchor.
true_ca_masks: list of masks from ground truth chains.
anchor_gt_idx (Tensor): a tensor with one integer in it. The index of selected ground truth anchor.
anchor_gt_residue:a 1D vector tensor of residue indexes that belongs to the selected ground truth anchor
asym_mask (Tensor): Boolean tensor indicating which regions are selected predicted anchor.
pred_ca_mask (Tensor): ca mask from predicted structure.

Expand All @@ -303,11 +377,38 @@ def calculate_input_mask(true_ca_masks, anchor_gt_idx, anchor_gt_residue,
return input_mask


def calculate_optimal_transform(true_ca_poses,
anchor_gt_idx, anchor_gt_residue,
true_ca_masks, pred_ca_mask,
asym_mask,
pred_ca_pos):
def calculate_optimal_transform(true_ca_poses: List[torch.Tensor],
anchor_gt_idx: int, anchor_gt_residue: torch.Tensor,
true_ca_masks: List[torch.Tensor], pred_ca_mask: torch.Tensor,
asym_mask: torch.Tensor,
pred_ca_pos: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:

"""
Takes selected anchor ground truth c-alpha positions and
selected predicted anchor c-alpha position then calculate the optimal rotation matrix
to align ground-truth anchor and predicted anchor

Args:
true_ca_poses: a list of tensors, corresponding to the c-alpha positions of the ground truth structure. e.g. If there are 5 chains, this list will have a length of 5
anchor_gt_idx (Tensor): a tensor with one integer in it. The index of selected ground truth anchor.
anchor_gt_residue:a 1D vector tensor of residue indexes that belongs to the selected ground truth anchor
true_ca_masks: list of masks from ground truth chains e.g. it will be length=5 if there are 5 chains in ground truth structure
pred_ca_mask: A boolean tensor corresponds to the mask to mask the predicted features
asym_mask: A boolean tensor that mask out other elements in a tensor if they do not belong to a this asym_id
pred_ca_pos: a [nres*3] tensor of predicted c-alpha atom positions

Process:
1) select an achor chain from ground truth, denoted by anchor_gt_idx, and
an chor chain from the predicted structure. Both anchor_gt and anchor_pred have exactly the same sequence
2) obtain the C-alpha positions corresponding to the selected anchor_gt, done be slicing the true_ca_pose according to anchor_gt_residue
3) calculate the optimal transformation that can best align the C-alpha atoms of anchor_pred to those of anchor_gt,
done by Kabsch algorithm: source https://en.wikipedia.org/wiki/Kabsch_algorithm

Returns:
a rotation matrix that record the optimal rotation
that will best align selected anchor prediction to selected anchor truth
a matrix records how the atoms should be shifted after applying r i.e. optimal alignment requires 1) rotate 2) shift the positions
"""
input_mask = calculate_input_mask(true_ca_masks,
anchor_gt_idx,
anchor_gt_residue,
Expand All @@ -326,13 +427,27 @@ def calculate_optimal_transform(true_ca_poses,
return r, x


def compute_permutation_alignment(out, features, ground_truth):
def compute_permutation_alignment(out: Dict[str,torch.Tensor],
features: Dict[str,torch.Tensor],
ground_truth: List[Dict[str, torch.Tensor]]) -> Tuple[List[Tuple[int, int]], Dict[int, List[int]]]:
"""
A class method that first permutate chains in ground truth first
before calculating the loss.
A method that permutes chains in ground truth before calculating the loss
because the mapping between the predicted and ground-truth will become arbitrary.
The model cannot be assumed to predict chains in the same order as the ground truth.
Thus, this function pick the optimal permutaion of predicted chains that best matches the ground truth,
by minimising the RMSD i.e. the best permutation of ground truth chains is selected based on which permutation has the lowest RMSD calculation

Details are described in Section 7.3 in the Supplementary of AlphaFold-Multimer paper:
https://www.biorxiv.org/content/10.1101/2021.10.04.463034v2

Args:
out: a dictionary of output tensors from model.forward()
features: a dictionary of feature tensors that are used as input for model.forward()
ground_truth: a list of dictionaries of features corresponding to chains in ground truth structure e.g. it will be a length of 5 if there are 5 chains in ground truth structure

Returns:
a list of tuple(int,int) that instructs how ground truth chains should be permutated
a dictionary recording which residues belong to which aysm_id
"""
unique_asym_ids = set(torch.unique(features['asym_id']).tolist())
unique_asym_ids.discard(0) # Remove padding asym_id
Expand Down Expand Up @@ -397,13 +512,19 @@ def compute_permutation_alignment(out, features, ground_truth):
return best_align, per_asym_residue_index


def multi_chain_permutation_align(out, features, ground_truth):
"""Compute multi-chain permutation alignment.
def multi_chain_permutation_align(out: Dict[str, torch.Tensor],
features: Dict[str, torch.Tensor],
ground_truth: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
"""
Compute multi-chain permutation alignment.

Args:
out: The output of model.forward()
features: Input features
ground_truth: Ground truth features
out: a dictionary of output tensors from model.forward()
features: a dictionary of feature tensors that are used as input for model.forward()
ground_truth: a list of dictionaries of features corresponding to chains in ground truth structure e.g. it will be a length of 5 if there are 5 chains in ground truth structure

Returns:
features: a dictionary with updated ground truth feature tensors, ready for downstream loss calculations.
"""

labels = split_ground_truth_labels(ground_truth)
Expand Down