Skip to content

Commit

Permalink
feat: generation peft (#822)
Browse files Browse the repository at this point in the history
  • Loading branch information
Jonathan Gomes Selman committed Dec 26, 2023
1 parent 908af1d commit a30ca91
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 4 deletions.
2 changes: 1 addition & 1 deletion dataquality/__init__.py
Expand Up @@ -31,7 +31,7 @@
"""


__version__ = "1.4.1"
__version__ = "1.4.2"

import sys
from typing import Any, List, Optional
Expand Down
4 changes: 3 additions & 1 deletion dataquality/integrations/seq2seq/core.py
@@ -1,6 +1,7 @@
from typing import List, Optional, Union
from warnings import warn

from peft import PeftModel
from tokenizers import Tokenizer
from transformers import GenerationConfig, PreTrainedModel, PreTrainedTokenizerFast

Expand Down Expand Up @@ -172,8 +173,9 @@ def watch(
# A model of the correct type is required if we need to generate
if generation_splits:
assert isinstance(
model, PreTrainedModel
model, (PreTrainedModel, PeftModel)
), "model must be an instance of transformers PreTrainedModel"

assert (
model.can_generate()
), "model must contain a `generate` method for seq2seq"
Expand Down
5 changes: 3 additions & 2 deletions dataquality/loggers/logger_config/seq2seq/seq2seq_base.py
@@ -1,6 +1,7 @@
from collections import defaultdict
from typing import Dict, List, Optional, Set
from typing import Dict, List, Optional, Set, Union

from peft import PeftModel
from transformers import GenerationConfig, PreTrainedModel, PreTrainedTokenizerFast

from dataquality.loggers.logger_config.base_logger_config import BaseLoggerConfig
Expand All @@ -15,7 +16,7 @@ class Seq2SeqLoggerConfig(BaseLoggerConfig):
max_target_tokens: Optional[int] = None
# For each split/inference-name, store sample id -> List[token_id] for the label
id_to_tokens: Dict[str, Dict[int, List[int]]] = defaultdict(dict)
model: Optional[PreTrainedModel] = None
model: Optional[Union[PreTrainedModel, PeftModel]] = None
generation_config: Optional[GenerationConfig] = None
generation_splits: Set[Split] = set()
model_type: Optional[Seq2SeqModelType] = None
Expand Down
14 changes: 14 additions & 0 deletions dataquality/utils/seq2seq/generation.py
Expand Up @@ -142,6 +142,17 @@ def add_generated_output_to_df(
Updated Dataframe with the generated columns added (see above)
"""
model.eval()
# When generating it is important to set `use_cache = True`.
# - WHAT? Caching stores intermediate token activations / representations.
# During autoregressive generation, the cache is updated each time a token
# is generated.
# - WHY? Caching prevents re-computing token information during auto-regressive
# generation, DRAMATICALLY speeding up performance. Every time a new token is
# generated, we only need to do the forward pass for a single new token, as we
# leverage the cached information to compute transformer based attention.
model_cache_flag = model.config.use_cache
model.config.use_cache = True

generated_data = BatchGenerationData()

num_batches = math.ceil(len(df) / GENERATION_BATCH_SIZE)
Expand Down Expand Up @@ -183,4 +194,7 @@ def add_generated_output_to_df(
generated_data.generated_top_logprobs, type=TOP_LOGPROBS_SCHEMA
)

# Reset the cache flag for the model
model.config.use_cache = model_cache_flag

return df
1 change: 1 addition & 0 deletions pyproject.toml
Expand Up @@ -41,6 +41,7 @@ dependencies = [
"ipywidgets>=8.1.0",
"imagededup>=0.3.1",
"pyjwt>=2.8.0",
"peft"
]
[[project.authors]]
name = "Galileo Technologies, Inc."
Expand Down

0 comments on commit a30ca91

Please sign in to comment.