Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Compilation] Optimization for compilation process #426

Closed
wants to merge 2 commits into from
Closed

Conversation

hunbssfy
Copy link
Collaborator

@hunbssfy hunbssfy commented Feb 13, 2024

optimize the compilation behavior for each core. Instead of generating a cuda source file for each core, congragate them into a single source file to reduce compilation overhead.

Initial test reveals a 20% decrease in the total compilation+run time of bert model (222s -> 177s). Also the same percentage of decrease in operator tests.

@hunbssfy hunbssfy added the enhancement New feature or request label Feb 13, 2024
@hunbssfy hunbssfy self-assigned this Feb 13, 2024
@vadiklyutiy
Copy link
Collaborator

did you measure compilation speedup?

optimize the compilation behavior for each core. Instead of generating a cuda source file for each core, congragate them into a single source file to reduce compilation overhead
@hunbssfy hunbssfy marked this pull request as ready for review February 15, 2024 17:54
@hunbssfy
Copy link
Collaborator Author

did you measure compilation speedup?

32 core i913900K + 4090

  • batch_matmul (this measurement is only for compile time IRModule->lower->cu file->.o->.so)
    local compile: 80s -> 60s 25%
    local compiler server (slower maybe due to cpu cores restriction when booting a docker compiler server locally): 104s -> 63s 40%
  • bert (this measures end to end time including inference and whole compilation process)
    local compile(total time): 222s -> 177s 20%
    local compiler server(total time): 321s -> 261s 19%

@vadiklyutiy
Copy link
Collaborator

didn't you notice how many candidates are in matmul and bert?

@hunbssfy
Copy link
Collaborator Author

https://docs.google.com/spreadsheets/d/1uVBh1hCafwhOEeRdrFnuoaz4HOFECyr5cl7WnuMm7Ns/edit?usp=sharing
I did a compilation time breakdown which shows the improvement this optimization has brought to the whole compliation time. This breakdown contains measurements of several common operators. I measured three parts of the compilation time:

a. 'lower' refers to the time spent on IR lowering (kernel level optimiztion
b. 'codegen' refers to the time spent on converting IRModule to .cu source file
c. 'compile' refers to the time spent on compiling .cu source file to object files

'compile_opt' is the optimized branch.

@vadiklyutiy
Copy link
Collaborator

Useful data, thx @hunbssfy
Could you make some clarifications

  1. How did you measure? Put a time counters around functions? If so, could you refer what function you covered.
  2. Could you add for reference the full compilation time of that test, if you have.
  3. To be honest I did not get why lower and codegen speeded up. Seems you did not touch them.

@hunbssfy
Copy link
Collaborator Author

hunbssfy commented Feb 19, 2024

@vadiklyutiy

  1. python/hidet/drivers/build_module.py: build_ir_module. if you look into the code in this function, you can easily find the three functions(lower/codegen/compile) I measured. I used python module 'time' to measure the function calling time. In the multiprocessing scenario, I use a shared value(with a lock) to accumulate all the time used. The time in the sheet is the total time summed from all the processor's individual time.
  2. I didn't measure the full compilation time because I didn't modify other code except the parts in this function. But I will do a full breakdown of the compilation time later this week, because right now I'm not very familiar with the whole compilation process yet.
  3. I'm not quite sure why this is the case either. But if you look at the code that I changed, I did change the function build_ir_module's signature, it can now take a sequence of IRModule and do lower/codegen/compile in a batch. Previously, this function can only take an IRModule at a time. I assume there might be some overhead in the measurement itself and python interpreter. In the batch compilation case, we measure the time to lower/codegen/compile in a batch style while in the original code we measure them for each task.

@vadiklyutiy
Copy link
Collaborator

vadiklyutiy commented Feb 19, 2024

' 1. Thanks for clarification. How many CPU thread did you have on machine where you did measure?
' 3. Seems like packing in a batches impact also speed up "lower". For me a bit unexpected but possible.

BTW From you spreadsheet. It is clear that if we have more candidates we have bigger improvement. matmul, conv, attn have similar number of candidates. For matmul and conv improvement 8.05x and 7.22x. But attn has only 3.83x. Maybe it's worth to look into it. Maybe the reason just because attn has much more code in .cu file but maybe we have some misprint or didn't take into account something.

@hunbssfy
Copy link
Collaborator Author

@vadiklyutiy I used 32 cores so that is 32 threads.

I also noticed this unbalanced acceleration percentage. I'll post new discoveries if I got any.

@yaoyaoding
Copy link
Member

Thanks @hunbssfy @vadiklyutiy !

What if different IRModules have functions with the same name but different implementations? For example, we might have a function in the schedule template which will have different implementations for different hyper-parameters. When we merge the generated code for multiple IRModules, there might be name conflict issue.

@vadiklyutiy
Copy link
Collaborator

Thanks @hunbssfy @vadiklyutiy !

What if different IRModules have functions with the same name but different implementations? For example, we might have a function in the schedule template which will have different implementations for different hyper-parameters. When we merge the generated code for multiple IRModules, there might be name conflict issue.

But as I understand finally they are linked to the lib.o. Does this issue exist in current implementation?

@vadiklyutiy
Copy link
Collaborator

@yaoyaoding could you response

@yaoyaoding
Copy link
Member

yaoyaoding commented Feb 26, 2024

But as I understand finally they are linked to the lib.o. Does this issue exist in current implementation?

They are private functions in each lib.a and the exposed public functions (e.g., those functions with API macro at the head of function declaraion like launch_1, launch_2) have different names. Thus current implementation does not have this problem.

@hunbssfy
Copy link
Collaborator Author

hunbssfy commented Feb 27, 2024

Thanks @hunbssfy @vadiklyutiy !

What if different IRModules have functions with the same name but different implementations? For example, we might have a function in the schedule template which will have different implementations for different hyper-parameters. When we merge the generated code for multiple IRModules, there might be name conflict issue.

@yaoyaoding I went through the generated .cu file during compilation. Here is one of them:

namespace candidate_0 {
static __device__ __forceinline__ void cuda_mma_sync_aligned_m16n8k8_row_col_f32_tf32_tf32_f32(tfloat32_t * __restrict__ a, tfloat32_t * __restrict__ b, float * __restrict__ c) {
...
}
static __global__ void __launch_bounds__(64) batch_matmul_kernel(float * __restrict__ a, float * __restrict__ b, float * __restrict__ c) {
...
}
void launch(float * __restrict__ a, float * __restrict__ b, float * __restrict__ c) {
...
}
}

namespace candidate_1 {
// two static 'internal' functions
// launch function
}

namespace candidate_2 {
...
}
....

As you can see, all candidates have a launch function. But they dwell finely within separate namespaces. This patch didn't change the namespaces of the candidates which means even though candidates are 'grouped' into different .cu files, their namespace names still map to their initial candidate numbers. For candidate 1 that's namespace candidate_0. for candidate 700 that's namespace candidate_700. So there will be no namespace conflicts. I also looked into the generated .so file's symbols, here is what they looked like:

0000000000040c60 T ZN11candidate_06launchEPfS0_S0
000000000003eb60 T ZN11candidate_16launchEPfS0_S0
0000000000040000 T ZN11candidate_26launchEPfS0_S0
0000000000040420 T ZN11candidate_36launchEPfS0_S0
0000000000040840 T ZN11candidate_46launchEPfS0_S0
000000000003ef80 T ZN11candidate_56launchEPfS0_S0
0000000000041ce0 T ZN11candidate_66launchEPfS0_S0
000000000003fbe0 T ZN11candidate_76launchEPfS0_S0
000000000003f7c0 T ZN11candidate_86launchEPfS0_S0
000000000003f3a0 T ZN11candidate_96launchEPfS0_S0
0000000000043180 T ZN12candidate_106launchEPfS0_S0
0000000000041080 T ZN12candidate_116launchEPfS0_S0
000000000003e320 T ZN12candidate_126launchEPfS0_S0
0000000000042d60 T ZN12candidate_136launchEPfS0_S0
00000000000435a0 T ZN12candidate_146launchEPfS0_S0
0000000000042520 T ZN12candidate_156launchEPfS0_S0
000000000003e740 T ZN12candidate_166launchEPfS0_S0
00000000000414a0 T ZN12candidate_176launchEPfS0_S0
0000000000042940 T ZN12candidate_186launchEPfS0_S0
00000000000418c0 T ZN12candidate_196launchEPfS0_S0
0000000000042100 T ZN12candidate_206launchEPfS0_S0
00000000000478e0 T ZN12candidate_216launchEPfS0_S0
00000000000457e0 T ZN12candidate_226launchEPfS0_S0
0000000000046c80 T ZN12candidate_236launchEPfS0_S0
00000000000470a0 T ZN12candidate_246launchEPfS0_S0
00000000000474c0 T ZN12candidate_256launchEPfS0_S0
0000000000045c00 T ZN12candidate_266launchEPfS0_S0
0000000000048960 T ZN12candidate_276launchEPfS0_S0
0000000000046860 T ZN12candidate_286launchEPfS0_S0
0000000000046440 T ZN12candidate_296launchEPfS0_S0
0000000000046020 T ZN12candidate_306launchEPfS0_S0
...

They look normal. So I don't think there will be a function name conflict issue.

@hunbssfy
Copy link
Collaborator Author

hunbssfy commented Feb 27, 2024

But as I understand finally they are linked to the lib.o. Does this issue exist in current implementation?

They are private functions in each lib.a and the exposed public functions (e.g., those functions with API macro at the head of function declaraion like launch_1, launch_2) have different names. Thus current implementation does not have this problem.

@yaoyaoding I inspected the final lib.so file. There are two types of APIs which are GLOBAL. The first one is the launch function of each candidates:

3171: 00000000000cb5d0 1043 FUNC GLOBAL DEFAULT 14 candidate_460::launch(float*, float*, float*)

the second one is these hidet_launch functions:

3199: 00000000001190f0 9 FUNC GLOBAL DEFAULT 14 hidet_launch_229

@vadiklyutiy
Copy link
Collaborator

I read all comments again and seems there is some misunderstanding.
Max in his patch compiles together several candidates of the same operator. Hence all it happens for the same IRModule. He did not try to compile several ops (IRModules) together.

So, this

What if different IRModules have functions with the same name but different implementations? For example, we might have a function in the schedule template which will have different implementations for different hyper-parameters. When we merge the generated code for multiple IRModules, there might be name conflict issue.

cannot happen. Does it make sense?

@yaoyaoding
Copy link
Member

Hence all it happens for the same IRModule. He did not try to compile several ops (IRModules) together.

Even the candidates are fot the same operator, they are represented by different IRModules.

The problem that I previously worried about will not happen in the case that all IRModules that will be batch compiled are coming from the same operator as we will set different namespace for different IRModules. I used namespace to avoid name conflict in the shared library when I group all the kernels in the same shared library (previously, they are in diffferent shared libraries). This alleviates the potential problem of function name conflict problem.

There is still potential name conflict problem if the batch compile API (i.e., build_ir_module function) is used to compile IRModules with name conflict. As there is no such use case in hidet for now, we can proceed without dealing with this problem. But, @hunbssfy could you add some assertion in the function to make sure that there is no name conflict before we group (i.e., make sure each "namespace::function_name" is unique).

@hunbssfy
Copy link
Collaborator Author

hunbssfy commented Mar 4, 2024

Hence all it happens for the same IRModule. He did not try to compile several ops (IRModules) together.

Even the candidates are fot the same operator, they are represented by different IRModules.

The problem that I previously worried about will not happen in the case that all IRModules that will be batch compiled are coming from the same operator as we will set different namespace for different IRModules. I used namespace to avoid name conflict in the shared library when I group all the kernels in the same shared library (previously, they are in diffferent shared libraries). This alleviates the potential problem of function name conflict problem.

There is still potential name conflict problem if the batch compile API (i.e., build_ir_module function) is used to compile IRModules with name conflict. As there is no such use case in hidet for now, we can proceed without dealing with this problem. But, @hunbssfy could you add some assertion in the function to make sure that there is no name conflict before we group (i.e., make sure each "namespace::function_name" is unique).

Hi @yaoyaoding , I added some unqiueness function name check after the regrouping of IRModule

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants