Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Start / Stop / Dump trace hooks for NCCL profiler for a tracing ecosystem integration #1210

Open
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

sanrise
Copy link

@sanrise sanrise commented Mar 5, 2024

NCCL already provides a profiling/tracing capability to record various operations during collectives including setting up buffers, sending data to and from GPU etc. Today it uses a compile time flag and traces the whole application and does not support any start and stop knobs.

Having such a control allows the application layer to gain temporary visibility into collective communication. This allows us to introduce abstractions within PyTorch's (libkineto) profiler to do things like (a) report NCCL traces along with GPU traces for only a subset of training iterations or (b) simply trace collectives simulated using torch.distributed; to gain efficient, fine grain visibility of collective operations.

This change will add such lifecycle management interfaces (start/stop/dump-trace) to the existing profiler along with other improvements to the collection.

Enhancements

  • Annotate the start and stop of the overall collective and provide collective name in the traces.
  • Missing chunk/data size measurement.
  • Add NCCL API markers.
  • Improve clean up for profiler and collective event buffers.
  • Make trace dumping not dependent on collective markings.

Example:
Upon Kineto integration, the PyTorch profiler module can generate NCCL traces. This file was generated by the profiler:

[
{"name": "allreduce-1", "cat": "COL", "id": 1, "ph": "b", "pid": -1, "tid": 1, "ts": 1.363281 },
{"name": "allreduce-1", "cat": "COL", "id": 1, "ph": "e", "pid": -1, "tid": 1, "ts": 0.000000 },
{"name": "wait-2", "cat": "COL", "id": 2, "ph": "b", "pid": -1, "tid": 1, "ts": 169.007812 },
{"name": "wait-2", "cat": "COL", "id": 2, "ph": "e", "pid": -1, "tid": 1, "ts": 0.000000 },
{"name": "reduce-3", "cat": "COL", "id": 3, "ph": "b", "pid": -1, "tid": 1, "ts": 258.093750 },
{"name": "reduce-3", "cat": "COL", "id": 3, "ph": "e", "pid": -1, "tid": 1, "ts": 0.000000 },
{"name": "wait-4", "cat": "COL", "id": 4, "ph": "b", "pid": -1, "tid": 1, "ts": 359.410156 },
{"name": "wait-4", "cat": "COL", "id": 4, "ph": "e", "pid": -1, "tid": 1, "ts": 0.000000 },
{"name": "broadcast-5", "cat": "COL", "id": 5, "ph": "b", "pid": -1, "tid": 1, "ts": 419.800781 },
{"name": "broadcast-5", "cat": "COL", "id": 5, "ph": "e", "pid": -1, "tid": 1, "ts": 0.000000 },
{"name": "wait-6", "cat": "COL", "id": 6, "ph": "b", "pid": -1, "tid": 1, "ts": 507.437500 },
{"name": "wait-6", "cat": "COL", "id": 6, "ph": "e", "pid": -1, "tid": 1, "ts": 0.000000 },
{"name": "_reduce_scatter_base-7", "cat": "COL", "id": 7, "ph": "b", "pid": -1, "tid": 1, "ts": 575.015625 },
{"name": "_reduce_scatter_base-7", "cat": "COL", "id": 7, "ph": "e", "pid": -1, "tid": 1, "ts": 0.000000 },
{"name": "wait-8", "cat": "COL", "id": 8, "ph": "b", "pid": -1, "tid": 1, "ts": 659.773438 },
{"name": "wait-8", "cat": "COL", "id": 8, "ph": "e", "pid": -1, "tid": 1, "ts": 0.000000 },
{"name": "_allgather_base-9", "cat": "COL", "id": 9, "ph": "b", "pid": -1, "tid": 1, "ts": 720.613281 },
{"name": "_allgather_base-9", "cat": "COL", "id": 9, "ph": "e", "pid": -1, "tid": 1, "ts": 0.000000 },
{"name": "wait-10", "cat": "COL", "id": 10, "ph": "b", "pid": -1, "tid": 1, "ts": 801.953125 },
{"name": "wait-10", "cat": "COL", "id": 10, "ph": "e", "pid": -1, "tid": 1, "ts": 0.000000 },
{} ]

In above example, we simply profiled this PyTorch snippet:

with profile(
        activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
        with_stack=True,
        # on_trace_ready=trace_handler,
    ) as prof:
            dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
            dist.reduce(tensor, dst=0, op=dist.ReduceOp.SUM)
            dist.broadcast(tensor, src=0)
            dist.reduce_scatter_tensor(
                reduce_scatter_output, reduce_scatter_input, op=dist.ReduceOp.SUM
            )
            dist.all_gather_into_tensor(all_gather_output, reduce_scatter_output)

(These changes have been authored and iterated on by @briancoutinho and various engineers at Meta over a period of time)

NCCL provides a profiling/tracing capability to record various operations during collectives including setting up buffers, sending data to and from GPU etc. This change will enable us to control NCCL profiling from the application layer through a start/stop interface.

Enhancements
* It uses a compile time flag and traces the whole application. So it does not support start and stop API.
* Does not annotate the start and stop of the overall collective and provide collective name.
* Missing chunk/data size measurement.
* Add nccl API markers.
* Improve clean up for profiler and collective event buffers.
* Make trace dumping not dependent on collective markings.
* Future enhancements will include sampling to enable always-on
  collection.
@sanrise sanrise changed the title Modify NCCL profiler invocation, collection, dump Start / Stop / Dump trace hooks for NCCL profiler for a tracing ecosystem integration Mar 5, 2024
@sanrise sanrise marked this pull request as ready for review March 5, 2024 00:20
@sjeaugey
Copy link
Member

sjeaugey commented Mar 5, 2024

Can you explain more in details how this integrates into Kineto? I'm failing to see how Kineto would call ncclProfilerStart/ncclProfilerStop, and also how it would get it's profiling data back into the global profiling (we generate a file, but will this file become part of the global profiling?).

@sanrise
Copy link
Author

sanrise commented Mar 5, 2024

Hey Sylvain, thanks for taking the time out to understand this.

Can you explain more in details how this integrates into Kineto?

We plan to integrate these start stop methods to our other tooling as well. Although, Kineto is the main user.

For Kineto, on a high level, we implement theIActivityProfilerSession interface within Kineto. In fact, these methods of start/stop/dump are influenced from this interface. Then within Kineto we have an ability to start such child profilers along with the main CUPTI profiler using a registerProfilerFactory(...) on instantialtion, this is also where all our other such internal child profilers are enrolled and their lifecycle is tied to the main (parent) profiler. We implement the IActivityProfilerSession interface by wrapping these new NCCL profiler methods.

Now when you wrap your model in a torch.profile context manager as shown in the example above, Kineto will start, which starts all these child profilers. When the model context exits - the main profiler is shut down, causing all children to shut down too. We have also implemented tooling to, in-fact, start stop Kineto profilers during training runtime for a subset of iterations, and the same profiler lifecycle applies there too.

Today our IActivityProfilerSession interface implementation for NCCL profiler, uses the NCCL_PROXY_PROFILER_ENABLED macro to check if the NCCL library has these methods because our Kineto code has to work with NCCL libraries that do not have these methods too as they are not part of the main trunk. Ideally that macro can go away if we can make this a part of the official interface

how it would get it's profiling data back into the global profiling (we generate a file, but will this file become part of the global profiling?)

So consolidating the data with GPU trace (aka global profiling) was not an initial design goal. The generated file is stored at the same location at the GPU trace.

@sanrise
Copy link
Author

sanrise commented Mar 8, 2024

@sjeaugey Hey, do let me know your thoughts on adding these methods to the profiler to - start/stop/check if started/dump trace.

We don't really have to change the functionality of the stock profiler that comes with the library if you don't want to, but just having a better interface allows cheap, quick, basic, out-of-the-box collective tracing for everything that sits above NCCL in the stack (and that's a lot of prod-like things) without having to re-compile each of them with a plugin (not so prod-like) :)

We can always defer any code that improves the specialization of a default profiler to be added as a plugin.

@dfyz
Copy link

dfyz commented Mar 29, 2024

@sanrise
disclaimer: I'm not affiliated with NCCL in any way, I'm just writing this as someone who spent a ton of time trying to profile NCCL with PyTorch

The ability to seamlessly profile NCCL collectives with Kineto/PyTorch would be absolutely amazing! However, I'm a little skeptical about using the existing proxy profiler for this. I'll try to summarize the potential problems below, and propose a hopefully more robust solution based on a different approach.

We don't really have to change the functionality of the stock profiler that comes with the library if you don't want to

First, the existing profiler has to be changed in some way because it simply doesn't work for anything non-trivial. I tried to fix this (see this issue, this PR, and this commit message for more details); actually, while digging through old PRs I just found out I was not the first person to attempt this.

Those PR were never merged, and I believe the reason why is that the existing profiler was never intended to be used for prod-like things (see this and this, for example). The NCCL maintainers simply don't have the bandwidth to review large changes to something that is only meant to serve as proof of concept (which is very understandable).

Second, the proxy profiler events only trigger for the NCCL channels that are proxied over actual network. Anything happening over NVLink, for example, is invisible to the proxy profiler. The kernels used in NCCL do a lot of performance-sensitive operations on the GPU even when there's no network involved (more on that at the end of this comment), and it would be really nice to have visibility into this, too.

Third, perhaps the most fundamental problem is that the proxy profiler events are inherently hard to match with the actual collective kernels launched from PyTorch, since the CPU proxy thread(s) are asynchronous with respect to both GPU kernels and the other CPU threads. From cursory reading, this PR assigns a event triggered by a proxy thread to the collective most recently recorded with ncclCollectiveRecord(), which is really suspicious to me (e.g., does it work if the Python code starts a new torch.distributed operation when the previous asynchronous one is still in progress?). From your description, I presume this has been tested internally at Meta for a while, so perhaps this all works perfectly and I'm missing something, but from personal experience, it seems like a hard problem to solve.

With all this in mind, what if we profiled the actual NCCL kernels instead of the proxy thread? I have a proof-of-concept implementation in my private NCCL fork which essentially does this:

  • allocates per-communicator per-channel buffers for kernel timings (somewhere around here), which are then passed to the kernel
  • in the kernel primitives code, records any events of interest to the buffer; e.g., the wall time when the GPU is ready to process the new portion of data (e.g., here)
  • schedules a host callback with cudaLaunchHostFunc() after the kernel is scheduled here

The diff to add this to NCCL is surprisingly small – a hundred of lines tops. The host callback can then do anything it wants with the raw timings obtained after the kernel finishes. I just collect them in host memory to eventually enrich the raw Kineto trace so that it looks like this:
image

This basically gives you full visibility into what is happening within the pink AllGather over each one of 8 channels. E.g., as I said above, here you can see that the kernel is mostly bottlenecked by data processing (copying data from the user buffer to the network buffer) and not by waiting for the network. If you squint hard enough, you can also see that some of those processing steps are much longer than the other – it turns out that the slow steps are those where the data is not aligned to 16 bytes, which makes NCCL fall back to a much slower code path. I don't think one can understand this using only the data from the proxy profiler.

Having said that, I definitely don't want to hijack this PR with unrelated proposals. Just integrating the existing proxy profiler to PyTorch in some form would also be great. But if this idea of fine-grained kernel profiling sounds interesting to PyTorch/NCCL developers, I can try to create a separate PR (or issue) to discuss this further.

@wangfakang
Copy link

@sanrise disclaimer: I'm not affiliated with NCCL in any way, I'm just writing this as someone who spent a ton of time trying to profile NCCL with PyTorch

The ability to seamlessly profile NCCL collectives with Kineto/PyTorch would be absolutely amazing! However, I'm a little skeptical about using the existing proxy profiler for this. I'll try to summarize the potential problems below, and propose a hopefully more robust solution based on a different approach.

We don't really have to change the functionality of the stock profiler that comes with the library if you don't want to

First, the existing profiler has to be changed in some way because it simply doesn't work for anything non-trivial. I tried to fix this (see this issue, this PR, and this commit message for more details); actually, while digging through old PRs I just found out I was not the first person to attempt this.

Those PR were never merged, and I believe the reason why is that the existing profiler was never intended to be used for prod-like things (see this and this, for example). The NCCL maintainers simply don't have the bandwidth to review large changes to something that is only meant to serve as proof of concept (which is very understandable).

Second, the proxy profiler events only trigger for the NCCL channels that are proxied over actual network. Anything happening over NVLink, for example, is invisible to the proxy profiler. The kernels used in NCCL do a lot of performance-sensitive operations on the GPU even when there's no network involved (more on that at the end of this comment), and it would be really nice to have visibility into this, too.

Third, perhaps the most fundamental problem is that the proxy profiler events are inherently hard to match with the actual collective kernels launched from PyTorch, since the CPU proxy thread(s) are asynchronous with respect to both GPU kernels and the other CPU threads. From cursory reading, this PR assigns a event triggered by a proxy thread to the collective most recently recorded with ncclCollectiveRecord(), which is really suspicious to me (e.g., does it work if the Python code starts a new torch.distributed operation when the previous asynchronous one is still in progress?). From your description, I presume this has been tested internally at Meta for a while, so perhaps this all works perfectly and I'm missing something, but from personal experience, it seems like a hard problem to solve.

With all this in mind, what if we profiled the actual NCCL kernels instead of the proxy thread? I have a proof-of-concept implementation in my private NCCL fork which essentially does this:

  • allocates per-communicator per-channel buffers for kernel timings (somewhere around here), which are then passed to the kernel
  • in the kernel primitives code, records any events of interest to the buffer; e.g., the wall time when the GPU is ready to process the new portion of data (e.g., here)
  • schedules a host callback with cudaLaunchHostFunc() after the kernel is scheduled here

The diff to add this to NCCL is surprisingly small – a hundred of lines tops. The host callback can then do anything it wants with the raw timings obtained after the kernel finishes. I just collect them in host memory to eventually enrich the raw Kineto trace so that it looks like this: image

This basically gives you full visibility into what is happening within the pink AllGather over each one of 8 channels. E.g., as I said above, here you can see that the kernel is mostly bottlenecked by data processing (copying data from the user buffer to the network buffer) and not by waiting for the network. If you squint hard enough, you can also see that some of those processing steps are much longer than the other – it turns out that the slow steps are those where the data is not aligned to 16 bytes, which makes NCCL fall back to a much slower code path. I don't think one can understand this using only the data from the proxy profiler.

Having said that, I definitely don't want to hijack this PR with unrelated proposals. Just integrating the existing proxy profiler to PyTorch in some form would also be great. But if this idea of fine-grained kernel profiling sounds interesting to PyTorch/NCCL developers, I can try to create a separate PR (or issue) to discuss this further.

@dfyz It's looks coooool, how much performance loss will this trace capability have on NCCL? In addition, how can we use the feature you mentioned above? could you share a PR? And this is useful for analyzing performance-related issues, thanks.

dfyz added a commit to dfyz/nccl that referenced this pull request May 21, 2024
@dfyz
Copy link

dfyz commented May 21, 2024

how much performance loss will this trace capability have on NCCL?

I didn't see any noticeable performance impact on the kernels themselves, but the host callback I wrote to process the timings from the kernels is very naive. It does have a noticeable impact on performance when the number of ranks (and hence the number of timestamps collected) is very large. In extreme cases, it could double the effective time it was needed to run the collectives.

In addition, how can we use the feature you mentioned above? could you share a PR?

There is no easy way, because my proof-of-concept never evolved beyond that (mostly because I wasn't sure it is of interest to anyone). It works for my purposes, but the code quality is not that great and only some NCCL kernels can be traced (more exactly, ring-based AllGather and ReduceScatter for the Simple protocol that use no more than 8 channels).

Having said that, I just published this PoC to my NCCL fork, hoping it might serve as inspiration to someone:

  • this is the commit that adds two functions to the NCCL API: ncclSetSaveTimingsState() (to enable collecting timestamps from kernels) and ncclAppendTimingsToJson() (to append the collected raw timestamps to a PyTorch JSON trace)
  • this is an example patch to modify PyTorch to call these functions
  • this is the post-processing script that correlates the collected timestamps with NCCL kernels to form the actual JSON that can be opened in Perfetto

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants