-
Notifications
You must be signed in to change notification settings - Fork 2.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Hardcoded block_size in kernels #261
Comments
Yeah I think we should do that. |
Not sure I understand why a larger block_size doesn’t work for you, it shouldn’t have anything to do with VRAM, and 12GB is plenty anyway! What’s the exact error you are getting? As a quick experiment, can you try
i.e. force the kernel to ask for 8192 bytes of shared memory (it does use shared memory but we don’t explicitly declare it as such so there’s a small chance that’s related). |
Same error with 8192 [CUDA ERROR] at file D:\SRC\llm.c\train_gpt2_fp32.cu:1079: void fused_classifier3(float* logits, float* losses,
const float* dlosses, const int* targets,
int B, int T, int V, int P) {
const int block_size = 1024;
const int N = B * T;
const int grid_size = N;
fused_classifier_kernel3<<<grid_size, block_size, 8192>>>(logits, losses, NULL, dlosses, targets, B, T, V, P);
cudaCheck(cudaGetLastError());
}
If you consider the maxGridSize, maxThreadDim etc...below from the deviceProps. What limits could I be running into? |
Typically this happens when the compiled kernel has more registers than the maximum allowed to be able to run that many warps per SM… There’s 256KiB of register file, so with 1024 threads and 4 bytes per float, that’s a maximum of 64 registers per thread. I’m not sure why/how the compiler ends up using that many registers for you, maybe too aggressive unrolling… Can you try adding the following to the kernel declaration: launch_bounds(MAX_THREADS_PER_BLOCK, MIN_BLOCKS_PER_MP) so it becomes:
|
The launch_bounds(1024, 1) bypassed "the too many resources error". However, all the losses go kaboom with block_size = 1024. But everything passeswith block_size = 32
I also noticed we don't do CEIL_DIV for fused_classifier3. Is it designed that way? |
I'm running fp32 btw. B = 4, T = 64. So pretty small batch. |
Is that at the same time as the previous change setting shared memory at 8192? If so I’m at a loss :( I have a RTX 4090 so same generation but slightly older drivers, I might try upgrading and see if I see the same issue next week, but… I just realised your output says WDDM, is this on Windows or WSL? I don’t think that typically matters but some things like timeouts and shared memory allocation restrictions might be different etc… |
It's Windows. |
I'm seeing the same error on Windows too. I do NOT see it on WSL2. Same spot as above in :
|
@azret - fixed it! I was missing the PFLAGS in the build - can you check to see if you are building with those new flags? |
For block sizes, maybe we should look into This does the occupancy calculator for a given function. |
@PeterZhizhin - I bumped into the -G debug issue too "[CUDA ERROR] at file train_gpt2.cu:1410: |
@rosslwheeler yes, I had exactly the same issue. Seems like the kernel uses too many registers? Reducing block size to 512 in debug mode makes the code work. |
Can we make the block_size in the kernels more adaptive or parameterized? e.g. 1024 is pretty big for my GPU with 12GB of memory.
I have to run with block_size = 32
The text was updated successfully, but these errors were encountered: