Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Run metrics calculation in a background process #3244

Closed
H4dr1en opened this issue May 2, 2024 · 8 comments
Closed

Run metrics calculation in a background process #3244

H4dr1en opened this issue May 2, 2024 · 8 comments
Labels

Comments

@H4dr1en
Copy link
Contributor

H4dr1en commented May 2, 2024

❓ Questions/Help/Support

I am training neural networks using pytorch-ignite for a computer vision task. Currently, calculating metrics during validation takes a significant part of the validation time and I want to optimise this, so that most of the time is spent in GPU inference operations

What I have in mind:

  1. Move the computation of the metrics from the validation step to a background process/queue
  2. Schedule the computation of the metrics after each iteration by sending the results of the inference to the metrics queue

This would work out of the box but I also use several handlers at the end of each validation epochs (logging/checkpoint/lr scheduler/early stopping) that depend on the metrics, so I need to have a sync at the end of the epoch to wait for the metrics computation to finish before triggering the various handlers.

As an alternative, I could rework each metric to make them faster, potentially using GPU, but this is also a long process as I would need to do it for each current and new metrics - I'd rather have them computed in the background, so that I don't have to care about the efficiency of the implementation

My questions are:

  1. What would be the cleanest way to implement this using pytorch-ignite?
  2. If pytorch-ignite doesn't provide a nice interface for this use case, would it make sense to extend the library to support it?
  3. Is there a blind spot that I am missing - is there a better way to deal with this situation?
@vfdev-5
Copy link
Collaborator

vfdev-5 commented May 2, 2024

@H4dr1en thanks for asking this interesting question!

The point 1 "Move the computation of the metrics from the validation step to a background process/queue" looks interesting, but here I mostly wonder about computing resources usage. I'm thinking about the following:

  • training is started and model have done one epoch of updates
  • next we clone the model for the validation and schedule a background validation of this copy
  • in parallel we continue the training with the original model.
    The question here is that wont we OOM on the training or background validation as we'll be using the same GPU(s)

The point 2 "Schedule the computation of the metrics after each iteration by sending the results of the inference to the metrics queue" I can read as computing the metrics on the training dataset using constantly updated model. I'm not very sure whether this is a good point. Final value wont correspond to the latest model...

As for scheduling a new process for metrics computation, let me think a bit what would be an implementation with ignite.
and then we may follow your point "If pytorch-ignite doesn't provide a nice interface for this use case, would it make sense to extend the library to support it?".

Currently, calculating metrics during validation takes a significant part of the validation time and I want to optimise this, so that most of the time is spent in GPU inference operations

Thinking about this statement, can't be possible to run validation with a larger intervals such that training time can be longer. For example

- @trainer.on(Events.EPOCH_COMPLETED)
+ @trainer.on(Events.EPOCH_COMPLETED(every=100))
def run_validation():
    ...

@H4dr1en
Copy link
Contributor Author

H4dr1en commented May 2, 2024

Hi @vfdev-5 ,

Thanks for the fast answer! I think there is small confusion, I don't want to move the whole validation logic in a background process, only the computation of the metrics. So both training and validation step would still run in the main process. The only part that I want to differ to a background process to unblock the training/validation is the computation of the metrics:

@validator.on(Events.ITERATION_COMPLETED)
def compute_metrics(validation_engine):
    # Current state: long running, CPU bound
    engine.state.output = compute_metrics(valdation_engine.state.output)

    # My idea: send to background to unblock the rest of the program
    # But I don't know how it would play with updating of the metrics (eg. RunningAverage)
    metrics_computation_queue.put(valdation_engine.state.output)

Then at the end of each validation epoch, I would wait and collect the metrics so that handlers depending on them can run:

@validator.on(Events.EPOCH_COMPLETED)  # Must be the first of these events because others might depend on state.metrics
def aggregate_iteration_metrics(validation_engine):

    # Here I would pull from the result queue
    metric_results = metrics_results_queue.get()
  
    # And somehow integrate them/trigger the metrics like RunningAverage
    RunningAverage.step(metric_results)   # I have no idea on how to do this at this point

Thinking about this statement, can't be possible to run validation with a larger intervals such that training time can be longer. For example

This would for sure help, but at the cost of more sparse validation curves and need for different LRScheduler and EarlyStopping values

@vfdev-5
Copy link
Collaborator

vfdev-5 commented May 2, 2024

So, to confirm, ideally, engine.state.output = compute_metrics(valdation_engine.state.output) should run in another process and at metric_results = metrics_results_queue.get() we join the process and get all results ?
Each iteration will submit a new task and on join call will have to wait once all tasks are done.

@H4dr1en
Copy link
Contributor Author

H4dr1en commented May 2, 2024

Yes exactly 👍

@vfdev-5
Copy link
Collaborator

vfdev-5 commented May 3, 2024

@H4dr1en here is a prototype of running handlers in a process pool:

import time

import torch
from ignite.engine import Engine, Events
from ignite.utils import setup_logger, logging

import torch.multiprocessing as mp


def long_running_computation(data):
    m = data["c"]
    v = m.sum()
    for _ in range(10000):
        v = v - m.mean()
        v = m.log_softmax(dim=1).sum() + v

    return v + data["a"] + data["b"]


def run():

    torch.manual_seed(0)
    eval_data = range(10)

    for with_mp in [True, False]:

        def eval_step(engine, batch):
            # forward pass latency
            time.sleep(0.5)
            print(f"{engine.state.epoch} / {engine.state.max_epochs} | {engine.state.iteration} - batch: {batch}", flush=True)
            return {
                "a": torch.tensor(engine.state.iteration, dtype=torch.float32),
                "b": torch.rand(()).item(),
                "c": torch.rand(128, 5000),
            }

        validator = Engine(eval_step)

        # pick a reasonable value of workers:
        if with_mp:
            pool = mp.Pool(processes=2)

        validator.state.storage = []

        @validator.on(Events.ITERATION_COMPLETED)
        def do_long_running_computation():
            if with_mp:
                validator.state.storage.append(
                    pool.apply_async(long_running_computation, (validator.state.output,))
                )
            else:
                validator.state.storage.append(
                    long_running_computation(validator.state.output)
                )

        @validator.on(Events.EPOCH_COMPLETED)
        def gather_results():
            if with_mp:
                validator.state.storage = [
                    r.get() for r in validator.state.storage
                ]            
            validator.state.metrics["abc"] = sum(validator.state.storage)


        start = time.time()
        validator.run(eval_data)
        elapsed = time.time() - start
        print("Elapsed time:", elapsed)
        if with_mp:
            pool.close()
            pool.join()
            
            
if __name__ == "__main__":
    run()

Output:

python -u script.py

1 / 1 | 1 - batch: 0
1 / 1 | 2 - batch: 1
1 / 1 | 3 - batch: 2
1 / 1 | 4 - batch: 3
1 / 1 | 5 - batch: 4
1 / 1 | 6 - batch: 5
1 / 1 | 7 - batch: 6
1 / 1 | 8 - batch: 7
1 / 1 | 9 - batch: 8
1 / 1 | 10 - batch: 9
Elapsed time: 22.257242918014526
1 / 1 | 1 - batch: 0
1 / 1 | 2 - batch: 1
1 / 1 | 3 - batch: 2
1 / 1 | 4 - batch: 3
1 / 1 | 5 - batch: 4
1 / 1 | 6 - batch: 5
1 / 1 | 7 - batch: 6
1 / 1 | 8 - batch: 7
1 / 1 | 9 - batch: 8
1 / 1 | 10 - batch: 9
Elapsed time: 26.21021580696106

Number of pool processes should be taken carefully as pytorch ops could be multi-threaded and all that can lead to perf degradation if using too much processes.
Let me know if this is what you were thinking of?

In case we would like to add something similar to ignite API, we have to think carefully about the public API...

@H4dr1en
Copy link
Contributor Author

H4dr1en commented May 15, 2024

Hi @vfdev-5 , thanks for this super example 👍 yes it covers most of my needs!

I have some questions:

  • What is the engine.state.storage, is it something internal? How does it work?
  • I am using ClearMLLogger and multiple Average metrics, how would they interact with the example you provided? More specifically: here we are writing the metric values directly to the engine.state.metrics, would the metrics/logger properly pick up the values of each iteration? How to ensure it?

@vfdev-5
Copy link
Collaborator

vfdev-5 commented May 15, 2024

What is the engine.state.storage, is it something internal? How does it work?

It is just a user-defined list manually created on an Engine.state: validator.state.storage = [], not something ignite internal.

More specifically: here we are writing the metric values directly to the engine.state.metrics, would the metrics/logger properly pick up the values of each iteration? How to ensure it?

Yes, loggers if configured to log metrics are taking values from engine.state.metrics:

metrics_state_attrs = OrderedDict(engine.state.metrics)

To ensure that loggers picks the value, you have to add its handler after gather_results handler.
While debugging, you can check that handlers on the event are set in the desired order, for example:

event = Events.EPOCH_COMPLETED   # or Events.ITERATION_COMPLETED
print(engine._event_handlers[event])

@H4dr1en
Copy link
Contributor Author

H4dr1en commented May 15, 2024

That's perfect 💯
Closing the issue for now, this should do it 👍

@H4dr1en H4dr1en closed this as completed May 15, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants