Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add fp16 support of Qwen1.5MoE models (A2.7B) to DeepSpeed-FastGen #5403

Open
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

ZonePG
Copy link
Contributor

@ZonePG ZonePG commented Apr 12, 2024

This PR adds support for Qwen1.5MoE-A2.7B models.

support for microsoft/DeepSpeed-MII#457

Test Code

for mii pipeline:

import mii

pipe = mii.pipeline("/data/zonepg/models/Qwen/Qwen1.5-MoE-A2.7B")
responses = pipe("DeepSpeed is", max_new_tokens=128, do_sample=False)
if pipe.is_rank_0:
    print(responses[0])

for huggingface:

import mii

from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation import GenerationConfig
import torch
tokenizer = AutoTokenizer.from_pretrained("/data/zonepg/models/Qwen/Qwen1.5-MoE-A2.7B")
model = AutoModelForCausalLM.from_pretrained("/data/zonepg/models/Qwen/Qwen1.5-MoE-A2.7B", device_map="auto", torch_dtype=torch.float16, trust_remote_code=True).eval()
print(model)
inputs = tokenizer('DeepSpeed is', return_tensors='pt')
inputs = inputs.to(model.device)
pred = model.generate(**inputs, max_new_tokens=128, do_sample=False, repetition_penalty=1.0)
test = tokenizer.decode(pred.cpu()[0], skip_special_tokens=False)
print(test)

Qwen1.5-MoE-A2.7B

Huggingface output with prompt "DeepSpeed is":

 a deep learning framework that is designed to accelerate the training of large-scale neural networks. It is built on top of PyTorch and provides a set of tools and techniques for optimizing the performance of deep learning models.

DeepSpeed supports a variety of hardware accelerators, including GPUs, TPUs, and FPGAs, and can be used to train models on distributed systems, such as clusters of GPUs or TPUs.

One of the key features of DeepSpeed is its ability to automatically parallelize the training of deep learning models across multiple GPUs or TPUs. This can significantly reduce the time required to train large models, as it allows the

DeepSpeed-FastGen output with prompt "DeepSpeed is":

 a deep learning framework that is designed to accelerate the training of large-scale neural networks. It is built on top of PyTorch and provides a set of tools and techniques for optimizing the performance of deep learning models.

DeepSpeed supports a variety of hardware accelerators, including GPUs, TPUs, and FPGAs, and can be used to train models on distributed systems, such as clusters of GPUs or TPUs.

One of the key features of DeepSpeed is its ability to automatically parallelize the training of deep learning models across multiple GPUs or TPUs. This can significantly reduce the time required to train large models, as it allows the

DeepSpeed-FastGen output with prompt "DeepSpeed is" with 8-way sharding:

 a deep learning framework that is designed to accelerate the training of large-scale neural networks. It is built on top of PyTorch and provides a set of tools and techniques for optimizing the performance of deep learning models.

DeepSpeed supports a variety of hardware accelerators, including GPUs, TPUs, and FPGAs, and can be used to train models on distributed systems, such as clusters of GPUs or TPUs.

One of the key features of DeepSpeed is its ability to automatically parallelize the training of deep learning models across multiple GPUs or TPUs. This can significantly reduce the time required to train large models, as it allows the

shared_expert_output = self.shared_expert_mlp_2(shared_expert_output, cur_params.shared_moe_mlp_2, b=None)
shared_expert_gate_output = self.shared_expert_gate(hidden_states, cur_params.shared_moe_gate, b=None)[..., :1]
# shared_expert_gate_output shape[-1] is 1
shared_expert_output.mul_(torch.sigmoid(shared_expert_gate_output))
Copy link
Contributor Author

@ZonePG ZonePG Apr 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure if using torch.sigmoid directly will affect performance?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

1 participant