You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Inside Pallas kernels, we often want a loop, and to speed up compilation, we typically use a scan function such as jax.lax.fori_loop. (For example, in the attention kernel example here.)
As the length of the loop grows, fori_loop slows down execution substantially (relative to using a Python for-loop). I put together a minimal script to isolate the issue, and running it on an A6000, saw a 2-3x slowdown on long loops:
I think the best explanation I found online is the following:
To elaborate on this, the reason GPU is so fast for vectorized operations is not that individual floating point operations are particularly fast (they're actually often slower than similar operations on a CPU!), but rather that it can very efficiently run many operations in parallel. For an operation like scan in which each step depends on the output of the previous, the sequence of operations as a whole cannot be parallelized. So you end up not taking advantage of any of the GPU's inherent parallelism, and the result is slow execution.
Contrast this to CPU, where individual floating point operations are relatively fast, but there is no so much in-built parallelism available. Because of this, scan does not incur as much of a performance penalty.
I think this falls into the same bucket. Like my example, where the input of each layer depended on the output of the previous one (the carry), this is a pure sequential loop. There is probably a sweet spot with the unroll parameter where compilation times and loop times are optimal.
Description
Inside Pallas kernels, we often want a loop, and to speed up compilation, we typically use a scan function such as jax.lax.fori_loop. (For example, in the attention kernel example here.)
As the length of the loop grows, fori_loop slows down execution substantially (relative to using a Python for-loop). I put together a minimal script to isolate the issue, and running it on an A6000, saw a 2-3x slowdown on long loops:
Here is the script that generated these results:
System info (python version, jaxlib version, accelerator, etc.)
The text was updated successfully, but these errors were encountered: