Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DeepSensor with Pytorch Lightning #43

Open
nilsleh opened this issue Aug 31, 2023 · 12 comments
Open

DeepSensor with Pytorch Lightning #43

nilsleh opened this issue Aug 31, 2023 · 12 comments

Comments

@nilsleh
Copy link

nilsleh commented Aug 31, 2023

Hi, thank you for the fantastic work, this is very exciting. I am looking to apply the DeepSensor library to different problems, but have the aim to setup my experiments with pytorch lightning, because it reduces boiler plate code and offers lots of benefits with respect to code organization, experiment logging, gpu training etc. I am aware that DeepSensor aims to support both tensorflow and pytorch, and thus lightning might not have any priority but I think it would be a great addition to have a "template" on how to use lightning with DeepSensor, especially for research code. I was hoping I could lay out my current idea for such a template, but I have realized that I need some pointers, for which I'd be very grateful.

For lightning, you roughly need two puzzle pieces, a LightningModule that defines your training, validation steps, specifically how you compute the loss and a LightningDatamodule that is giving you a pytorch DataLoader for each stage.

Focusing on the latter first, I have taken the task_loader_tour to get a better understanding of what the TaskLoader can do. However, for training you need to generate a set of tasks first, which is done manually by calling TaskLoader repeatedly, and to create a training batch of tasks one uses the concat_tasks function which seems sort of analogous to a DataLoader collate function. But there are also additional data processing tasks inside the ConvNP module, specifically, the loss_fn() before the neuralprocess library is used to both conduct the forward pass and compute the loss.

Before, moving to lightning I was hoping to get a "standard" pytorch training loop going. To this end, I was thinking whether one could leverage a pytorch IterableDataset with something like this to give you tasks and subsequently have a DataLoader:

from typing import Iterator, Any
import deepsensor.torch
from torch.utils.data import IterableDataset
from deepsensor.model.convnp import remove_nans_from_task_Y_t_if_present
from deepsensor.model.nps import convert_task_to_nps_args
from deepsensor.model.convnp import ConvNP
from deepsensor.data.loader import TaskLoader

class TaskStreamDataset(IterableDataset):
    """Pytorch Dataset Class for Deep Sensor Task Streams."""

    def __init__(self, task_loader: TaskLoader, context_sampling=20, target_sampling=50, min_time: str = None, max_time: str = None) -> None:
        """Initialize a new instance of Task Stream Dataset.
        
        Args:
            task_loader: Defined task loader
            context_sampling: context sampling strategy for the task loader
            target_sampling: target sampling strategy for the task loader
            min_time: minimum time frame to sample
            max_time: maximum time frame to sample
        """
        super().__init__()

        self.task_loader = task_loader
        self.min_time = min_time
        self.max_time = max_time

        self.context_sampling = context_sampling
        self.target_sampling = target_sampling

    def __iter__(self) -> Iterator[dict[str, Any]]:
        """Define how data samples are generated."""

        for date in pd.date_range(self.min_time, self.max_time)[::1]:
            task = self.task_loader(date, context_sampling=self.context_sampling, target_sampling=self.target_sampling)
            task, nans_present = remove_nans_from_task_Y_t_if_present(task)

            task = ConvNP.check_task(task)

            context_data, xt, yt, model_kwargs = convert_task_to_nps_args(task)
            
            yield {"context_data": context_data, "xt": xt, "yt": yt}

This is of course assuming that a batch can be formed by the dataloader and if not, one would have to define a collate function, which I am assuming would look something like concat_tasks. Additionally, for num_workers>0 one has to also define the distribution to workers within the IterableDataset.

The idea is then to use a "standard" pytorch training scheme as follows:

from torch.utils.data import DataLoader
import neuralprocesses.torch as nps
import torch

data_loader = DataLoader(TaskStreamDataset(task_loader, context_sampling=20, target_sampling=50, min_time=min_time, max_time=max_time), batch_size = 4)

model = ConvNP(target_processor, task_loader)

optimizer = torch.optim.Adam(model.model.parameters(), lr=0.001)

for epoch in range(1):
    for batched_task in data_loader:
        optimizer.zero_grad()
        
        # compute loss
        logpdfs = nps.loglik(
            model,
            batched_task["context_data"],
            batched_task["xt"],
            batched_task["yt"],
        )

        train_loss = -torch.mean(logpdfs)

        train_loss.backward()
        optimizer.step()

However, here I am running into the following for which I am not sure how to get around.

File "/lib/python3.10/site-packages/plum/function.py", line 390, in __call__
    method, return_type = self.resolve_method(args, types)
  File "lib/python3.10/site-packages/plum/function.py", line 332, in resolve_method
    raise e
plum.resolver.NotFoundLookupError: For function `loglik`, `(<deepsensor.model.convnp.ConvNP object at 0x7f77898ce1d0>,

But if pytorch training could be conducted in that way, one could then move to lightning, by defining a DataModule that could have more init arguments to support different tasks etc. but basically look something like this:

import deepsensor.torch
from deepsensor.data.loader import TaskLoader
from lightning import LightningDataModule

class DeepSensorDataModule(LightningDataModule):
    """Lightning Data Module to serve data to train and evaluate Deep Sensor models."""

    def __init__(self, task_loader: TaskLoader):
        """Initialize a new instance of the Deep Sensor Lightning Data Module."""

        super().__init__()

        self.task_loader = task_loader

    def train_dataloader(self) -> IterableDataset:
        """Define Training Data Loading."""
        min_time, max_time = "2016-01-01", "2016-02-01"
        # return self.generate_tasks(min_time, max_time)
        return DataLoader(dataset=TaskStreamDataset(self.task_loader, min_time=min_time, max_time=max_time), batch_size=4)
    
    def val_dataloader(self) -> IterableDataset:
        """Define Validation Data Loading."""
        min_time, max_time = "2016-02-01", "2016-03-01"
        return DataLoader(dataset=TaskStreamDataset(self.task_loader, min_time=min_time, max_time=max_time), batch_size=4)
    
    def test_dataloader(self) -> IterableDataset:
        """Define Test Data Loading."""
        min_time, max_time = "2016-03-01", "2016-04-01"
        return DataLoader(dataset=TaskStreamDataset(self.task_loader, min_time=min_time, max_time=max_time), batch_size=4)

And a LightningModule, here just shown in a very basic form:

import torch
from torch import Tensor
import deepsensor.torch
from lightning import LightningModule
import neuralprocesses.torch as nps

from deepsensor.model.convnp import ConvNP

class DeepSensorModule(LightningModule):
    """Lightning Module to train Deep Sensor models."""

    def __init__(self, model: ConvNP):
        """Initialize a new instance of the Deep Sensor Module.
        
        Args:
            model: model to train
        """
        super().__init__()

        self.model = model

    def training_step(self, batched_task: Task, batch_idx) -> Tensor:
        """Define training step."""
        logpdfs = nps.loglik(
            self.model,
            batched_task["context_data"],
            batched_task["xt"],
            batched_task["yt"],
        )

        train_loss = -torch.mean(logpdfs)

        # logging here

        return train_loss

    def configure_optimizers(self):
        """Configure optimizers."""
        return torch.optim.Adam(self.model.model.parameters(), lr=0.001)

Because then, one can leverage the Trainer with all its flexibility to conduct training and evaluation:

task_loader = define_task_loader()
model = ConvNP(task_loader)

dm = DeepSensorDataModule(task_loader)
module = DeepSensorModule(model)
trainer = Trainer(max_epochs=10, devices=[0])
trainer.fit(module, dm)

Apologies for the long post, and also if all of this is something you have tried/or are working on already. However, I would be grateful for any other pointers towards what you think about this idea. If you find this interesting, I would be happy to provide more details or implement this more formally in a PR in what form you see fit.

@tom-andersson
Copy link
Collaborator

Hi @nilsleh, thank you for opening this issue and for all the details! Firstly, regarding the plum error, can you post the rest of the stack trace? The error means that multiple dispatch failed to find a matching nps.loglik method for the arguments and types provided.

Secondly, from a bigger picture, further thought is needed on whether integrating with PyTorch lightning is in the scope of deepsensor. But at very least, there should be some documentation on how to get lightning working. Once you've got your basic implementation working, let's discuss the pros and cons.

Note, the code you've provided is specific to the ConvNP class, so the docstring """Lightning Module to train Deep Sensor models.""" isn't quite correct. In general (and if possible) we'd like functionality in DeepSensor to be model agnostic. This is the case in the simple train_epoch method provided, since it just calls the model.loss_fn method (which is a required method in the model interface, although not yet enforced).

Regarding wrapping the TaskLoader in a more standard generator-style class, there is an issue on this: #24. Similar to the above, there is an open question about whether we want to hard-code functionality like this into DeepSensor and make some simplifying assumptions, or if we want to encourage users to implement things themselves in the custom way they want.

On the above, if the added functionality is useful and doesn't sacrifice too much flexibility, then that's a good argument to add it.

@nilsleh
Copy link
Author

nilsleh commented Sep 4, 2023

Hi @tom-andersson, thanks for your reply. I have tried to create a minimal reproducible example as a google colab notebook here that downloads some data from meteonet just to try and get the "mechanics" to work. Maybe that can enhance the discussion. At the moment, I have just implemented the pytorch training loop that shows the mentioned error. Once that is resolved I would add the lightning training loop.

@tom-andersson
Copy link
Collaborator

Hi @nilsleh, thank you for the colab notebook, that's very helpful. I've had a look at the 'standard PyTorch training loop' part and found the cause of the plum error. It's because you are passing the deepsensor ConvNP model object to nps.loglik, rather than the raw PyTorch model constructed by the neuralprocesses package. You can access the raw PyTorch model with the ConvNP.model attribute (so you need model.model there).

Once you fix this, you'll bump into two other issues caused by the torch.data.utils.DataLoader:

  1. The tensors returned by your DataLoader contain two batch dimensions. This is because ConvNP.check_task prepends a size-1 batch dimension (as well as convert to PyTorch tensors and mask any NaNs), and so the automatic batching of the PyTorch DataLoader prepends a second batch dim. You can turn off the automatic batching with batch_size=None. Note, if you want to batch tasks, you should handle this within TaskStreamDataset and use deepsensor.model.convnp.concat_tasks, which will deal with padding when there's a variable number of observations between tasks.
  2. The DataLoader converts the context_data entry from a list of tuples to a list of lists, which leads to a very esoteric neuralprocesses error. You instead have to convert it back to a tuple of lists with batched_task["context_data"] = [tuple(x) for x in batched_task["context_data"]]

After all three of the above changes, the code runs.

P.S. You should use the same DataProcessor object for all your data, rather than a separate one for each variable.

@nilsleh
Copy link
Author

nilsleh commented Sep 6, 2023

Thanks a lot for your feedback and making it work, I updated the colab notebook accordingly. Looking at the batching of tasks and using the model.loss_fn I refactored the code to do the batching process via a collate_fn to the pytorch DataLoader which simply calls the deep sensor concat_tasks function and simplifies the IterableDataset.

Additionally, I have also implemented the logic for what I think is needed to have a DataLoader with num_workers>0 since that is often the bottleneck in training pipelines. It works but I want to check that more rigorously since there are some caveats with IterableDataset.

Another question I have run in, is how GPU training is supposed to work "out of the box" with a pytorch or lighnting training procedure. I found the set_gpu_default_device function, however, this hard codes to the 0 device so not as flexible for multi-gpu training. I have some idea where putting everything on gpu could go in the code for pytorch or lightning module, but beforehand wanted to ask how you thought about doing training on GPU should look like from the DeepSensor perspective?

@tom-andersson
Copy link
Collaborator

Thanks for the updates @nilsleh,

I have some idea where putting everything on gpu could go in the code for pytorch or lightning module, but beforehand wanted to ask how you thought about doing training on GPU should look like from the DeepSensor perspective?

I'm not sure I understand your question. IIUC, we just need to generalise set_gpu_default_device to support multiple GPUs. For example, we could have this interface to use GPUs 0, 1, and 2: set_default_device(device="gpu", id=[0, 1, 2]).

Under the hood, set_gpu_default_device uses the backends library to set the device using PyTorch or TF depending on what version of DeepSensor the user has imported. However, on a brief skim of the backends code that handles devices, it appears to only relate to single devices. Maybe @wesselb can confirm?

@nilsleh
Copy link
Author

nilsleh commented Sep 6, 2023

In lightning normally the only thing one has to do is to set the devices flag in the Trainer object with the desired GPUs and lightning takes care of the rest. For example the data is put on the desired device before it arrives in the training_step or validation_step and at that point the model is also on the same device. However, if the data processing steps in model.loss_fn() and the forward pass thereafter are being taken care of by the backends library anyway, then it might even suffice to just call set_default_device somewhere at the start of lightning training and lightning will not complain. Not sure if I explained that well enough, but I will continue to play around with it.

Edit: And at the moment just calling set_gpu_default_device() assuming single gpu training with lightning leads to RuntimeError: Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, you must use the 'spawn' start method so I think at different moments things are moved to GPU between lightning and the backend so I think there has to be some work within the lightning module to make this work for single gpu training as well.

@wesselb
Copy link
Collaborator

wesselb commented Sep 7, 2023

@tom-andersson backends indeed only handles single GPUs, and this is intentional. When using multi-GPU training, first you need to decide how you want to spread things across GPUs and handle communication between GPUs. The most common choice here is DDP (distributed data parallel), where different GPUs compute the forward and backward pass for different elements on the batch. A more complex alternative strategy is model parallelism, where the same model forward occupies multiple GPUs. @nilsleh most likely intends to do DDP, so let’s focus on that. (Please correct me if I’m wrong!)

In PyTorch Lightning’s DDP strategy, every GPU is managed by a separate Python process. This Python process holds the whole model in memory and computes the forward and backward for particular batch elements. After computing the forward and backward, the process communicates with all other process to update and synchronise the model parameters. Importantly, in this Python process, the device is always set to a single GPU, which is the GPU associated to the Python process.

But this is all very conceptual. In reality, all of this is abstracted away by PyTorch Lightning. So how does it work then?

The answer is simple: the models in neuralprocesses are standard PyTorch models, so you should not try to set the device via backends, but let PTL handle the DDP strategy and associated device management. (Note that it is possible and not too hard to implement DDP without PTL. If you’d like that, a different approach is needed.)

Now, this should work, but doesn’t quite seem to, because of the forking problem. I think it’s worthwhile diving into that. My guess is that somewhere the device is set to CUDA where that shouldn’t be done: you’ll need to completely hand over control to PTL.

@nilsleh
Copy link
Author

nilsleh commented Sep 14, 2023

@wesselb @tom-andersson From my perspective, my main aim was to find a pytorch lightning setup that could facilitate the training of models via DeepSensor. Single GPU support, via just setting the Trainer flag would be awesome already. But keeping in mind how simple it is change to DDP training in lightning, I thought that would be something worthwhile to consider when creating a template to keep required code changes of a potential user minimal, but DDP could also just be something to consider at some later point.

@wesselb
Copy link
Collaborator

wesselb commented Sep 20, 2023

@nilsleh My impression is that single-GPU and multi-GPU training via PTL DDP should most certainly be possible! I think it's a matter of getting to the bottom of the CUDA multiprocessing error, where perhaps the right approach is to defer all device management 100% to PTL.

@nilsleh
Copy link
Author

nilsleh commented Sep 22, 2023

@wesselb @tom-andersson I have updated the colab notebook with my progress of GPU training so far and can try to summarize what I found so far. Maybe that is helpful.

Writing a pytorch training loop (no deepsensor training utitlities) and doing the moving of tensors to device manually with a transfer_batch_to_device() function and putting the ConvNP on device as well, works. Attempts of using set_gpu_default_device somewhere in this pipeline failed.

With the lightning approach I observed that when setting the trainer to gpu training, it says that cuda is being used, however, the training actually runs on cpu. I am guessing that the deepsensor.Task is not transferred to device properly.

When handling the data device movement in the datamodule by overwriting transfer_batch_to_device(), it seems to do the job, however, the model is not being put on the device correctly. More specifically, the decoder part remains on cpu.

From my understanding moving the model weights to device happens in the configure_optimizer stage in the lightning pipeline, but it seems that it does not handle it completely given the object structure of the ConvNP.

@tom-andersson
Copy link
Collaborator

Hey @nilsleh, sorry for the delay. Have you made any progress on this? You're more familiar with PTL than me, so I can't be of much help here. It's hard to know specifically what's going wrong. Under the hood of the ConvNP we are using @wesselb's neuralprocesses package to construct the model. The model is itself just pure PyTorch, albeit a quite custom model, so it seems possible that something is interfering with PTL transferring the model to the device.

@wesselb suggested deferring all device management to PTL. In that case, you may want to avoid using deepsensor.train.set_gpu_default_device.

Regarding the Task potentially not being passed to the device, my only thought is that by default the Task data will exist as numpy until it is passed to the ConvNP, at which point the model calls task.convert_to_tensor. For deepsensor.torch, this just calls torch.tensor on all the numpy data. Again, not sure how relevant that is, but if PTL is sensitive to when tensor conversions occur then you may want to call task.convert_to_tensor() before passing to the ConvNP

@nilsleh
Copy link
Author

nilsleh commented Oct 17, 2023

Hi Tom, no worries. Yes, I have done some more work on this, however, won't have time to properly update it until after the vacation.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants