Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CUDA out of memory in unexpected cases #24

Open
havu73 opened this issue Mar 3, 2023 · 2 comments
Open

CUDA out of memory in unexpected cases #24

havu73 opened this issue Mar 3, 2023 · 2 comments

Comments

@havu73
Copy link

havu73 commented Mar 3, 2023

Hi! Thank you so much for providing this implementation and documentations. I have been looking for the pytorch implementation of enformer for a while.

When I run the provided code test_pretrained.py using 8G of GPU, the model behaves as expected. The input sequence dimension is [131072, 4], which means the input DNA sequence is of length 131072 bp.

However, when I started to write a script that would take the embedding of some genomic sequences, which according to enformer paper and the default parameters, is of length 196608 bp, then I started to get errors of CUDA out of memory. This happens even if my batch size is 1. It also happens when I increases my memory request to 32Gb (instead of 8G as before). If I run the same code to get embeddings for input sequences (batchsize=1 to maximum of 8), then it can run on a CPU with 32Gb of memory.

What confused me is that the code test_pretrained.py seems to run fine with just 8GB and CUDA, but when I just try to get embeddings on 1 input sequence (which theoretically less computing than what is done in test_pretrained.py), I get such an error. I have been looking at various platforms to fix this, but I have not been successful. Do you have any suggestions about why that is the case?

Thank you so much. Below is a snapshot of my code to get embeddings:

import torch
from enformer_pytorch import Enformer, seq_indices_to_one_hot, GenomeIntervalDataset
from torch.utils.data import DataLoader


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Enformer.from_pretrained('enformer-official-rough')
model = model.to(device)


ds = GenomeIntervalDataset(
    bed_file = '/data/gnomad/v2.1.1_hg19/variants/MAF/chr1_maf.bed.gz',                       
    fasta_file = '/data/ucsc/sequence/hg19/chr1.fa.gz',                        
    return_seq_indices = True,                          
    shift_augs = None,                            
    context_length = 196_608,
    chr_bed_to_fasta_map = {
        'chr1': 'chr1'
    }
)

@lucidrains
Copy link
Owner

@havu73 hello! while the network was trained at 191k, you can still run it at 131k (or even less, provided it is divisible by the total downsample factors of the conv stem, which is 128), if that is what will fit into memory

it would have a bit less context, but perhaps good to see if you get signal first before scaling up

@havu73
Copy link
Author

havu73 commented Mar 6, 2023

Thank you so much @lucidrains I really appreciate it! This is super helpful. I have been a bit busy to catch up on this, but once I tested this out I will close the issue. Thanks again very much!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants