Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PyroModelGuideWarmup fails on GPU - probably need to be manually run before trainer.fit() #2616

Open
vitkl opened this issue Mar 18, 2024 · 4 comments
Labels

Comments

@vitkl
Copy link
Contributor

vitkl commented Mar 18, 2024

PyroModelGuideWarmup fails on GPU probably because Callback.setup() is called in the accelerator environment in the latest PyTorch Lightning.

This test fails on GPU:

pytest tests/model/test_pyro.py::test_pyro_bayesian_regression_low_level --accelerator 'gpu'
(cell2state_cuda118_torch22) vk7@farm22-gpu0203:.../software/tests/scvi-tools$ pytest tests/model/test_pyro.py::test_pyro_bayesian_regression_low_level --accelerator 'gpu'
=================================================================== test session starts ===================================================================
platform linux -- Python 3.10.13, pytest-8.1.1, pluggy-1.4.0
rootdir: .../software/tests/scvi-tools
configfile: pyproject.toml
plugins: cov-4.1.0, anyio-4.3.0
collected 1 item                                                                                                                                          

tests/model/test_pyro.py F                                                                                                                          [100%]

======================================================================== FAILURES =========================================================================
_________________________________________________________ test_pyro_bayesian_regression_low_level _________________________________________________________

self = BayesianRegressionPyroModel(
  (linear): PyroLinear(in_features=100, out_features=1, bias=True)
)
x = tensor([[ 6., 25.,  3.,  ..., 10., 22., 13.],
        [14.,  3., 14.,  ...,  0.,  6., 14.],
        [19.,  0.,  0.,  ....0.,  8.],
        [ 0.,  9.,  2.,  ..., 14.,  6.,  0.],
        [ 0.,  0.,  0.,  ..., 13.,  9., 12.]], device='cuda:0')
y = tensor([[0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
 ...      [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0]], device='cuda:0')
ind_x = tensor([ 29, 272, 379, 251, 149, 339, 147, 137, 197, 275, 139, 323, 365, 322,
        362,  59,  99, 281, 397,  31,  7... 301,
         92, 378, 221, 280, 349,  46,  83, 222,  48, 180, 279, 395,  53,  87,
        386,   7], device='cuda:0')

    def forward(self, x, y, ind_x):
        obs_plate = self.create_plates(x, y, ind_x)
    
        sigma = pyro.sample("sigma", dist.Exponential(self.one))
    
>       mean = self.linear(x).squeeze(-1)

tests/model/test_pyro.py:98: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
.../software/miniconda3farm5/envs/cell2state_cuda118_torch22/lib/python3.10/site-packages/pyro/nn/module.py:450: in __call__
    result = super().__call__(*args, **kwargs)
.../software/miniconda3farm5/envs/cell2state_cuda118_torch22/lib/python3.10/site-packages/torch/nn/modules/module.py:1511: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
.../software/miniconda3farm5/envs/cell2state_cuda118_torch22/lib/python3.10/site-packages/torch/nn/modules/module.py:1520: in _call_impl
    return forward_call(*args, **kwargs)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

self = PyroLinear(in_features=100, out_features=1, bias=True)
input = tensor([[ 6., 25.,  3.,  ..., 10., 22., 13.],
        [14.,  3., 14.,  ...,  0.,  6., 14.],
        [19.,  0.,  0.,  ....0.,  8.],
        [ 0.,  9.,  2.,  ..., 14.,  6.,  0.],
        [ 0.,  0.,  0.,  ..., 13.,  9., 12.]], device='cuda:0')

    def forward(self, input: Tensor) -> Tensor:
>       return F.linear(input, self.weight, self.bias)
E       RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument mat1 in method wrapper_CUDA_addmm)

.../software/miniconda3farm5/envs/cell2state_cuda118_torch22/lib/python3.10/site-packages/torch/nn/modules/linear.py:116: RuntimeError

