Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Text Generation][V2] LinearRouter to accept SPLIT/JOIN #1434

Draft
wants to merge 50 commits into
base: feature/damian/no_kv_cache
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
3e00175
Pipelines Refactor - Initial Impl (#1287)
bfineran Oct 26, 2023
224e116
[Pipeline Refactor] Additional functionality, engine operator, linear…
dsikka Oct 31, 2023
58b0758
[v2] EngineOperator updates to make continuous batching easier (#1371)
bfineran Nov 1, 2023
e1ff108
[Pipeline Refactor] Update routes, text generation initial functional…
dsikka Nov 3, 2023
59457b7
[Pipeline Refactor] Additional Operators, Route update and completed …
dsikka Nov 3, 2023
f18d5f3
add split/join functionality
dsikka Nov 3, 2023
2c4d231
update router to include split/join in parent class, refactor pipelin…
dsikka Nov 7, 2023
672ca20
process multiple generations
dsikka Nov 7, 2023
304eb35
initial commit
dbogunowicz Nov 8, 2023
71515ac
fix error
dbogunowicz Nov 8, 2023
6f1b175
Merge remote-tracking branch 'origin/features/v2/run_multiple' into f…
dbogunowicz Nov 9, 2023
041174b
[Pipeline Refactor] Split/Join Functionality for multiple prompts (#1…
dsikka Nov 9, 2023
a508342
unit testing for text generation operators
dsikka Nov 6, 2023
cbb0e86
additional changes
dsikka Nov 7, 2023
2541581
unit testing completion
dsikka Nov 8, 2023
8c8989d
remove debug
dsikka Nov 8, 2023
f8d75e3
fix
dsikka Nov 8, 2023
fd1e466
add todo
dsikka Nov 8, 2023
64c0552
more clean-up
dsikka Nov 8, 2023
913665a
fix test
dsikka Nov 8, 2023
e15521f
add docstrings/comments
dsikka Nov 8, 2023
379481e
break out tests to individual unit test files; add conftest and make …
dsikka Nov 9, 2023
a90a20a
Merge remote-tracking branch 'origin/features/v2/unit_testing' into f…
dbogunowicz Nov 10, 2023
0a50d1d
[Pipeline Refactor] Unit Testing for Text Generation Operators (#1392)
dsikka Nov 10, 2023
c0c4240
Merge branch 'v2' into feature/damian/v2/factor_out_transformation_utils
dbogunowicz Nov 10, 2023
4f248dd
Delete tests/deepsparse/v2/unit/text_generation/test_msic.py
dbogunowicz Nov 13, 2023
20980a7
[Continuous Batching] Queue Implementation to support batching groupi…
bfineran Nov 13, 2023
d81012d
[Continuous Batching] Executor thread for running continuous batching…
bfineran Nov 13, 2023
5c48505
[ContinuousBatching] ContinuousBatchingScheduler Implementation (#1375)
bfineran Nov 13, 2023
e1b7f37
[continuous batching] singleton pattern for scheduler (#1391)
bfineran Nov 13, 2023
98f7a6d
Merge branch 'v2' into feature/damian/v2/factor_out_transformation_utils
dbogunowicz Nov 14, 2023
bbd534d
[Pipeline Refactor][Text-Generation] Create a helper function for cre…
dbogunowicz Nov 14, 2023
d1683b4
Merge branch 'v2' into feature/damian/v2/factor_out_transformation_utils
dbogunowicz Nov 14, 2023
51c4ee6
pipeline runs, but incorrectly
dbogunowicz Nov 17, 2023
fa96efb
it works for a single sequence
dbogunowicz Nov 20, 2023
e41ddf8
cleanup. now lets figure out how to run multiple sequences
dbogunowicz Nov 20, 2023
b80a417
[Pipeline Refactor][Text-Generation] Refactor `transformers` helpers …
dbogunowicz Nov 20, 2023
1b9238a
[Text Generation][V2] End-to-end tests (#1402)
dbogunowicz Nov 20, 2023
89f11e5
Merge remote-tracking branch 'origin/v2' into feature/damian/no_kv_cache
dbogunowicz Nov 21, 2023
9b441f5
integration tests pass
dbogunowicz Nov 21, 2023
c858b1f
[Pipeline Refactor][Text Generation][Continuous Batching] Integration…
dsikka Nov 21, 2023
bb3ff41
[Pipeline Refactor] Operator Registry (#1420)
dsikka Nov 21, 2023
19434e7
Merge remote-tracking branch 'origin/v2' into feature/damian/no_kv_cache
dbogunowicz Nov 22, 2023
90de2b3
fix tricky rebase
dbogunowicz Nov 22, 2023
66ca295
one more cleanup
dbogunowicz Nov 22, 2023
dcded1d
got tests to work after rebase. implementing SPLIT and JOIN in linear…
dbogunowicz Nov 22, 2023
127aa00
pipeline working, with GraphRouter. Needs some more testing
dbogunowicz Nov 22, 2023
af57698
ready for review
dbogunowicz Nov 27, 2023
4397c80
cleanup
dbogunowicz Nov 28, 2023
6f1214e
initial commit
dbogunowicz Nov 28, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
114 changes: 103 additions & 11 deletions src/deepsparse/transformers/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,24 +17,26 @@
"""


import logging
import os
import re
from pathlib import Path
from tempfile import NamedTemporaryFile
from typing import List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union

import numpy
import onnx
import transformers
from onnx import ModelProto

from deepsparse.log import get_main_logger
from deepsparse.utils.onnx import _MODEL_DIR_ONNX_NAME, truncate_onnx_model
from deepsparse.utils.onnx import MODEL_ONNX_NAME, truncate_onnx_model
from sparsezoo import Model
from sparsezoo.utils import save_onnx


__all__ = [
"get_deployment_path",
"setup_transformers_pipeline",
"overwrite_transformer_onnx_model_inputs",
"fix_numpy_types",
"get_transformer_layer_init_names",
Expand All @@ -44,7 +46,94 @@
_LOGGER = get_main_logger()


def get_deployment_path(model_path: str) -> Tuple[str, str]:
def setup_transformers_pipeline(
model_path: str,
sequence_length: int,
tokenizer_padding_side: str = "left",
engine_kwargs: Optional[Dict] = None,
onnx_model_name: Optional[str] = None,
) -> Tuple[
str, transformers.PretrainedConfig, transformers.PreTrainedTokenizer, Dict[str, Any]
]:
"""
A helper function that sets up the model path, config, tokenizer,
and engine kwargs for a transformers model.
:param model_path: The path to the model to load
:param sequence_length: The sequence length to use for the model
:param tokenizer_padding_side: The side to pad on for the tokenizer,
either "left" or "right"
:param engine_kwargs: The kwargs to pass to the engine
:param onnx_model_name: The name of the onnx model to be loaded.
If not specified, defaults are used (see setup_onnx_file_path)
:return The model path, config, tokenizer, and engine kwargs
"""
model_path, config, tokenizer = setup_onnx_file_path(
model_path, sequence_length, onnx_model_name
)

tokenizer.padding_side = tokenizer_padding_side
if not tokenizer.pad_token:
tokenizer.pad_token = tokenizer.eos_token

engine_kwargs = engine_kwargs or {}
if engine_kwargs.get("model_path"):
raise ValueError(
"The engine kwargs already specify "
f"a model path: {engine_kwargs['model_path']}, "
f"but a model path was also provided: {model_path}. "
"Please only provide one."
)
engine_kwargs["model_path"] = model_path
return model_path, config, tokenizer, engine_kwargs


def setup_onnx_file_path(
model_path: str,
sequence_length: int,
onnx_model_name: Optional[str] = None,
task: Optional[str] = None,
) -> Tuple[str, transformers.PretrainedConfig, transformers.PreTrainedTokenizer]:
"""
Parses ONNX model from the `model_path` provided. It additionally
creates config and tokenizer objects from the `deployment path`,
derived from the `model_path` provided.
:param model_path: path to the model to be parsed
:param sequence_length: maximum sequence length of the model
:param onnx_model_name: optionally, the precise name of the ONNX model
of interest may be specified. If not specified, the default ONNX model
name will be used (refer to `get_deployment_path` for details)
:return: file path to the processed ONNX file for the engine to compile
"""
deployment_path, onnx_path = get_deployment_path(model_path, onnx_model_name)

hf_logger = logging.getLogger("transformers")
hf_logger_level = hf_logger.level
hf_logger.setLevel(logging.ERROR)

config = transformers.PretrainedConfig.from_pretrained(
deployment_path, finetuning_task=task
)
hf_logger.setLevel(hf_logger_level)

trust_remote_code = False
tokenizer = transformers.AutoTokenizer.from_pretrained(
deployment_path,
trust_remote_code=trust_remote_code,
model_max_length=sequence_length,
)

if not config or not tokenizer:
raise RuntimeError(
"Invalid config or tokenizer provided. Please provide "
"paths to the files or ensure they exist in the `model_path` provided. "
"See `tokenizer` and `config` arguments for details."
)
return onnx_path, config, tokenizer


def get_deployment_path(
model_path: str, onnx_model_name: Optional[str] = None
) -> Tuple[str, str]:
"""
Returns the path to the deployment directory
for the given model path and the path to the mandatory
Expand All @@ -53,36 +142,39 @@ def get_deployment_path(model_path: str) -> Tuple[str, str]:
for running the transformers model in the deepsparse pipeline

:param model_path: path to model directory, sparsezoo stub, or ONNX file
:param onnx_model_name: name of the ONNX file to look for in the deployment
directory. Defaults to MODEL_ONNX_NAME
:return: path to the deployment directory and path to the ONNX file inside
the deployment directory
"""
onnx_model_name = onnx_model_name or MODEL_ONNX_NAME
if os.path.isfile(model_path):
# return the parent directory of the ONNX file
return os.path.dirname(model_path), model_path

if os.path.isdir(model_path):
model_files = os.listdir(model_path)

if _MODEL_DIR_ONNX_NAME not in model_files:
if onnx_model_name not in model_files:
raise ValueError(
f"{_MODEL_DIR_ONNX_NAME} not found in transformers model directory "
f"{onnx_model_name} not found in transformers model directory "
f"{model_path}. Be sure that an export of the model is written to "
f"{os.path.join(model_path, _MODEL_DIR_ONNX_NAME)}"
f"{os.path.join(model_path, onnx_model_name)}"
)
return model_path, os.path.join(model_path, _MODEL_DIR_ONNX_NAME)
return model_path, os.path.join(model_path, onnx_model_name)

elif model_path.startswith("zoo:"):
zoo_model = Model(model_path)
deployment_path = zoo_model.deployment_directory_path
return deployment_path, os.path.join(deployment_path, _MODEL_DIR_ONNX_NAME)
return deployment_path, os.path.join(deployment_path, onnx_model_name)
elif model_path.startswith("hf:"):
from huggingface_hub import snapshot_download

deployment_path = snapshot_download(repo_id=model_path.replace("hf:", "", 1))
onnx_path = os.path.join(deployment_path, _MODEL_DIR_ONNX_NAME)
onnx_path = os.path.join(deployment_path, onnx_model_name)
if not os.path.isfile(onnx_path):
raise ValueError(
f"{_MODEL_DIR_ONNX_NAME} not found in transformers model directory "
f"{onnx_model_name} not found in transformers model directory "
f"{deployment_path}. Be sure that an export of the model is written to "
f"{onnx_path}"
)
Expand Down
38 changes: 11 additions & 27 deletions src/deepsparse/transformers/pipelines/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,18 @@
Base Pipeline class for transformers inference pipeline
"""

import logging

import warnings
from pathlib import Path
from typing import Any, Dict, List, Mapping, Optional, Union

import numpy
import transformers
from transformers.models.auto import AutoTokenizer

from deepsparse import Bucketable, Pipeline
from deepsparse.transformers.helpers import overwrite_transformer_onnx_model_inputs
from deepsparse.transformers.helpers import (
get_deployment_path,
overwrite_transformer_onnx_model_inputs,
setup_onnx_file_path as setup_onnx_file_path_v2,
)


Expand Down Expand Up @@ -124,24 +123,15 @@ def setup_onnx_file_path(self) -> str:

:return: file path to the processed ONNX file for the engine to compile
"""
deployment_path, onnx_path = get_deployment_path(self.model_path)

# temporarily set transformers logger to ERROR to avoid
# printing misleading warnings
hf_logger = logging.getLogger("transformers")
hf_logger_level = hf_logger.level
hf_logger.setLevel(logging.ERROR)
self.config = transformers.PretrainedConfig.from_pretrained(
deployment_path,
finetuning_task=self.task if hasattr(self, "task") else None,
)
hf_logger.setLevel(hf_logger_level)

self.tokenizer = AutoTokenizer.from_pretrained(
deployment_path,
trust_remote_code=self._trust_remote_code,
model_max_length=self.sequence_length,
# we will be soon retiring V1 pipelines. This is why I am deciding
# to reuse the functions from V2 pipelines in the (soon) legacy pipelines
onnx_path, config, tokenizer = setup_onnx_file_path_v2(
model_path=self.model_path,
sequence_length=self.sequence_length,
task=self.task if hasattr(self, "task") else None,
)
self.config = config
self.tokenizer = tokenizer

if not self._delay_overwriting_inputs:
# overwrite onnx graph to given required input shape
Expand All @@ -153,12 +143,6 @@ def setup_onnx_file_path(self) -> str:
onnx_path, max_length=self.sequence_length
)

if not self.config or not self.tokenizer:
raise RuntimeError(
"Invalid config or tokenizer provided. Please provide "
"paths to the files or ensure they exist in the `model_path` provided. "
"See `tokenizer` and `config` arguments for details."
)
return onnx_path

def tokens_to_engine_input(
Expand Down
92 changes: 91 additions & 1 deletion src/deepsparse/transformers/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import logging
import pathlib
import uuid
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import numpy
from transformers import AutoTokenizer, GenerationConfig
Expand All @@ -33,6 +33,7 @@
"override_config",
"process_generation_config",
"validate_session_ids",
"compute_engine_inputs",
"set_generated_length",
]

Expand Down Expand Up @@ -82,6 +83,95 @@ def set_generated_length(
)


def compute_engine_inputs(onnx_input_names: str, **kwargs) -> List[numpy.ndarray]:
"""
Given the names of the onnx inputs, compute the inputs
to the engine. The inputs will be calculating from the
passed kwargs. The information about the required kwargs
can be found in the docstring of the individual compute
functions.

:param onnx_input_names: The names of the onnx inputs
:param kwargs: The kwargs to compute the inputs from
:return: The computed inputs to the engine
"""
engine_inputs = []
for input_name in onnx_input_names:
if input_name == "causal_mask":
# delay the computation of the causal mask
continue
# fetch the compute function for the
# given input_name
compute_func = _get_compute_func(input_name)
# compute the engine input from the kwargs
# and append it to the engine_inputs
engine_inputs.append(compute_func(**kwargs))

if "causal_mask" in onnx_input_names:
# compute the causal mask and append it to the engine_inputs
input_ids, attention_mask, *_ = engine_inputs
engine_inputs.append(create_causal_mask(input_ids, attention_mask))

return engine_inputs


def _get_compute_func(input_name: str) -> Callable[..., numpy.ndarray]:
# given the input_name, return the appropriate compute function
compute_func = {
"input_ids": _compute_input_ids,
"attention_mask": _compute_attention_mask,
"positions": _compute_positions,
}.get(input_name)
if compute_func is None:
raise ValueError(
"Could not find compute function " f"for the input_name: {input_name}"
)
return compute_func


def _compute_input_ids(token_batch: List[int], **kwargs) -> numpy.ndarray:
# convert the token_batch to a numpy array
return numpy.array([token_batch])


def _compute_attention_mask(
sequence_length: int,
prompt_sequence_length: int,
num_total_processed_tokens: int,
**kwargs,
) -> numpy.ndarray:
# create a fully masked attention mask with the appropriate
# shape (equal to the sequence_length)
attention_mask = numpy.zeros((1, sequence_length), dtype=numpy.int64)
# unmask the appropriate number of tokens, the sum of
# - the number of tokens already processed and cached (num_total_processed_tokens)
# - the number of tokens currently processed (prompt_sequence_length)
# the sum cannot exceed the maximum length of the attention_mask
num_attention_entries_to_unmask = min(
num_total_processed_tokens + prompt_sequence_length, sequence_length
)
# unmask the bits from the right-hand side
attention_mask[:, -num_attention_entries_to_unmask:] = 1
return attention_mask


def _compute_positions(
num_total_processed_tokens: int, prompt_sequence_length: int, **kwargs
):
# create the positions array with the appropriate shape
# positions count starts from the number of tokens already processed
# and ends at the number of tokens already processed + the number of tokens
# currently processed
return (
numpy.arange(
num_total_processed_tokens,
num_total_processed_tokens + prompt_sequence_length,
)
.reshape(1, -1)
.astype(numpy.int64)
)


def validate_session_ids(
session_ids: Optional[str], other_attributes: Dict[str, Any]
) -> Optional[List[str]]:
Expand Down
11 changes: 6 additions & 5 deletions src/deepsparse/transformers/utils/token_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,16 +77,17 @@ def generate(self, logits: numpy.ndarray) -> numpy.ndarray:
:param logits: the logits from the model with shape (vocab_size,)
:return: the sampled token
"""
if self.top_k:
logits = self.apply_top_k(logits)
if self.top_p:
logits = self.apply_top_p(logits)

if self.deterministic:
token = numpy.argmax(logits)
self.tokens.append(token)
return token

if self.top_k:
logits = self.apply_top_k(logits)

if self.top_p:
logits = self.apply_top_p(logits)

if self.sampling_temperature != 1.0:
logits /= self.sampling_temperature

Expand Down