Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

MaxPool2D memory leakage on device MPS #125217

Open
mehmetozsoy1 opened this issue Apr 30, 2024 · 2 comments
Open

MaxPool2D memory leakage on device MPS #125217

mehmetozsoy1 opened this issue Apr 30, 2024 · 2 comments
Assignees
Labels
high priority module: mps Related to Apple Metal Performance Shaders framework triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@mehmetozsoy1
Copy link

mehmetozsoy1 commented Apr 30, 2024

馃悰 Describe the bug

On MPS device, allocated memory keeps increasing when the model has MaxPool2D modules. It raises an out of memory error when tensor sizes are too big.

Here is an example:

import torch 
device = "mps:0"



torch_model = torch.nn.Sequential(
    torch.nn.Conv2d(32,32,3),
    torch.nn.MaxPool2d((2,2)),
    torch.nn.Conv2d(32,32,3),
    torch.nn.MaxPool2d((2,2)),
    torch.nn.Conv2d(32,32,3),
    torch.nn.MaxPool2d((2,2)),
    torch.nn.Conv2d(32,32,3),
    torch.nn.MaxPool2d((2,2))
)

input = torch.rand(64, 32, 256, 256).to(device)
torch_model.to(device)

for _ in range(10000):
    output = torch_model(input)
    loss = output.sum()
    loss.backward()
    print("allocated memory: ", torch.mps.driver_allocated_memory())

In the above code, allocated memory constantly increases at each iteration when MaxPool2D layers are present. Eventually it crashes due to out of memory. However, when I remove the MaxPool2D layers, it runs without a problem, as the allocated memory remains constant at each iteration

import torch 
device = "mps:0"



torch_model = torch.nn.Sequential(
    torch.nn.Conv2d(32,32,3),
    # torch.nn.MaxPool2d((2,2)),
    torch.nn.Conv2d(32,32,3),
    # torch.nn.MaxPool2d((2,2)),
    torch.nn.Conv2d(32,32,3),
    # torch.nn.MaxPool2d((2,2)),
    torch.nn.Conv2d(32,32,3)
    # torch.nn.MaxPool2d((2,2))
)

input = torch.rand(64, 32, 256, 256).to(device)
torch_model.to(device)

for _ in range(10000):
    output = torch_model(input)
    loss = output.sum()
    loss.backward()
    print("allocated memory: ", torch.mps.driver_allocated_memory())

This code runs without a problem.

Versions

PyTorch version: 2.3.0
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 14.4.1 (arm64)
GCC version: Could not collect
Clang version: 15.0.0 (clang-1500.3.9.4)
CMake version: Could not collect
Libc version: N/A

Python version: 3.11.5 (main, Sep 11 2023, 08:31:25) [Clang 14.0.6 ] (64-bit runtime)
Python platform: macOS-14.4.1-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Apple M1 Pro

Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] torch==2.3.0
[pip3] torchaudio==2.3.0
[pip3] torchvision==0.18.0
[conda] numpy 1.26.4 py311he598dae_0
[conda] numpy-base 1.26.4 py311hfbfe69c_0
[conda] pytorch 2.3.0 py3.11_0 pytorch
[conda] torchaudio 2.3.0 py311_cpu pytorch
[conda] torchvision 0.18.0 py311_cpu pytorch

cc @ezyang @gchanan @zou3519 @kadeng @kulinseth @albanD @malfet @DenisVieriu97 @jhavukainen

@mehmetozsoy1 mehmetozsoy1 changed the title Torch MaxPool2D memory leakage on device MPS MaxPool2D memory leakage on device MPS Apr 30, 2024
@cpuhrsch cpuhrsch added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module module: mps Related to Apple Metal Performance Shaders framework high priority labels Apr 30, 2024
@cpuhrsch
Copy link
Contributor

Seems high priority if reproducible.

@jhavukainen
Copy link
Collaborator

jhavukainen commented May 1, 2024

Can confirm that this is reproducible. The leak tool didn't show any direct memory leakage, it seems more like MPS backend holding onto the allocated memory in a clearly problematic way that requires us to investigate. My first guess is something goes bonkers in caching the MaxPool2D graphs and it just keeps creating more and more of those until runs out of memory. Thanks for filing this issue with such a clean repro case, we'll look into this with high priority.

@jhavukainen jhavukainen self-assigned this May 1, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
high priority module: mps Related to Apple Metal Performance Shaders framework triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

4 participants