Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

leaky_integrator_box not working with multiple GPUs #378

Open
gaurvigoyal opened this issue Sep 13, 2023 · 9 comments
Open

leaky_integrator_box not working with multiple GPUs #378

gaurvigoyal opened this issue Sep 13, 2023 · 9 comments

Comments

@gaurvigoyal
Copy link

gaurvigoyal commented Sep 13, 2023

Describe the bug
The function leaky_integrator_box does not seem to be set up for multi-gpu training on native pytorch, at least in data parallel mode.

To Reproduce
Steps to reproduce the behavior:
I imagine it can only be replicated on a system with multiple GPUs

  1. Take any model set up in torch.
  2. Add the following lines to your code (where you set up the device):
        if torch.cuda.device_count() > 1:
            print("Let's use", torch.cuda.device_count(), "GPUs!")
            model = torch.nn.DataParallel(model) 
  1. Run the model
  2. See error: RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1!

Expected behavior
The different sampled should go forward and backward on their specified GPUs.

Desktop (please complete the following information):

  • OS: Docker image nvidia/cuda:11.5.2-cudnn8-devel-ubuntu20.04

Additional context
Full error:
Original Traceback (most recent call last):
File "/usr/local/lib/python3.8/dist-packages/torch/nn/parallel/parallel_apply.py", line 61, in _worker
output = module(*input, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/projects/edpr_hpe/movenet.pytorch/lib/models/movenet_stencil.py", line 348, in forward
x = self.backbone(x) # n,24,48,48
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/projects/edpr_hpe/movenet.pytorch/lib/models/movenet_stencil.py", line 198, in forward
f1, state = self.features1(timestep, state)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/projects/edpr_hpe/norse/norse/torch/module/sequential.py", line 108, in forward
input_tensor, s = module(input_tensor, state[index])
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/projects/edpr_hpe/norse/norse/torch/module/receptive_field.py", line 144, in forward
return self.neurons(x_repeated, state)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/projects/edpr_hpe/norse/norse/torch/module/snn.py", line 88, in forward
return self.activation(input_tensor, state, self.p, self.dt)
File "/projects/edpr_hpe/norse/norse/torch/functional/leaky_integrator_box.py", line 89, in li_box_feed_forward_step
dv = dt * p.tau_mem_inv * ((p.v_leak - state.v) + input_tensor)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1!

Please let me know if any more information is needed.

Is there a more appropriate (but still relatively simple) way to adapt a Norse code for multi-GPU training?

@gaurvigoyal
Copy link
Author

@Jegp

@gaurvigoyal
Copy link
Author

gaurvigoyal commented Sep 13, 2023

Right, so the DataParallel assumed that the first dimension is batch. But with Norse, the first dimension is timestep. This is causing the sampled to be scattered in an incorrect way across the GPUs. I have now tried using DataParallel(dim=1), and that is a very different error:

torch.jit._trace.TracingCheckError: Tracing failed sanity checks!
ERROR: Graphs differed across invocations!

@Jegp
Copy link
Member

Jegp commented Sep 13, 2023

Thanks for reporting this @gaurvigoyal! Ok, so are you saying that the parallelization is not an issue? Because, I can actually imagine it breaking down if the tensor is sent to multiple devices after initialization.

Regarding the tracing, is it correct that this only happens during backprop? One potential problem could be that the autodiff graph isn't properly "cleared" between timesteps. Could you share the code you use to optimize the model? Happy to take it offline as well.

@gaurvigoyal
Copy link
Author

Hey @jens, Thanks for responding so quickly. Have you tried DataParallel training or any other multi-GPU setup with Norse yet? With (dim=1), Pytorch scatters the data as per the batch dimension now, so that problem seems to be solved, for the data, but I guess it doesn't work for the model, leading to the issue with the trace.

I am working on my movenet.pytorch repository in the spiking-data-loader branch (here).

The resulting error log contains a graph diff on that trace that's too long, but here is the full error:

Traceback (most recent call last):
  File "train.py", line 50, in <module>
    main(cfg)
  File "train.py", line 40, in main
    run_task = Task(cfg, model)
  File "/projects/edpr_hpe/movenet.pytorch/lib/task/task.py", line 59, in __init__
    self.tb.add_graph(self.model, torch.randn(40, 1, 1, 192, 192).to(self.device))
  File "/usr/local/lib/python3.8/dist-packages/torch/utils/tensorboard/writer.py", line 736, in add_graph
    self._get_file_writer().add_graph(graph(model, input_to_model, verbose, use_strict_trace))
  File "/usr/local/lib/python3.8/dist-packages/torch/utils/tensorboard/_pytorch_graph.py", line 289, in graph
    trace = torch.jit.trace(model, args, strict=use_strict_trace)
  File "/usr/local/lib/python3.8/dist-packages/torch/jit/_trace.py", line 741, in trace
    return trace_module(
  File "/usr/local/lib/python3.8/dist-packages/torch/jit/_trace.py", line 983, in trace_module
    _check_trace(
  File "/usr/local/lib/python3.8/dist-packages/torch/autograd/grad_mode.py", line 28, in decorate_context
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/jit/_trace.py", line 526, in _check_trace
    raise TracingCheckError(*diag_info)
torch.jit._trace.TracingCheckError: Tracing failed sanity checks!
ERROR: Graphs differed across invocations!

My first guess is maybe the data and model need to be parallel-ized separately? I'd hoped Pytorch and Norse would together handle this internally.

@Jegp
Copy link
Member

Jegp commented Sep 14, 2023

Happy to help! Did you see I updated the receptive fields in a PR? I fixed some distributions and added some cleverness that should help you get better performance both in time and accuracy.

I can't seem to reproduce your error, I'm afraid. I tried a small toy example pasted below, and it worked for me. Could I ask you to try it out and see if it works for you?

import torch
import norse.torch as norse

model = norse.SequentialState(norse.LIBoxCell(), torch.nn.Linear(1, 10)).to("cuda:0")
par = torch.nn.DataParallel(model, dim=1)
par(torch.empty(100, 100, 100, 1).to("cuda:1"))

Another separate point is that I have much better experience in using Pytorch Lightning for the parallelization, because they're taking care of mapping both the data and model to various devices. I pasted a small (pretty dumb) example below. The "parallelization magic" comes from using the pl.Trainer(..., strategy=pl.strategies.DDPStrategy()) (for Distributed Data Parallel). Which is a single-device multi-processing strategy, which could maybe also work in your case. Would that be an option? Although, I realize that may require a bit to rewrite your code :P

import os
from torch import optim, nn, utils, Tensor
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
import lightning.pytorch as pl

import norse.torch as norse

# define any number of nn.Modules (or use your current ones)
encoder = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3))
decoder = norse.SequentialState(norse.LIBoxCell(), nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28))


# define the LightningModule
class LitAutoEncoder(pl.LightningModule):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder

    def training_step(self, batch, batch_idx):
        # training_step defines the train loop.
        # it is independent of forward
        x, y = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        s = None
        for i in range(2):
            x_hat, s = self.decoder(z, s)
        loss = nn.functional.mse_loss(x_hat, x)
        # Logging to TensorBoard (if installed) by default
        self.log("train_loss", loss)
        return loss

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=1e-3)
        return optimizer


# init the autoencoder
autoencoder = LitAutoEncoder(encoder, decoder)

dataset = MNIST(os.getcwd(), download=True, transform=ToTensor())
train_loader = utils.data.DataLoader(dataset)

# train the model (hint: here are some helpful Trainer arguments for rapid idea iteration)
trainer = pl.Trainer(limit_train_batches=100, max_epochs=1, strategy=pl.strategies.DDPStrategy())
trainer.fit(model=autoencoder, train_dataloaders=train_loader)

@gaurvigoyal gaurvigoyal reopened this Sep 14, 2023
@gaurvigoyal
Copy link
Author

Right. I don't remember closing this, must've been a mistake (I clicked in "close with comment", I guess). Thanks! I hadn't seen the PR! Clever things indeed :) I'll incorporate that.
And I'll try these and get back to you! Thanks!

@Jegp
Copy link
Member

Jegp commented Jan 8, 2024

Any news @gaurvigoyal? Anything I can do?

@gaurvigoyal
Copy link
Author

Hey @Jegp, I rewrote the code to lightning. But distributing it over multiple GPUs was still running into errors. At some point, other projects get higher priority and this went on the back burner. At this point I don't have as much to dedicate as this needs. Did you already publish the paper on these integrators?

@Jegp
Copy link
Member

Jegp commented May 2, 2024

Hey @Jegp, I rewrote the code to lightning. But distributing it over multiple GPUs was still running into errors. At some point, other projects get higher priority and this went on the back burner. At this point I don't have as much to dedicate as this needs. Did you already publish the paper on these integrators?

Hey @gaurvigoyal, we just put up a preprint on exactly this: https://arxiv.org/abs/2405.00318
I know it's like a couple of months late, but I'm happy to pick up the discussion :-)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants