Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[one-optimize] Optimize part of the transformer's attention-head #12917

Open
2 tasks
BalyshevArtem opened this issue Apr 24, 2024 · 12 comments
Open
2 tasks

[one-optimize] Optimize part of the transformer's attention-head #12917

BalyshevArtem opened this issue Apr 24, 2024 · 12 comments
Assignees

Comments

@BalyshevArtem
Copy link
Contributor

BalyshevArtem commented Apr 24, 2024

What

Let's introduce two new optimization passes to simplify and accelerate part of transformer's attention-head.
Original it has the following pattern we can optimize:

Screenshot from 2024-04-24 17-53-02

1. First we can fuse

StridedSlice --- Concatenation 
StridedSlice --- Neg /

pattern as Mul operation, consisting of 1 and where there was a Neg operation there -1

As a result we will have:
Screenshot from 2024-04-24 17-57-50

2. The we can twice fuse Mul with FullyConnected nodes and get:

Screenshot from 2024-04-24 17-59-33

3. And finally Fuse horizontal fc layers, we will get single FC node :

Screenshot from 2024-04-24 18-01-03

Why

To speed up and simplify attention-based models.

How

  • Introduce pass to fuse StridedSlices/Neg/Concatenation as Mul pattern.
  • Introduce pass to fuse Mul with FullyConnected node.
@seanshpark
Copy link
Contributor

@BalyshevArtem , this is awesome!
I've resized the images a little bit smaller for better readability :)

@periannath
Copy link
Contributor

Would you let me know which model you used?

In the model I used, only one FullyConnected layer was created in the corresponding part, so it seems that the structure varies slightly depending on the model.

@BalyshevArtem
Copy link
Contributor Author

BalyshevArtem commented Apr 25, 2024

Would you let me know which model you used?

I used model generated in one of the internal repo - Modified Llama2 (split head).
It is decoder part.

@jinevening
Copy link
Contributor

@BalyshevArtem Thanks for a good idea :) As @periannath mentioned, the original pattern seems to have duplicate FCs, i.e., the two FCs are in fact the same. So the baseline would be the pattern with a single FC layer.

For the second fusion, the second MUL is for applying rotary embedding, which would be a user input (not constant) if the model supports dynamic behavior.

If a model only supports fixed positions (all input tokens' position is fixed, which means that the number of previously cached tokens is also fixed), this would be an effective optimization.

@jinevening
Copy link
Contributor

Introduce pass to fuse StridedSlices/Neg/Concatenation as Mul pattern.

This fusion looks good to me. One minor concern is that this will reduce operator counts but create a new constant tensor. It has to be considered not to increase model size too much.

@BalyshevArtem
Copy link
Contributor Author

This fusion looks good to me. One minor concern is that this will reduce operator counts but create a new constant tensor. It has to be considered not to increase model size too much.

We can fuse this pattern only, if we can then fuse Mul with const into In the fully connected operation, which is located above. As I understand it, the issue with dynamic or static rotation embedding used does not affect this fusion optimization, since there will always be an fully connected layer in front of this pattern, right?

the original pattern seems to have duplicate FCs, i.e., the two FCs are in fact the same. So the baseline would be the pattern with a single FC layer.

I'm not sure I got it right :) In the example that I used:

Screenshot from 2024-04-24 17-53-02

these two fc operations have different constants.
Or do you mean some another pattern?

@jinevening
Copy link
Contributor

jinevening commented Apr 29, 2024

the issue with dynamic or static rotation embedding used does not affect this fusion optimization, since there will always be an fully connected layer in front of this pattern, right?

Yes :)

these two fc operations have different constants.

Ah, your model seems to be the one whose attention heads are split. I thought about the pattern without head split. Below is the original pattern of rotary embedding whose heads are not split.

image

After heads are split, it seems that a new FC is created as FC is fused with Mul (left Mul in the above graph).

I think that kind of fusion should be applied carefully suppressed because it will increase model size quite much (model size is a bottleneck of performance as of now). Thanks for finding.

@jinevening
Copy link
Contributor

@BalyshevArtem Could you share any preliminary result after this optimization, e.g., impacts on cycles/traffic? If there is some sensitive information, please use our internal repo.

@BalyshevArtem
Copy link
Contributor Author

@BalyshevArtem Could you share any preliminary result after this optimization, e.g., impacts on cycles/traffic? If there is some sensitive information, please use our internal repo.

Sure, I will post results in internal repo :)

Below is the original pattern of rotary embedding whose heads are not split.

In this example, we can also apply some optimizations:

  1. Fuse StridedSlices-Neg-Concatenation pattern as Mul operation, consisting of 1 and where there was a Neg operation there -1.
  2. Fuse two muls with const values.
  3. Fuse pattern Add(Mul(Input, Const1), Mul(Input, Const2) as Mul(Input, Const3), where Const3= Const1 + Const2

@jinevening
Copy link
Contributor

It seems that the first fusion is invalid. Please check the begin/end of StridedSlice.

StridedSlice A(begin:0, end:40) --- Concatenation (B+A)
StridedSlice B(begin:40, end:80)--- Neg /

The order of two sliced tensors is changed, so it is impossible to convert the pattern to a simple Mul.

@BalyshevArtem
Copy link
Contributor Author

It seems that the first fusion is invalid. Please check the begin/end of StridedSlice.

StridedSlice A(begin:0, end:40) --- Concatenation (B+A)
StridedSlice B(begin:40, end:80)--- Neg /

The order of two sliced tensors is changed, so it is impossible to convert the pattern to a simple Mul.

Yes, you're right, thank you! Indeed, there is a division in half and a reverse of these halves.

Such a pattern can still be optimized, but it gets more complicated. Let's expand the pattern in question by adding Fully Connected.

---- Weight_Const
|
FullyConnected---->StridedSlice A(begin:0, end:40) --- Concatenation (B+A)
             \---->StridedSlice B(begin:40, end:80)--- Neg /

So the idea is to first split weights and rotate in the same way as StridedSlices->Concatenation does. So In the example from #12917 (comment) we need change weights for FullyConnected (with shape 80 x 240) - split it into two parts by rows: 40 x 240 - first_part and 40 x 240 - second_part and reverse theirs order, now now second_part is first and first_part is second. And after that introduce Mul with negative values (first part), and then fuse it in FC and so on (as in #12917 (comment)).

It turns out to be a highly specialized optimization pattern, but at the same time it allows us to greatly reduce unnecessary calculations and even reduce the binary size, due to fusing constants and weights.
@jinevening,
The question is to: Does this pattern occur in our target models? If you find it helpful to implement such optimization, I will to do it, but if you think that this is too rare pattern that will not be useful to us, then it is better to postpone this task. What do you think? :)

@jinevening
Copy link
Contributor

@BalyshevArtem I've answered the question in the internal repo.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants