Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

delete use of cooperative groups in kernels #292

Open
karpathy opened this issue Apr 29, 2024 · 2 comments
Open

delete use of cooperative groups in kernels #292

karpathy opened this issue Apr 29, 2024 · 2 comments

Comments

@karpathy
Copy link
Owner

We use a lot of cooperative groups functionality in our kernels. This is an additional dependency that is likely mildly convenient, but it is also likely that the code could be written without them, without too much added complexity, and just as fast. As a general feature ideally llm.c is very careful in the "dependency surface" of its code, which would make it very portable, easy to skim/read even if slightly longer, and easy to run or port to any hardware, old/new/edge/exotic/ or otherwise unthought of.

I would accept PRs that develop cooperative-groups-free kernels in dev/cuda that:

  1. aren't too much more complex or more LOC
  2. have the same speed

On top of dev/cuda I'd be happy to merge these into "mainline" train_gpt2.cu and the fp32 version train_gpt2fp32.cu.

@ChrisDryden
Copy link
Contributor

Just posting some notes here on my research of how to remove all of the CG related code to remove the dependency:

    sum = cg::reduce(warp, sum, cg::plus<float>{});

Can be replaced with the following

__device__ float warpReduceSum(float val) {
    for (int offset = 16; offset > 0; offset /= 2) {
        val += __shfl_xor_sync(0xFFFFFFFF, val, offset);
    }
    return val;
}

sum = warpReduceSum(sum)

Without the need for any thread syncs.

Also the other variables that are used can be replaced by the following:

int warpSize = 32;
int laneId = threadIdx.x % warpSize;
int warpId = threadIdx.x / warpSize;
int warpsPerBlock = (blockDim.x / warpSize);
warp.thread_rank() == laneId
warp.size() == warpSize
warp.meta_group_size() == warpsPerBlock
warp.meta_group_rank() == warpId

I have replaced most of the kernel to test for performance improvement and I was not able to see any noticable change by removing the cooperative groups.

@ngc92
Copy link
Contributor

ngc92 commented May 2, 2024

in many cases, I also find it quite convenient to just have a blockSize of 32 in x direction, and the rest in y direction.
Then threadIdx.x corresponds to laneId and threadIdx.y is warpId. Doesn't work when the block naturally already uses the other block dims.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants