Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enable yuan autotp & add conv tp #5428

Open
wants to merge 7 commits into
base: master
Choose a base branch
from

Conversation

Yejing-Lai
Copy link
Contributor

@Yejing-Lai Yejing-Lai commented Apr 17, 2024

This PR aims to enable yuan model autotp and add conv tp.

Yuan model used shared qk.
For example:
q_linear_out = [q1, q2, q3, q4, q5, ... , q16]
k_linear_out = [k1, k2, k3, k4, k5, ... , k16]

after share qk:
TP=1:
q' = [q1,q2,q3,q4, q9,q10,q11,q12, k1,k2 k3,k4, k9,k10,k11,k12]
k' = [q5,q6,q7,q8, q13,q14,q15,q16, k5,k6,k7,k8, k13,k14,k15,k16]
v' = [v1,v2,v3,v4, v5,v6,v7,v8, v9,v10,v11,v12, v13,v14,v15,v16]

TP=2:
rank0:
q'_0 = [q1,q2,q3,q4, k1,k2 k3,k4]
k'_0 = [q5,q6,q7,q8, k5,k6,k7,k8]
v'_0 = [v1,v2,v3,v4, v5,v6,v7,v8] -> v'_0 is error! Expect value is: [v1,v2,v3,v4, v9,v10,v11,v12]
rank1:
q'_1 = [q9,q10,q11,q12, k9,k10,k11,k12]
k'_1 = [q13,q14,q15,q16, k13,k14,k15,k16]
v'_1 = [v9,v10,v11,v12, v13,v14,v15,v16] -> v'_1 is error! Expect value is: [v5,v6,v7,v8, v13,v14,v15,v16]

To avoid modifying the modeling code. We adjust the value and oproj weight to fit this qk type.

We also added the conv tp to support some models that including the heavy conv calculation. It is similar to the linear tp policy.
if not last_conv_layer:

    1. Divide the conv weight to each rank along the output channel dimension.
    1. To apply conv2d.

else:

    1. Divide the conv weight to each rank along the input channel dimension.
    1. Apply conv2d.
    1. Use allreduce to add outputs.

@delock
Copy link
Contributor

delock commented Apr 17, 2024

@@ -123,3 +123,54 @@ def _transpose_fused_qkvw(src, mp_size, fused_qkv_type=None, module=None):
warning_once(f"Unrecognized fusedkqv weight type, default to using bloom type,"
f"please check in prepare_tp_fused_qkvw() to avoid potential calculation errors")
return _bloom_type_transpose(src, mp_size)


def shard_value_with_share_qk(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comments here needed (with an example) to help understand functionality of shard_value_with_share_qk()

self.shard_by_oc = shard_by_oc
self.shard_weights(conv)

def shard_weights(self, conv):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should have some comments to explain the sharding scheme here. Better with a simple example to help understanding.

@@ -350,6 +372,9 @@ def set_lm_head(module):
pbar.update(1)
gc.collect()
replaced_module = set_lm_head(replaced_module)
# conv2d tp module replace
if 'Yuan' in str(replaced_module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it mean we apply conv sharding only for models we know there is conv layer?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. I added the comment for helping to understand this situation~

@delock
Copy link
Contributor

delock commented Apr 23, 2024

Hi @tjruwase, we get request to support Yuan model AutoTP (https://huggingface.co/IEITYuan/Yuan2-102B-hf). This model has special QKV format and also has convolution layers which need special treatment in tensor parallelism. This PR address both model features and support them inside DeepSpeed AutoTP. Can this PR be reviewed? Thanks!

@loadams
Copy link
Contributor

loadams commented May 15, 2024

Hi @delock - FYI could you resolve the merge conflicts on this PR so it can be reviewed/tests run?

@Yejing-Lai
Copy link
Contributor Author

Hi @delock - FYI could you resolve the merge conflicts on this PR so it can be reviewed/tests run?

Hi @loadams. The conflicts have been resolved. Please review~

@loadams
Copy link
Contributor

loadams commented May 22, 2024

Hi @delock - FYI could you resolve the merge conflicts on this PR so it can be reviewed/tests run?

Hi @loadams. The conflicts have been resolved. Please review~

This looks fine to me, anything else you want to review @delock ?

@delock
Copy link
Contributor

delock commented May 22, 2024

Hi @delock - FYI could you resolve the merge conflicts on this PR so it can be reviewed/tests run?

Hi @loadams. The conflicts have been resolved. Please review~

This looks fine to me, anything else you want to review @delock ?

@loadams looks fine for me, thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants