Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support automatically calculate max_total_token_num #81

Open
wants to merge 5 commits into
base: main
Choose a base branch
from

Conversation

singularity-s0
Copy link
Contributor

In ApiServerArgs.md, an algorithm was introduced to calculate the optimal max_total_token_num argument. This process can be automated, and this PR introduces this feature.

The max_total_token_num argument now defaults to None. If not set, the API server will automatically calculate the optimal setting according to total GPU RAM and model size. A ratio of 0.8 will also be applied to ensure enough memory is reserved for inference.

Docs have also been updated.

@XHPlus
Copy link
Contributor

XHPlus commented Aug 17, 2023

Thanks for your great PR! We are refactoring part of our code and will merge your PR as soon as the refactored version is ready. Besides, hope to add a WeChat friend with you. (hao95111)

@hiworldwzj
Copy link
Collaborator

@singularity-s0 Hello, Can this feature be modified to support all models? Because different models may have different calculation methods(GQA model is different), should the implementation of this feature be bound to each individual model instance?

@singularity-s0
Copy link
Contributor Author

singularity-s0 commented Aug 21, 2023

Hi,

I'm not entirely sure how GQA or other implementations affect the use of GPU memory, could you please elaborate?

Generally, the formula is max_total_token_num = (total_free_gpu_memory - model_parameter_size) * 0.8 / kv_cache_size according to the docs.

  • total_free_gpu_memory is read using PyTorch CUDA API. This should be the ideal implementation.
  • model_parameter_size is estimated from the size of weight files on disk. This should mostly be accurate, unless some kind of compression is used, which I'm unaware of.
  • kv_cache_size should be dependent on model. If config.json provide enough information to calculate this value for each model, then model-specific implementations are not required. However I'm not sure if this is always the case (maybe GQA somehow affects this?)
  • Some implementations may require additional memory (maybe GQA?). Either config.json tell us enough information or we need model-specific implementations.

@hiworldwzj
Copy link
Collaborator

@singularity-s0 kv_cache_size is more different in the model that use GQA. "GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints"

@singularity-s0
Copy link
Contributor Author

From my understanding of the paper mentioned above, GQA reduces the kv_cache_size by num_attention_heads / num_key_value_heads times. These values are available from config.json so the value of kv_cache_size can always be calculated.

The new formula will be
max_total_token_num = (total_free_gpu_memory - model_parameter_size) * 0.8 / original_kv_cache_size * num_attention_heads / num_key_value_heads

For models that do not use GQA, simply default num_key_value_heads to num_attention_heads. All current models would be supported this way.

Is my understanding correct?

@hiworldwzj
Copy link
Collaborator

@singularity-s0 Yes, you are right.

@singularity-s0
Copy link
Contributor Author

This PR has been updated with changes to how kv_cache_size is calculated. Please review.

with open(config_path, 'r') as f:
config = json.load(f)
hidden_size = config['hidden_size']
layer_num = config['num_hidden_layers']
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@singularity-s0 This code may not be very robust when the key name in config.json changes.

total_size = total_size / (1024 ** 3)
return total_size

def get_kv_cache_size(model_dir):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"get_kv_cache_size and xxxx" is best implemented as a member function of TpPartBaseModel and should be inherited and implemented by its subclasses.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that max_total_token_num (and batch_max_tokens that depends on it) gets passed to a lot of subsystems before the model is initialized. We need this value to be ready early.

Is there any way to achieve this if implemented as a member function of TpPartBaseModel?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@singularity-s0 You can try to add a method in TpPartBaseModel, but it is not easy to get and set batch_max_tokens in TpPartBaseModel. Let me think about how to implement it elegantly. What are your suggestions?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally, since each instance of LightLLM server is bound to only one model, model configuration can (and should) be loaded before all other subsystems are initialized (because other subsystems may depend on model configuration, as in the case of max_total_token_num). A refactor would be the most elegant way to address this.

Other parameters like max_req_total_len and dtype (which is currently hardcoded to fp16) might also be dependent on model config.json and would benefit from this refactor.

However I imagine such a refactor would not be easy. Hacky solutions are also available but it is ultimately up to you to decide which way is the best.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@singularity-s0 You can write a standalone recommendation program to generate a value for max_total_token_num. that will be more appropriate。

@hiworldwzj hiworldwzj self-requested a review December 4, 2023 06:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants