Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue in Manual optimisation, during self.manual_backward call #19810

Open
pranavrao-qure opened this issue Apr 25, 2024 · 1 comment
Open

Issue in Manual optimisation, during self.manual_backward call #19810

pranavrao-qure opened this issue Apr 25, 2024 · 1 comment
Labels
bug Something isn't working needs triage Waiting to be triaged by maintainers ver: 2.0.x

Comments

@pranavrao-qure
Copy link

pranavrao-qure commented Apr 25, 2024

Bug description

I have set automatic_optimization to False, and am using self.manual_backward to calculate and populate the gradients. The code breaks during the self.manual_backward call, raising the error "RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn". I have posted the code below for replicating the issue.
The issue does not arise when I set args['use_minibatch_clip_loss'] = False, or when I set args['batch_size'] = args['minibatch_size'] = 16. I suspect the issue only arises when I try to do backwards after running the model under torch.no_grad()

What version are you seeing the problem on?

v2.2

How to reproduce the bug

import os
import math

import torch
from torch import nn
from torch.utils.data import TensorDataset, DataLoader

import lightning.pytorch as pl
from lightning.pytorch import loggers as pl_loggers

class TestModule(nn.Module):
    def __init__(self, in_dim=512, out_dim=16):
        super().__init__()
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.query = nn.Linear(self.in_dim, self.out_dim, bias=True)

    def forward(self, input):
        return self.query(input)

class TestLitModel(pl.LightningModule):
    def __init__(self, args):
        super().__init__()
        self.test_module_obj = TestModule(args['in_dim'], args['out_dim'])
        self.use_minibatch_clip_loss = args['use_minibatch_clip_loss']
        if self.use_minibatch_clip_loss:
            self.batch_size = args['batch_size']
            self.minibatch_size = args['minibatch_size']
            self.accumulate_grad_batches_mb = args['accumulate_grad_batches_mb']
            self.automatic_optimization = False

    def get_mini_batches(self, input):
        num_mb = math.ceil(self.batch_size / self.minibatch_size)
        return torch.chunk(input, num_mb)

    def shared_step(self, input):
        output = self.test_module_obj(input)
        loss = output.mean()
        return loss

    def train_step_minibatch(self, input, batch_idx):
        if self.batch_size > self.minibatch_size:
            mini_batches = self.get_mini_batches(input)
            mb_model_output_list = list()
            with torch.no_grad():
                for mb in mini_batches:
                    mb_model_output_list.append(self.shared_step(mb).detach())
                all_loss = sum(mb_model_output_list)

            self.test_module_obj.train()
            self.test_module_obj.requires_grad_(True)
            torch.set_grad_enabled(True)
            assert torch.is_grad_enabled()
            assert all(p.requires_grad for p in self.test_module_obj.parameters())
            for _, mb in enumerate(mini_batches):
                mb_model_output = self.shared_step(mb)
                self.manual_backward(mb_model_output)
        else:
            all_loss = self.shared_step(input)
            self.manual_backward(all_loss)

        # get optimizers and scheduler
        if (batch_idx + 1) % self.accumulate_grad_batches_mb == 0:
            optimizer = self.optimizers()
            if isinstance(optimizer, list):
                optimizer = optimizer[0]
            optimizer.step()
            optimizer.zero_grad()
        return all_loss

    def training_step(self, batch, batch_idx):
        input = batch[0]
        if self.use_minibatch_clip_loss:
            loss = self.train_step_minibatch(input, batch_idx)
        else:
            loss = self.shared_step(input)
        return loss

    def validation_step(self, batch, batch_idx):
        input = batch[0]
        loss = self.shared_step(input)
        return loss

    def configure_optimizers(self):
        self.optimizer = torch.optim.AdamW(list(self.test_module_obj.parameters()), lr= 0.0002, weight_decay= 0.01)
        return {"optimizer": self.optimizer}


if __name__ == '__main__':
    args = {
        'in_dim': 512,
        'out_dim': 16,
        'train_batch_size': 16,
        'val_batch_size': 64,
        'use_minibatch_clip_loss': True,
        'batch_size': 16,
        'minibatch_size': 4,
        'accumulate_grad_batches_mb': 1,
    }

    x_dummy = torch.randn(512, args['in_dim'])                                                                      # 512 samples, args['in_dim'] features each
    test_data_loader = DataLoader(TensorDataset(x_dummy), batch_size=args['train_batch_size'], shuffle=False)       # Dummy dataset
    test_lit_model = TestLitModel(args)

    # -- LOGGING
    checkpoint_dir = 'test_logs/'
    tb_logger = pl_loggers.TensorBoardLogger(save_dir=os.path.join(checkpoint_dir, "logs"))
    trainer = pl.Trainer(
            logger=tb_logger,
            accelerator='gpu',
            devices=[1],
            strategy='auto',
            precision='16-mixed',
            max_epochs=1,
            accumulate_grad_batches=1,
            num_sanity_val_steps=0,
            inference_mode=False,
        )
    trainer.fit(test_lit_model, test_data_loader)

Error messages and logs

Epoch 0:   0%|                                                                                                                                                                                                                                | 0/32 [00:00<?, ?it/s]Traceback (most recent call last):
  File "/home/users/pranav.rao/foundational_models/test1.py", line 119, in <module>
    trainer.fit(test_lit_model, test_data_loader)
  File "/home/users/pranav.rao/miniconda3/envs/pytorch_lightning_bug_report/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 544, in fit
    call._call_and_handle_interrupt(
  File "/home/users/pranav.rao/miniconda3/envs/pytorch_lightning_bug_report/lib/python3.11/site-packages/lightning/pytorch/trainer/call.py", line 44, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/users/pranav.rao/miniconda3/envs/pytorch_lightning_bug_report/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 580, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/home/users/pranav.rao/miniconda3/envs/pytorch_lightning_bug_report/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 987, in _run
    results = self._run_stage()
              ^^^^^^^^^^^^^^^^^
  File "/home/users/pranav.rao/miniconda3/envs/pytorch_lightning_bug_report/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 1033, in _run_stage
    self.fit_loop.run()
  File "/home/users/pranav.rao/miniconda3/envs/pytorch_lightning_bug_report/lib/python3.11/site-packages/lightning/pytorch/loops/fit_loop.py", line 205, in run
    self.advance()
  File "/home/users/pranav.rao/miniconda3/envs/pytorch_lightning_bug_report/lib/python3.11/site-packages/lightning/pytorch/loops/fit_loop.py", line 363, in advance
    self.epoch_loop.run(self._data_fetcher)
  File "/home/users/pranav.rao/miniconda3/envs/pytorch_lightning_bug_report/lib/python3.11/site-packages/lightning/pytorch/loops/training_epoch_loop.py", line 140, in run
    self.advance(data_fetcher)
  File "/home/users/pranav.rao/miniconda3/envs/pytorch_lightning_bug_report/lib/python3.11/site-packages/lightning/pytorch/loops/training_epoch_loop.py", line 252, in advance
    batch_output = self.manual_optimization.run(kwargs)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/users/pranav.rao/miniconda3/envs/pytorch_lightning_bug_report/lib/python3.11/site-packages/lightning/pytorch/loops/optimization/manual.py", line 94, in run
    self.advance(kwargs)
  File "/home/users/pranav.rao/miniconda3/envs/pytorch_lightning_bug_report/lib/python3.11/site-packages/lightning/pytorch/loops/optimization/manual.py", line 114, in advance
    training_step_output = call._call_strategy_hook(trainer, "training_step", *kwargs.values())
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/users/pranav.rao/miniconda3/envs/pytorch_lightning_bug_report/lib/python3.11/site-packages/lightning/pytorch/trainer/call.py", line 309, in _call_strategy_hook
    output = fn(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^
  File "/home/users/pranav.rao/miniconda3/envs/pytorch_lightning_bug_report/lib/python3.11/site-packages/lightning/pytorch/strategies/strategy.py", line 391, in training_step
    return self.lightning_module.training_step(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/users/pranav.rao/foundational_models/test1.py", line 74, in training_step
    loss = self.train_step_minibatch(input, batch_idx)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/users/pranav.rao/foundational_models/test1.py", line 57, in train_step_minibatch
    self.manual_backward(mb_model_output)
  File "/home/users/pranav.rao/miniconda3/envs/pytorch_lightning_bug_report/lib/python3.11/site-packages/lightning/pytorch/core/module.py", line 1071, in manual_backward
    self.trainer.strategy.backward(loss, None, *args, **kwargs)
  File "/home/users/pranav.rao/miniconda3/envs/pytorch_lightning_bug_report/lib/python3.11/site-packages/lightning/pytorch/strategies/strategy.py", line 213, in backward
    self.precision_plugin.backward(closure_loss, self.lightning_module, optimizer, *args, **kwargs)
  File "/home/users/pranav.rao/miniconda3/envs/pytorch_lightning_bug_report/lib/python3.11/site-packages/lightning/pytorch/plugins/precision/precision.py", line 72, in backward
    model.backward(tensor, *args, **kwargs)
  File "/home/users/pranav.rao/miniconda3/envs/pytorch_lightning_bug_report/lib/python3.11/site-packages/lightning/pytorch/core/module.py", line 1090, in backward
    loss.backward(*args, **kwargs)
  File "/home/users/pranav.rao/miniconda3/envs/pytorch_lightning_bug_report/lib/python3.11/site-packages/torch/_tensor.py", line 525, in backward
    torch.autograd.backward(
  File "/home/users/pranav.rao/miniconda3/envs/pytorch_lightning_bug_report/lib/python3.11/site-packages/torch/autograd/__init__.py", line 267, in backward
    _engine_run_backward(
  File "/home/users/pranav.rao/miniconda3/envs/pytorch_lightning_bug_report/lib/python3.11/site-packages/torch/autograd/graph.py", line 744, in _engine_run_backward
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

Environment

Current environment
  • CUDA:
    - GPU:
    - NVIDIA A100 80GB PCIe
    - NVIDIA A100 80GB PCIe
    - NVIDIA A100 80GB PCIe
    - NVIDIA A100 80GB PCIe
    - available: True
    - version: 12.1
  • Lightning:
    - lightning: 2.2.3
    - lightning-utilities: 0.11.2
    - pytorch-lightning: 2.2.3
    - torch: 2.3.0
    - torchmetrics: 1.3.2
  • Packages:
    - absl-py: 2.1.0
    - aiohttp: 3.9.5
    - aiosignal: 1.3.1
    - attrs: 23.2.0
    - filelock: 3.13.4
    - frozenlist: 1.4.1
    - fsspec: 2024.3.1
    - grpcio: 1.62.2
    - idna: 3.7
    - jinja2: 3.1.3
    - lightning: 2.2.3
    - lightning-utilities: 0.11.2
    - markdown: 3.6
    - markupsafe: 2.1.5
    - mpmath: 1.3.0
    - multidict: 6.0.5
    - networkx: 3.3
    - numpy: 1.26.4
    - nvidia-cublas-cu12: 12.1.3.1
    - nvidia-cuda-cupti-cu12: 12.1.105
    - nvidia-cuda-nvrtc-cu12: 12.1.105
    - nvidia-cuda-runtime-cu12: 12.1.105
    - nvidia-cudnn-cu12: 8.9.2.26
    - nvidia-cufft-cu12: 11.0.2.54
    - nvidia-curand-cu12: 10.3.2.106
    - nvidia-cusolver-cu12: 11.4.5.107
    - nvidia-cusparse-cu12: 12.1.0.106
    - nvidia-nccl-cu12: 2.20.5
    - nvidia-nvjitlink-cu12: 12.4.127
    - nvidia-nvtx-cu12: 12.1.105
    - packaging: 24.0
    - pip: 23.3.1
    - protobuf: 5.26.1
    - pytorch-lightning: 2.2.3
    - pyyaml: 6.0.1
    - setuptools: 68.2.2
    - six: 1.16.0
    - sympy: 1.12
    - tensorboard: 2.16.2
    - tensorboard-data-server: 0.7.2
    - torch: 2.3.0
    - torchmetrics: 1.3.2
    - tqdm: 4.66.2
    - triton: 2.3.0
    - typing-extensions: 4.11.0
    - werkzeug: 3.0.2
    - wheel: 0.41.2
    - yarl: 1.9.4
  • System:
    - OS: Linux
    - architecture:
    - 64bit
    - ELF
    - processor: x86_64
    - python: 3.11.9
    - release: 5.15.0-69-generic
    - version: Quantisation and Pruning Support #76~20.04.1-Ubuntu SMP Mon Mar 20 15:54:19 UTC 2023

More info

I am training a Vision Language model with CLIP loss. The batch size I want to use is large, which requires to calculate the embeddings in mini batches and then calculate the gradient in mini batches as done in the repo https://github.com/Zasder3/train-CLIP/tree/main (See lines: https://github.com/Zasder3/train-CLIP/blob/79d4c7960072047a9e0d39335ab60dcb150640c3/models/wrapper.py#L64-L109 )

The issue arose when I implemented the similar algorithm as above for my use case and tried to train it. I have tried to isolate the problem as much I could, and produce a simple script reproducing the same error I get.

@pranavrao-qure pranavrao-qure added bug Something isn't working needs triage Waiting to be triaged by maintainers labels Apr 25, 2024
@pranavrao-qure
Copy link
Author

I removed excess code, made a new Conda environment, installing just pytorch-lightning and tensorboard, and was able to replicate the same issue even with lightning version 2.2.3. I have edited the above issue to reflect the same.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working needs triage Waiting to be triaged by maintainers ver: 2.0.x
Projects
None yet
Development

No branches or pull requests

1 participant