The above exception was the direct cause of the following exception:

accelerator = 'gpu', devices = 'auto'

    def test_pyro_bayesian_regression_low_level(
        accelerator: str,
        devices: list | str | int,
    ):
        adata = synthetic_iid()
        adata_manager = _create_indices_adata_manager(adata)
        train_dl = AnnDataLoader(adata_manager, shuffle=True, batch_size=128)
        pyro.clear_param_store()
        model = BayesianRegressionModule(in_features=adata.shape[1], out_features=1)
        plan = LowLevelPyroTrainingPlan(model)
        plan.n_obs_training = len(train_dl.indices)
        trainer = Trainer(
            accelerator=accelerator,
            devices=devices,
            max_epochs=2,
            callbacks=[PyroModelGuideWarmup(train_dl)],
        )
>       trainer.fit(plan, train_dl)

tests/model/test_pyro.py:203: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
scvi/train/_trainer.py:219: in fit
    super().fit(*args, **kwargs)
.../software/miniconda3farm5/envs/cell2state_cuda118_torch22/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py:544: in fit
    call._call_and_handle_interrupt(
.../software/miniconda3farm5/envs/cell2state_cuda118_torch22/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py:44: in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
.../software/miniconda3farm5/envs/cell2state_cuda118_torch22/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py:580: in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
.../software/miniconda3farm5/envs/cell2state_cuda118_torch22/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py:950: in _run
    call._call_setup_hook(self)  # allow user to setup lightning_module in accelerator environment
.../software/miniconda3farm5/envs/cell2state_cuda118_torch22/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py:93: in _call_setup_hook
    _call_callback_hooks(trainer, "setup", stage=fn)
.../software/miniconda3farm5/envs/cell2state_cuda118_torch22/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py:208: in _call_callback_hooks
    fn(trainer, trainer.lightning_module, *args, **kwargs)
scvi/model/base/_pyromixin.py:72: in setup
    pyro_guide(*args, **kwargs)
.../software/miniconda3farm5/envs/cell2state_cuda118_torch22/lib/python3.10/site-packages/pyro/nn/module.py:450: in __call__
    result = super().__call__(*args, **kwargs)
.../software/miniconda3farm5/envs/cell2state_cuda118_torch22/lib/python3.10/site-packages/torch/nn/modules/module.py:1511: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
.../software/miniconda3farm5/envs/cell2state_cuda118_torch22/lib/python3.10/site-packages/torch/nn/modules/module.py:1520: in _call_impl
    return forward_call(*args, **kwargs)
.../software/miniconda3farm5/envs/cell2state_cuda118_torch22/lib/python3.10/site-packages/pyro/infer/autoguide/guides.py:510: in forward
    self._setup_prototype(*args, **kwargs)
.../software/miniconda3farm5/envs/cell2state_cuda118_torch22/lib/python3.10/site-packages/pyro/infer/autoguide/guides.py:460: in _setup_prototype
    super()._setup_prototype(*args, **kwargs)
.../software/miniconda3farm5/envs/cell2state_cuda118_torch22/lib/python3.10/site-packages/pyro/infer/autoguide/guides.py:157: in _setup_prototype
    self.prototype_trace = poutine.block(poutine.trace(model).get_trace)(
.../software/miniconda3farm5/envs/cell2state_cuda118_torch22/lib/python3.10/site-packages/pyro/poutine/messenger.py:32: in _context_wrap
    return fn(*args, **kwargs)
.../software/miniconda3farm5/envs/cell2state_cuda118_torch22/lib/python3.10/site-packages/pyro/poutine/trace_messenger.py:216: in get_trace
    self(*args, **kwargs)
.../software/miniconda3farm5/envs/cell2state_cuda118_torch22/lib/python3.10/site-packages/pyro/poutine/trace_messenger.py:198: in __call__
    raise exc from e
.../software/miniconda3farm5/envs/cell2state_cuda118_torch22/lib/python3.10/site-packages/pyro/poutine/trace_messenger.py:191: in __call__
    ret = self.fn(*args, **kwargs)
.../software/miniconda3farm5/envs/cell2state_cuda118_torch22/lib/python3.10/site-packages/pyro/poutine/messenger.py:32: in _context_wrap
    return fn(*args, **kwargs)
.../software/miniconda3farm5/envs/cell2state_cuda118_torch22/lib/python3.10/site-packages/pyro/poutine/messenger.py:32: in _context_wrap
    return fn(*args, **kwargs)
.../software/miniconda3farm5/envs/cell2state_cuda118_torch22/lib/python3.10/site-packages/pyro/nn/module.py:450: in __call__
    result = super().__call__(*args, **kwargs)
.../software/miniconda3farm5/envs/cell2state_cuda118_torch22/lib/python3.10/site-packages/torch/nn/modules/module.py:1511: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
.../software/miniconda3farm5/envs/cell2state_cuda118_torch22/lib/python3.10/site-packages/torch/nn/modules/module.py:1520: in _call_impl
    return forward_call(*args, **kwargs)
tests/model/test_pyro.py:98: in forward
    mean = self.linear(x).squeeze(-1)
.../software/miniconda3farm5/envs/cell2state_cuda118_torch22/lib/python3.10/site-packages/pyro/nn/module.py:450: in __call__
    result = super().__call__(*args, **kwargs)
.../software/miniconda3farm5/envs/cell2state_cuda118_torch22/lib/python3.10/site-packages/torch/nn/modules/module.py:1511: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
.../software/miniconda3farm5/envs/cell2state_cuda118_torch22/lib/python3.10/site-packages/torch/nn/modules/module.py:1520: in _call_impl
    return forward_call(*args, **kwargs)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

self = PyroLinear(in_features=100, out_features=1, bias=True)
input = tensor([[ 6., 25.,  3.,  ..., 10., 22., 13.],
        [14.,  3., 14.,  ...,  0.,  6., 14.],
        [19.,  0.,  0.,  ....0.,  8.],
        [ 0.,  9.,  2.,  ..., 14.,  6.,  0.],
        [ 0.,  0.,  0.,  ..., 13.,  9., 12.]], device='cuda:0')

    def forward(self, input: Tensor) -> Tensor:
>       return F.linear(input, self.weight, self.bias)
E       RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument mat1 in method wrapper_CUDA_addmm)
E            Trace Shapes:        
E             Param Sites:        
E            Sample Sites:        
E               sigma dist |      
E                    value |      
E       linear.weight dist | 1 100
E                    value | 1 100
E         linear.bias dist | 1    
E                    value | 1

.../software/miniconda3farm5/envs/cell2state_cuda118_torch22/lib/python3.10/site-packages/torch/nn/modules/linear.py:116: RuntimeError
------------------------------------------------------------------ Captured stderr setup ------------------------------------------------------------------
Seed set to 0
------------------------------------------------------------------ Captured stderr call -------------------------------------------------------------------
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA A100-SXM4-80GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
-------------------------------------------------------------------- Captured log call --------------------------------------------------------------------
WARNING  jax._src.xla_bridge:xla_bridge.py:742 An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
==================================================================== warnings summary =====================================================================
../../../../../..//software/miniconda3farm5/envs/cell2state_cuda118_torch22/lib/python3.10/site-packages/lightning_utilities/core/imports.py:14
  /nfs/team283/vk7/software/miniconda3farm5/envs/cell2state_cuda118_torch22/lib/python3.10/site-packages/lightning_utilities/core/imports.py:14: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html

../../../../../..//software/miniconda3farm5/envs/cell2state_cuda118_torch22/lib/python3.10/site-packages/lightning/fabric/__init__.py:40
  .../software/miniconda3farm5/envs/cell2state_cuda118_torch22/lib/python3.10/site-packages/lightning/fabric/__init__.py:40: Deprecated call to `pkg_resources.declare_namespace('lightning.fabric')`.
  Implementing implicit namespace packages (as specified in PEP 420) is preferred to `pkg_resources.declare_namespace`. See https://setuptools.pypa.io/en/latest/references/keywords.html#keyword-namespace-packages

../../../../../..//software/miniconda3farm5/envs/cell2state_cuda118_torch22/lib/python3.10/site-packages/pkg_resources/__init__.py:2350
../../../../../..//software/miniconda3farm5/envs/cell2state_cuda118_torch22/lib/python3.10/site-packages/pkg_resources/__init__.py:2350
  .../software/miniconda3farm5/envs/cell2state_cuda118_torch22/lib/python3.10/site-packages/pkg_resources/__init__.py:2350: DeprecationWarning: Deprecated call to `pkg_resources.declare_namespace('lightning')`.
  Implementing implicit namespace packages (as specified in PEP 420) is preferred to `pkg_resources.declare_namespace`. See https://setuptools.pypa.io/en/latest/references/keywords.html#keyword-namespace-packages
    declare_namespace(parent)

../../../../../..//software/miniconda3farm5/envs/cell2state_cuda118_torch22/lib/python3.10/site-packages/lightning/pytorch/__init__.py:37
  .../software/miniconda3farm5/envs/cell2state_cuda118_torch22/lib/python3.10/site-packages/lightning/pytorch/__init__.py:37: Deprecated call to `pkg_resources.declare_namespace('lightning.pytorch')`.
  Implementing implicit namespace packages (as specified in PEP 420) is preferred to `pkg_resources.declare_namespace`. See https://setuptools.pypa.io/en/latest/references/keywords.html#keyword-namespace-packages

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
================================================================= short test summary info =================================================================
FAILED tests/model/test_pyro.py::test_pyro_bayesian_regression_low_level - RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument mat1...
============================================================== 1 failed, 5 warnings in 2.00s ==============================================================

Versions:

scvi 1.1.2
lightning 2.1.4
torch 2.2.1+cu118

@vitkl vitkl added the bug label Mar 18, 2024
@vitkl
Copy link
Contributor Author

vitkl commented Mar 18, 2024

This can probably be addressed using the following modification of the TrainRunner or to the model.train() method:

class TrainRunner:

    def __call__(self):
        # other code .....
        
        from copy import copy
        dl = copy(self.data_splitter)
        dl.setup()
        dl = dl.train_dataloader()
        PyroModelGuideWarmup(dl).setup(
            self.trainer, self.training_plan, stage="fit"
        )

        self.trainer.fit(
            self.training_plan, self.data_splitter, ckpt_path=self.ckpt_path
        )
        # other code .....

At this stage self.data_splitter.setup() has not been called yet and PyTorch Lightning expects to call self.data_splitter.setup() later. So we need to copy self.data_splitter, call self.data_splitter.setup() on a copy and create the dataloader needed for this callback.

Using model.train() is probably better because the model would have all parameters created before Lightning sees it.

@vitkl
Copy link
Contributor Author

vitkl commented Mar 20, 2024

Would be great to hear what you think @martinkim0 and I can add the proposed changes

@martinkim0
Copy link
Contributor

Hey sorry, took a look at this and forgot to respond. I think it makes sense to add the fixes to train instead of the TrainRunner since this will be specific to Pyro models. Happy to take a PR if you'd like to take a stab at it!

@vitkl
Copy link
Contributor Author

vitkl commented Mar 25, 2024

Sounds good! Later this week, I will make a PR about this issue - as well as another issue with the second GuideWarmup callback (pyro doesn't track deterministic variables initialised after setup).

I think we need to get rid of both pyro GuideWarmup callbacks and just run guide once in model.train(). This would break how people use them now but IMO a better solution.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants