Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

IPEX LLM serving example #3068

Merged
merged 19 commits into from
May 16, 2024
Merged

IPEX LLM serving example #3068

merged 19 commits into from
May 16, 2024

Conversation

bbhattar
Copy link
Contributor

@bbhattar bbhattar commented Apr 3, 2024

Adding an example for deploying the text generation with Large Language Models (LLMs) with IPEX. It can use (1) IPEX Weight-only Quantization to convert the model to INT8 precision, (2) IPEX Smoothquant quantization, or (3) default bfloat16 optimization.

Files:

README.md
llm_handler.py - custom handler for quantizing and deploying the model
model-config-llama2-7b-bf16.yaml - config file for bfloat16 optimizations
model-config-llama2-7b-int8-sq.yaml - config file for smooth-quant quantization
model-config-llama2-7b-int8-woq.yaml - config file for weight-only quantization
sample_text_0.txt - A sample prompt you can use to test the text generation model.

Type of change

Please delete options that are not relevant.

  • New feature (non-breaking change which adds functionality)

Checklist:

  • Did you have fun?
  • Have you added tests that prove your fix is effective or that this feature works?
  • Has code been commented, particularly in hard-to-understand areas?
  • Have you made corresponding changes to the documentation?

Copy link
Collaborator

@lxning lxning left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the contribution. Could you please add test for this example? See pytest example: https://github.com/pytorch/serve/blob/master/test/pytest/test_example_gpt_fast.py.

examples/large_models/ipex_llm_int8/README.md Outdated Show resolved Hide resolved
examples/large_models/ipex_llm_int8/README.md Outdated Show resolved Hide resolved
self_jit = torch.jit.trace(converted_model.eval(), example_inputs, strict=False, check_trace=False)
self_jit = torch.jit.freeze(self_jit.eval())

self_jit.save(self.quantized_model_path)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it possible to add logic to check if the quantized_model_path exist? if exist, skip this step to reduce model loading latency.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added the logic. Users can choose to re-quantize through clear_cache_dir flag in the config

Comment on lines 566 to 567
for i, x in enumerate(outputs):
inferences.append(self.tokenizer.decode(outputs[i], skip_special_tokens=True))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

usually batch_decode is faster.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed to batch_decode

@lxning lxning requested a review from mreso May 8, 2024 04:21
Copy link
Collaborator

@mreso mreso left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left some comments, please verify if the smooth quant branch is actually working.

torchserve --ncs --start --model-store model_store
```

4. From the client, set up batching parameters. I couldn't make it work by putting the max_batch_size and max_batch_delay in config.properties.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lets figure this out and update the readme before merging.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed it in the new commit

examples/large_models/ipex_llm_int8/llm_handler.py Outdated Show resolved Hide resolved
try:
import intel_extension_for_pytorch as ipex
try:
ipex._C.disable_jit_linear_repack()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this LLM specific or is this something we should also do in the ipex integration in the basehandler?

def initialize(self, ctx: Context):
model_name = ctx.model_yaml_config["handler"]["model_name"]
# path to quantized model, if we are quantizing on the fly, we'll use this path to save the model
self.clear_cache_dir = ctx.model_yaml_config["handler"]["clear_cache_dir"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

clear_cache_dir does not seem to be set in any of the example model config yaml. Better to use .get("clear_cache_dir", DEFAULT_VALUE) here and replace default value to what you think is appropriate.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added default value for every parameters, except model name

model_name = ctx.model_yaml_config["handler"]["model_name"]
# path to quantized model, if we are quantizing on the fly, we'll use this path to save the model
self.clear_cache_dir = ctx.model_yaml_config["handler"]["clear_cache_dir"]
self.quantized_model_path = ctx.model_yaml_config["handler"]["quantized_model_path"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here and below, would be good to set a default value using .get() and remove from yaml file to concentrate on important settings there.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

if hasattr(self.user_model.config, n):
return getattr(self.user_model.config, n)
logger.error(f"Not found target {names[0]}")
exit(0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better to exit with 1 here as this is an error condition.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, changed!

# need to recompute these
def _get_target_nums(names):
for n in names:
if hasattr(self.user_model.config, n):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this code tested? self here seems to be Evaluator which does not have user_model as an attribute.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added the test for smooth-quant path and fixed the scope issue for user_model

torch.zeros(1, 0, 0, 1, dtype=torch.long).contiguous(),
torch.zeros([1, n_heads, 1, head_dim]).contiguous(),
torch.zeros([1, n_heads, 1, head_dim]).contiguous(),
self.beam_idx_tmp,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here, self (Evaluator) will not have beam_idx_tmp which is part of IpexLLMHandler

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

recomputed beam_idx_tmp here


example_inputs = self.get_example_inputs()

with torch.no_grad(), torch.cpu.amp.autocast(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: The following lines (tracing and saving the model) are equal to all 3 conditions and could be replace by a single appearance after the if-else

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

replaced by single trace_and_export function

example_inputs = None
input_ids = torch.ones(32).to(torch.long)
attention_mask = torch.ones(len(input_ids))
if self.example_inputs_mode == "MASK_POS_KV":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: code could be shared with collate_batch by moving this into a utility function.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed this part. Now utilizing the dataloader built with collate_batch for generating an example input

@bbhattar bbhattar requested review from mreso and lxning May 15, 2024 15:29
Copy link
Collaborator

@mreso mreso left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM now, please address the linting issue before merging and as our CI worker does not have access to the llama weights, please skip execution of the test for now. Thanks!

## Model Config
In addition to usual torchserve configurations, you need to enable ipex specific optimization arguments.

In order to enable IPEX, ```ipex_enable=true``` in the ```config.parameters``` file. If not enabled it will run with default PyTorch with ```auto_mixed_precision``` if enabled. In order to enable ```auto_mixed_precision```, you need to set ```auto_mixed_precision: true``` in model-config file.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

config.properties?

@lxning lxning enabled auto-merge May 15, 2024 19:48
auto-merge was automatically disabled May 15, 2024 20:12

Head branch was pushed to by a user without write access

@mreso mreso enabled auto-merge May 16, 2024 04:09
@mreso mreso added this pull request to the merge queue May 16, 2024
Merged via the queue into pytorch:master with commit 34bc370 May 16, 2024
12 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants