Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Question in offload.py: Moving activation to CPU does NOT reduce GPU memory. #948

Open
swordfate opened this issue Mar 4, 2022 · 14 comments
Labels
bug Something isn't working help wanted Extra attention is needed offload_model triaged

Comments

@swordfate
Copy link

swordfate commented Mar 4, 2022

I use my cuda_active_bytes function to measure the GPU memory before and after the code line below. I find moving activation to CPU does NOT reduce GPU memory.

self._activations[index] = tuple([a.cpu() for a in list(self._activations[index])])

def cuda_active_bytes():
    torch.cuda.synchronize()
    stats = torch.cuda.memory_stats()
    current_active_byte =  stats['active_bytes.all.current']
    return current_active_byte

So actually all the activations generated by forward is still in GPU memory? If so, I think the code line above is redundant.

@swordfate swordfate changed the title Question: Moving activation to CPU does NOT reduce GPU memory. Question in offload.py: Moving activation to CPU does NOT reduce GPU memory. Mar 4, 2022
@flying-x
Copy link
Contributor

flying-x commented Mar 4, 2022

@anj-s

@flying-x
Copy link
Contributor

flying-x commented Mar 4, 2022

pytorch has a caching memory allocator for gpu memory. Therefore, you need to call torch.cuda.empty_cache() to see the memory reduction. Otherwise, the memory will be used for next allocation. Another way is to use the memory allocator's counters, see https://pytorch.org/docs/stable/generated/torch.cuda.memory_allocated.html#torch.cuda.memory_allocated

@swordfate
Copy link
Author

swordfate commented Mar 7, 2022

pytorch has a caching memory allocator for gpu memory. Therefore, you need to call torch.cuda.empty_cache() to see the memory reduction. Otherwise, the memory will be used for next allocation. Another way is to use the memory allocator's counters, see https://pytorch.org/docs/stable/generated/torch.cuda.memory_allocated.html#torch.cuda.memory_allocated

Thank you for patient reply, but I tried both torch.cuda.empty_cache() and torch.cuda.memory_allocated() function as follows and it shows the gpu memory is still NOT reduced. I think maybe this is because that tensor.cpu() just copy a tensor to cpu memory (this is different from module.cpu() and this tensor is not the output tensor), thus, the tensor still exists in the tensor calculation graph? Am I right?

def cuda_active_bytes():
    torch.cuda.synchronize()
    torch.cuda.empty_cache()
    stats = torch.cuda.memory_stats()
    current_active_byte =  stats['active_bytes.all.current']
    print('torch.cuda.memory_allocated:', torch.cuda.memory_allocated())
    return current_active_byte

@min-xu-ai
Copy link
Contributor

Can you share a small test code that shows what you are seeing? I.e., test without offload, just pure pytorch code.

@anj-s
Copy link
Contributor

anj-s commented Mar 7, 2022

@swordfate Can you print out the device of the activations post move to confirm?

@swordfate
Copy link
Author

@min-xu-ai @anj-s Here is a small example for checking GPU memory allocated and tensor location.

import torch

def print_cuda_active_bytes():
    torch.cuda.synchronize()
    torch.cuda.empty_cache()
    # stats = torch.cuda.memory_stats()
    # print('active_bytes:', stats['active_bytes.all.current'])
    print('memory_allocated:', torch.cuda.memory_allocated())

x = torch.randn((2, 3), device='cuda', requires_grad=True) # leaf node
w = torch.randn((3, 4), device='cuda', requires_grad=True) # leaf node
l = torch.matmul(x, w) # intermediate node
y = torch.matmul(l, l.T) # output node

# Check GPU memory
print('Before moving...,', end=' ')
print_cuda_active_bytes() # Before moving..., memory_allocated: 2048
x = x.cpu() # try to move leaf node to cpu
print('After moving x to cpu,', end=' ')
print_cuda_active_bytes() # After moving x to cpu, memory_allocated: 2048
w = w.cpu() # try to move leaf node to cpu
print('After moving w to cpu,', end=' ')
print_cuda_active_bytes() # After moving w to cpu, memory_allocated: 2048
l = l.cpu() # try to move intermediate node to cpu
print('After moving l to cpu,', end=' ')
print_cuda_active_bytes() # After moving l to cpu, memory_allocated: 2048
y = y.cpu() # try to move the output node to cpu
print('After moving y to cpu,', end=' ') # After moving y to cpu, memory_allocated: 1536
print_cuda_active_bytes()

# Check tensor location
a = torch.randn((2, 3), device='cuda', requires_grad=True)
b = torch.randn((3, 4), device='cuda', requires_grad=True)
c = torch.matmul(a, b)
d = torch.matmul(c, c.T)
e = c.cpu()
print(e.is_cuda) # False
print(c.is_cuda) # True

@min-xu-ai
Copy link
Contributor

Just for testing, what if you put the computation under torch.no_grad? I think autograd engine kept the memory for backward compute.

@swordfate
Copy link
Author

swordfate commented Mar 7, 2022

Just for testing, what if you put the computation under torch.no_grad? I think autograd engine kept the memory for backward compute.

Yes. After I set requires_grad =False, it prints:

# Check GPU memory usage
Before moving..., memory_allocated: 2048
After moving x to cpu, memory_allocated: 1536
After moving w to cpu, memory_allocated: 1024
After moving l to cpu, memory_allocated: 512
After moving y to cpu, memory_allocated: 0

# Check tensor location
a = torch.randn((2, 3), device='cuda', requires_grad=False)
b = torch.randn((3, 4), device='cuda', requires_grad=False)
c = torch.matmul(a, b)
d = torch.matmul(c, c.T)
e = c.cpu()
print(e.is_cuda) # False
print(c.is_cuda) # True, Here it shows `c` is still in GPU.

So, this code implementation can not reduce GPU memory when moving activations to cpu using tensor.cpu() (line #524 code) because that activation is still in GPU for backward computing?

@min-xu-ai
Copy link
Contributor

In this case, you can use checkpoint_wrapper and offload the activation to cpu using that wrapper. This way, only during backward, the tensor will be moved back to gpu.

@swordfate
Copy link
Author

swordfate commented Mar 7, 2022

In this case, you can use checkpoint_wrapper and offload the activation to cpu using that wrapper. This way, only during backward, the tensor will be moved back to gpu.

Thanks for telling me the solution, I will dive into it in the future.

@anj-s
Copy link
Contributor

anj-s commented Mar 14, 2022

Ideally we do want the intermediate activations to be offloaded to the CPU. Marking this as a bug that we should fix.

@anj-s anj-s added bug Something isn't working help wanted Extra attention is needed triaged labels Mar 14, 2022
@Alex-Songs
Copy link

@swordfate
Hello, when I use fairscale 0.4.6, the same error is reported: None of the inputs have requires_grad=True. Gradients will be None warnings.warn("None of the inputs have requires_grad=True. Gradients will be None. Please solve it ? I just use the activation checkpoint for the middle layer of the Transformer.

@Alex-Songs
Copy link

@min-xu-ai @flying-x @anj-s @swordfate Hello, when I use fairscale 0.4.6, the same error is reported: None of the inputs have requires_grad=True. Gradients will be None warnings.warn("None of the inputs have requires_grad=True. Gradients will be None. Please solve it ? I just use the activation checkpoint for the middle layer of the Transformer.

@min-xu-ai
Copy link
Contributor

Do you have a small test case? Have you tried pytorch version of FSDP?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working help wanted Extra attention is needed offload_model triaged
Projects
None yet
Development

No branches or pull requests

5 participants