Skip to content

MadryLab/rla

Repository files navigation

Jointly Embedding Protein Structures and Sequences through Residue Level Alignment

This repository provides is the associated code for "Jointly Embedding Protein Structures and Sequences through Residue Level Alignment" by Foster Birnbaum, Saachi Jain, Aleksander Madry, and Amy E. Keating.

Setup

Requirements

The rla_env.yml file specifies the needed requirements.

Model weights and data

Model weights and data are available to download here. The model weights folder should be downloaded to the home directory. The data are provided as a zipped folder containing the train/validation/test datasplits as WebDatasets.

Example inference data are provided in the example_data folder.

To package a .pdf file into the WebDataset format that is most easily read by RLA, follow this example.

Inference

Once the model weights folder is downloaded, the value of the model_dir = /c/example/path line in each example notebook must be changed to reflect the path to the weights. Additionally, if the computer you are running the notebook on is offline, the args_dict['arch'] = '/c/example/path' line must be changed to reflect the path to ESM-2 as downloaded by the transformers module. If your computer is online, the args_dict['arch'] = '/c/example/path' should be deleted.

The core idea behind using RLA is to pass the sequence and structure of a protein through the corresponding tracks of RLA (RLA-ESM and RLA-COORDinator) to generate sequence and structure embeddings and to calculate the residue-level cosine similarity in the resulting embeddings. The cosine similarities are averaged to generate an RLA score that represents the sequence-structure compatability in the input protein. The following is an example case of generate an RLA score, assuming val_loader is a PyTorch dataloader generated by processing the WebDataset as shown in the example above.

import src.data_utils as data_utils

## Get sequence and structure embeddings from RLA
def get_seq_and_struct_features(model, tokenizer, batch):
    seq_batch, coords_batch = batch
    seqs = seq_batch['string_sequence']
    text_inp = tokenizer(seqs, return_tensors='pt', padding=True, truncation=True, max_length=1024+2)
    text_inp['position_ids'] = seq_batch['pos_embs'][0]
    text_inp = {k: v.to('cuda') for k, v in text_inp.items()}
    coord_data = data_utils.construct_gnn_inp(coords_batch, device='cuda', half_precision=True)
    gnn_features, text_features, logit_scale = model(text_inp, coord_data) # Get features
    new_text_features, _, new_text_mask = data_utils.postprocess_text_features(
        text_features=text_features, 
        inp_dict=text_inp, 
        tokenizer=tokenizer, 
        placeholder_mask=seq_batch['placeholder_mask'][0])
    return {
        'text': new_text_features, # text feature
        'gnn': gnn_features, # gnn feature
        'seq_mask_with_burn_in': seq_batch['seq_loss_mask'][0], # sequence mask of what's supervised
        'coord_mask_with_burn_in': coords_batch['coords_loss_mask'][0], # coord mask of what's supervised
        'seq_mask_no_burn_in': new_text_mask.bool(), # sequence mask of what's valid (e.g., not padded)
        'coord_mask_no_burn_in': coords_batch['coords'][1], # coord mask of what's valid
    }

all_scores = []
for i, batch in enumerate(val_loader):
    with torch.no_grad():
        with autocast(dtype=torch.float16):
            output_dict = get_seq_and_struct_features(trained_model, tokenizer, batch)
            text_feat = output_dict['text']
            gnn_feat =  output_dict['gnn'][:, :text_feat.shape[1]] # Remove tail padding
            scores = (text_feat.unsqueeze(2) @ gnn_feat.unsqueeze(-1)).squeeze(-1).squeeze(-1)
            scores = (scores * output_dict['seq_mask_no_burn_in'].float()).sum(1)/output_dict['seq_mask_no_burn_in'].sum(1) # Calculate RLA score
            all_scores.append(scores.cpu())

Structural candidate ranking

An example of how to use RLA to rank candidate structures is provided here. The example ranks hundreds of decoy structures for 2 real structures from the PDB and evaluates the comparison by calculating a correlation to the decoy TM-scores. The data are sourced from Roney and Ovchinnikov, 2022.

Mutation effect prediction

An example of how to use RLA to predict the effect of mutations is provided here. The example predicts the effects of thousands of single and double amino acid substitutions on the stability of single chain proteins and compares the predictions to experimentally observed values. The data are sourced from Tsuboyama et al., 2023.

Contact prediction

An example of how to use RLA to predict the contacts between 2 residues in a protein is provided here.

Training

To train the model, use the clip_main.py script. For example

python clip_main.py --config dataset_configs/full_pdb.yaml --training.exp_name experiment_name --model.coordinator_hparams terminator_configs/standard.json 

Funding

This work was supported by NIH awards R01GM129007 and R35GM149227 to A.E.K.