Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimized rope_rotation_llama and apply temperature to logits with vectorization #59

Open
wants to merge 2 commits into
base: master
Choose a base branch
from

Conversation

andresnowak
Copy link

No description provided.

@andresnowak andresnowak changed the title Optimized rope_rotation_llama with vectorization Optimized rope_rotation_llama and apply temperature to logits with vectorization Oct 23, 2023
@tairov
Copy link
Owner

tairov commented Oct 23, 2023

Hi @andresnowak thanks for your PR. I'll review it shortly..
Did you get any performance improvement when switched to vectorized rope rotation?

@andresnowak
Copy link
Author

Not really, at least when testing with the 15m and 110m models i didn't see a difference or the difference was to little, probably the matmul operation just takes the majority of the time and you can't see to much of a difference optimizing other operations (but i was checking if it would be possible to tile the matmul operation), but i think in bigger models there should be somewhat of a difference. But i can't also test to well for improvements, because each run gives very different toks/s, i sometimes get 330 and sometimes 410 in a ryzen 3600x.

@tairov
Copy link
Owner

tairov commented Oct 23, 2023

Approach you implemented looks interesting.
I did some benchmark, this change doesn't give any performance boost for models up to 110M. However, on bigger models it can show some difference.

image

I'll try to benchmark it on tinyllama.

@mikowals
Copy link
Contributor

I vectorized ROPE as you have a while back but the networks were much too small to see any impact. I also did micro benchmarks focused on the ROPE function and even head size 25000 (512 * 512) had no improvement. That is already far bigger than any LLM uses. At bigger sizes it did get faster. I am just not sure it is relevant.

But it wasn't slower as far as I could tell. My tests were on 4 core 32 GB ram Ubuntu and also the playground which is usually 32 cores.

@andresnowak
Copy link
Author

andresnowak commented Oct 23, 2023

hmm interesting, but shouldn't vectorized be faster with a 1000 values already, or is strided load and store not as efficient because it has to access to separated parts in memory?, or is the speedup just not noticeable until you get to bigger iteration sizes

@mikowals
Copy link
Contributor

mikowals commented Nov 4, 2023

I think this will get a performance improvement by removing the parallelize call of the loop over heads in ROPE and replacing with a simple for loop. Setting up threads with parallelize is time consuming and I think removing it is a net performance gain. vectorize on its own seems slightly faster than master branch and the out performance probably increases if larger networks get added.

Also there is a version of tile that steps down through nelts values that might be cleaner (and ever so slightly faster) than 1 value at a time when the head_size isn't a multiple of nelts. Like this:

alias nelts_alternatives = VariadicList[Int](nelts, 16, 8, 4, 2, 1)
tile[calc_head, nelts_alternatives](0, head_size // 2) # calc_head needs to multiply iterator by 2

@andresnowak
Copy link
Author

andresnowak commented Nov 4, 2023

Regarding the parallelize part, Why would it be faster to remove it, isn't parallelize supposedly using a cached runtime?, so it shouldn't be creating threads each time it is called no?. Or am I not understanding how the implementation works?

@andresnowak
Copy link
Author

Doing some tests in a 3600x @mikowals, i only saw a difference in the 15m model when comparing rope with parallelization and without it, for 110m and tiny_llama i didn't see a difference. And doing some isolated benchmarking for the rope function i saw that the version without parallelization was slower than the parallelized one when an amount of values was that was 5000 or more (ex. head_size = 1000, n_heads = 6), but with less values, the version without parallelization was faster, so if in general we always have less than 5000 values, then we could remove the parallelization

@mikowals
Copy link
Contributor

mikowals commented Nov 5, 2023

On M1 Pro removing parallelize is better. Not a huge difference in the whole network but clearly better. In isolated benchmarks at the size of the baby llama models parallelizing is about 10x slower than a vanilla for-loop. At huge sizes parallelizing is faster.

In these graphs V1 is current master and V2 is with the two line change to remove parallelize and replace with a for-loop.
Screenshot 2023-11-05 at 4 55 44鈥痯m
Screenshot 2023-11-05 at 4 57 03鈥痯m
Screenshot 2023-11-05 at 4 56 42鈥痯m
Screenshot 2023-11-05 at 4 56 11鈥痯m

@andresnowak
Copy link
Author

andresnowak commented Nov 5, 2023

Maybe i didn't understand correctly, in these benchmarks you are comparing the original implementation and the rope implementation with simd and without parallelization no?, what i was comparing was rope with simd and parallelization and rope with simd and no parallelization. But yeah as you say if the sizes are always in the range of the size of baby llama, then it can be better to remove parallelization. But one last thing, when you where doing the benchmarks did you compare parallelization using all cores on the m1 pro or only the performance cores, because i think using all cores in the m1 can be slower than using only the performance cores, because we are dividing the work exactly for the amount of cores in the machine, and from what i understand the efficiency cores are a lot slower than the performance cores

@mikowals
Copy link
Contributor

mikowals commented Nov 6, 2023

I put the comparisons I did in this branch. The graphs above are done where "V1 is current master and V2 is with the two line change to remove parallelize and replace with a for-loop." I used lamatune and set V1 to current master and V2 to the branch I just linked to.

The isolated benchmarks I ran are in this file. I run all tests on my system with -j 6 because from what I can tell that is optimal (fastest) on my system. It may be that different worker numbers indicate different trade offs between ROPE algorithms but I am mostly interested in trying to improve the global optimum. It may well be that my results only hold on MacOS and Mac M series chips.

My comments on this PR are really just that there seems to be easy performance improvement from removing parallelize.

@andresnowak
Copy link
Author

@mikowals Doing the lamatune and doing your isolated benchmarks I have found the same conclusion as you, simple vanilla for loops are faster than parallelize, vectorize and vectorize parallelize, for all the available models. In the benchmarks (v1 is current master, v2 is your implementation removing the parallelization, and v3 is my implementation with parallelize and vectorize).
image.

And for the isolated benchmarks, vanilla for loops was faster than all the other implementations for all the models up to tinyllama, in tinyllama it was the same speed as vectorize_parallelize and parallelize, and for the imaginary networks the fastest one was parallelize. The only thing is that i don't understand why is vectorization slower than vanilla for loops, i don't know if it has to do with cache misses or that maybe simd_strided is a slower instruction, or i just don't know. But yeah, it seems it would be better to remove parallelization for rope_llama if we don't get to the sizes of the imaginary network, or maybe we can add an if inside the function to run rope_llama with parallelization or not depending on the size

stories15M size dims-heads-size: 288 - 6 - 48
RoPE vectorize_parallelize time: 0.0094327443032383718 ms
RoPE parallelize (current) time: 0.0094180861510420478 ms
RoPE vectorize time: 0.0003703772909307732 ms
RoPE vanilla for loops time: 0.00027608576648576648 ms

stories110M size dims-heads-size: 768 - 12 - 64
RoPE vectorize_parallelize time: 0.0094208203660253401 ms
RoPE parallelize (current) time: 0.0092873126632242193 ms
RoPE vectorize time: 0.0018691157337200571 ms
RoPE vanilla for loops time: 0.00052221472060768118 ms

TinyLlama-1B size dims-heads-size: 2048 - 32 - 64
RoPE vectorize_parallelize time: 0.0093080902034136086 ms
RoPE parallelize (current) time: 0.0092769410737927899 ms
RoPE vectorize time: 0.0027147553415534154 ms
RoPE vanilla for loops time: 0.00090131258923845407 ms

imaginary huge network size dims-heads-size: 262144 - 128 - 512
RoPE vectorize_parallelize time: 0.022488891769826116 ms
RoPE parallelize (current) time: 0.017436351394991213 ms
RoPE vectorize time: 0.097005635083138655 ms
RoPE vanilla for loops time: 0.041574515086885068 ms

@tairov
Copy link
Owner

tairov commented Nov 6, 2023

Probably for simple loops Mojo already has some levels of optimizations by default. Similarly as gcc optimizes for loops.

@andresnowak
Copy link
Author

Hmm, but vectorize also uses for loops no?, so it should have the same optimizations no?, or is it not applying the same optimizations

@tairov
Copy link
Owner

tairov commented Nov 6, 2023

That's hard to say. Possibly we can manually tune the source code to make it "more optimal" than standard optimizations. But if there is no manually optimized loops, the compiler could attempt to do so.
As it was shown with gcc , not all compiler optimizations speed up the code; in some cases they can even slow it down. But in order to dig further into this issue, we may need to dive into the depths, which is not permitted under the Mojo license.
On the source code level we can try to leverage autotune, and some other Mojo features to discover other performance wins.

@andresnowak
Copy link
Author

hhmmm, okay, I understand.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants