Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

improve efficiency of warps #38

Open
maxwellzh opened this issue Dec 30, 2022 · 4 comments
Open

improve efficiency of warps #38

maxwellzh opened this issue Dec 30, 2022 · 4 comments
Labels
enhancement New feature or request

Comments

@maxwellzh
Copy link
Contributor

In current implementation, the warps along T axis are computed in fully serialized manner

warp-rnnt/core.cu

Lines 112 to 134 in edd5857

if (t < actual_t && u < actual_u) {
// Ready to compute alphas[t, u]
unsigned int l = labels[idx2(n, u-1, U-1)];
float bias = log_probs[idx4(n, t-1, u, blank, T, U, V)];
float skip = alphas[idx3(n, p, u, T, U)] + bias;
float emit = alphas[idx3(n, t, u-1, T, U)] + log_probs[idx4(n, t, u-1, l, T, U, V)];
float r = log_sum_exp(skip, emit);
float output = r;
for(unsigned int i = 1; i < W; i++) {
r = __shfl_up_sync(0xffffffff, r, 1);
if (i == d) {
r = log_sum_exp(r + bias, emit);
output = r;
}
}
alphas[idx3(n, t, u, T, U)] = output;
}

The for loop of each warp is executed one-by-one, which means the ith warp at specific row u, has to wait for all its leading warps to finish the loops, and that is i (num of warps) * W (for loop overhead, warpsize, 32 here) time complexity.

However, we don't necessarily have to wait for previous warps to finish before we go into the loop in current warp.

Let's take forward computation of alphas as the example with warpsize=4:
warp_sample
Here d denotes the index inside a warp, so 0 <= d < W. B is the result from u-1 row and supposed to be ready.

The forward computation of alpha follows (indeed we do the computation in logarithm, here is just for discussion):
Screenshot 2022-12-30 at 15 14 13
Note that alpha_0 relies on result from the last warp.

Here comes the trick, I rewrote alpha_3 formula to following
Screenshot 2022-12-30 at 15 26 29

The underlined part is warp-independent. The first part (the product of emitting probability e_2 e_1 e_0) can be computed by prefix sum (scan) algorithm in logarithm, and only introduce log2(W) complexity.

Finally, the new procedure is like:

  1. Compute local paths combination prob (the underlined part). O(W) complexity;
  2. Compute product of emitting probs (e2e1e0, ...) with prefix sum algorithm. O(log2(W)) complexity;
  3. Wait for previous warps to finish and compute final results. Constant complexity.

For all warps at row u, 1 & 2 can be done in parallel, ith warp has only to wait all previous warps to finish step 3. The new procedure should be considerably faster than current serialized execution, especially when T is large.

@1ytic
Copy link
Owner

1ytic commented Dec 30, 2022

Hello Huahuan Zheng, interesting theory! But I don't think it will be useful in practice. Optimising a forward pass doesn't make sense. Your can check the cuda profiler logs. The big issue is memory IO, and I really like your previous MR with compact memory version. I wish to finish reviewing it and reopen your MR in near feature.

@1ytic 1ytic added the enhancement New feature or request label Dec 30, 2022
@maxwellzh
Copy link
Contributor Author

Will do further investigation later :)

As for the IO issue, I remember I have seen in somewhere that a thread block would instinctively load nearby memory whatever it is used or not. Have you ever tried using (N, U, T, V) layout instead of (N, T, U, V)? With the former's (and especially when gather=True), a warp (also a thread block) is able to load a chunk of consecutive memory and reuse it.

Indeed, I've been using the compact version loss function in our speech recognition tasks for a while. It should be technically correct (it's in my dev branch now, the main branch hasn't been updated for some time). I'll finish some merge from my dev to the main branch, and once it's finished, I would reopen the MR.

@1ytic
Copy link
Owner

1ytic commented Jan 1, 2023

I’m not familiar with memory manager for cuda threads. But you right, having TxU matrix is the main bottleneck. Fortunately, there is solution for this, fast_rnnt. It looks really promising.

@maxwellzh
Copy link
Contributor Author

I've been following the fast_rnnt work for a while, but haven't make a successful pruned rnn-t training yet.

They also have a paper about the implementation. https://arxiv.org/pdf/2206.13236.pdf

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants