Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add return_token_timestamps to WhisperProcessor #30812

Conversation

kamilakesbi
Copy link
Contributor

@kamilakesbi kamilakesbi commented May 14, 2024

What this PR do ?

This PR fixes #30433 by making sure we can compute timestamps with both WhisperForConditionalGeneration and AutomaticSpeechRecognitionPipeline.

We add a return_timestamps hyperparameter to WhisperProcessor.feature_extractor to be used when we want to compute timestamps. When True, the processor will return a num_frames parameter containing the number of frames of the input audios. num_frames is then passed to generate and used to compute timestamps.

Prior to that, timestamps were broken for whisper-large-v3 when used with WhisperForConditionalGeneration.

Who can review ?

cc @sanchit-gandhi

@kamilakesbi kamilakesbi changed the title [WIP] - Add return_num_frames in WhisperProcessor [WIP] - add num_frames parameter in WhisperProcessor to compute word level timestamps in WhisperForConditionalGeneration May 14, 2024
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@kamilakesbi kamilakesbi changed the title [WIP] - add num_frames parameter in WhisperProcessor to compute word level timestamps in WhisperForConditionalGeneration [WIP] - add return_timestamps in WhisperProcessor May 15, 2024
@kamilakesbi kamilakesbi changed the title [WIP] - add return_timestamps in WhisperProcessor [WIP] - add return_timestamps to WhisperProcessor May 15, 2024
Copy link
Contributor

@sanchit-gandhi sanchit-gandhi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Think this is indeed the cleanest and most reliable approach for computing num_frames. The alternative method we discussed offline is detailed below. Leaving it here for the next reviewer to consider, in case they believe it's a superior strategy.

Anything beyond len(input_speech) is padded by zeros to 30-seconds in the feature extractor. If we know what zero’s correspond to in log-mel space, then we can know how many padded zeros we have in our spectrogram, and thus what the original input length was.

Note that this won’t be perfect: the last frame where the audio stops is going to be affected by the end of the audio, so we’ll be looking for the first frame where there is entirely padding (rather than finding the frame in which the audio stops).

However, the original method by OpenAI (and the one implemented in this PR) is also imperfect: if a user took a 10-second audio, and padded it by hand to 15-seconds with zeros, then num_frames would be computed on the length of the padded input, not the original one

@@ -474,6 +475,13 @@ def generate(
"The input name `inputs` is deprecated. Please make sure to use `input_features` instead.",
FutureWarning,
)

if input_features is not None and isinstance(input_features, BatchFeature):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure why this has crept in? input_features should be a tensor of shape (bsz, num_mels, num_frames), not a BatchFeature encoding. Thus, this new logic isn't required.

The correct way of using the feature extractor should be:

from transformers import WhisperProcessor, WhisperForConditionalGeneration
from datasets import load_dataset, Audio

model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")

dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
dataset = dataset.cast_column("audio", Audio(16_000))

sample = next(iter(dataset))
inputs = processor(sample["audio"]["array"], return_tensors="pt")

# note here how we un-pack the batch feature encoding
pred_ids = model.generate(**inputs, language="english")

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The output of the processor would be a BatchFeature as indicated here no ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but then we un-pack the BatchFeature when we pass it to the model, i.e. we do:

pred_ids = model.generate(**inputs)

Not:

pred_ids = model.generate(inputs)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case it will work with both packed and unpacked inputs. Isn't that better?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm aligned with @sanchit-gandhi here - handling packed and unpacked inputs isn't something any of our other processing classes handle, so it's not something we need to introduce here

kamilakesbi and others added 4 commits May 15, 2024 18:58
Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
@kamilakesbi kamilakesbi changed the title [WIP] - add return_timestamps to WhisperProcessor [WIP] - add return_token_timestamps to WhisperProcessor May 16, 2024
@kamilakesbi kamilakesbi changed the title [WIP] - add return_token_timestamps to WhisperProcessor add return_token_timestamps to WhisperProcessor May 16, 2024
Copy link
Contributor

@sanchit-gandhi sanchit-gandhi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good! Mostly formatting now, then we can get a final review

src/transformers/models/whisper/generation_whisper.py Outdated Show resolved Hide resolved
@@ -474,6 +475,13 @@ def generate(
"The input name `inputs` is deprecated. Please make sure to use `input_features` instead.",
FutureWarning,
)

if input_features is not None and isinstance(input_features, BatchFeature):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but then we un-pack the BatchFeature when we pass it to the model, i.e. we do:

pred_ids = model.generate(**inputs)

Not:

pred_ids = model.generate(inputs)

tests/models/whisper/test_modeling_whisper.py Outdated Show resolved Hide resolved
tests/models/whisper/test_modeling_whisper.py Show resolved Hide resolved
kamilakesbi and others added 5 commits May 16, 2024 17:29
Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this feature and tests!

All looks good to me - just the handling of unpacked features to remove

@@ -474,6 +475,13 @@ def generate(
"The input name `inputs` is deprecated. Please make sure to use `input_features` instead.",
FutureWarning,
)

if input_features is not None and isinstance(input_features, BatchFeature):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm aligned with @sanchit-gandhi here - handling packed and unpacked inputs isn't something any of our other processing classes handle, so it's not something we need to introduce here

kamilakesbi and others added 2 commits May 16, 2024 23:57
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
@kamilakesbi kamilakesbi added Audio Good Second Issue Issues that are more difficult to do than "Good First" issues - give it a try if you want! labels May 17, 2024
@kamilakesbi kamilakesbi self-assigned this May 17, 2024
Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding!

…b.com:kamilakesbi/transformers into timestamps_whisper_for_conditional_generation
@@ -1927,7 +1927,117 @@ def test_large_timestamp_generation(self):

generated_ids = model.generate(input_features, max_length=448, return_timestamps=True).to("cpu")

EXPECTED_OUTPUT = torch.tensor([50258, 50259, 50360, 50365, 2221, 13, 2326, 388, 391, 307, 264, 50244, 295, 264, 2808, 5359, 11, 293, 321, 366, 5404, 281, 2928, 702, 14943, 13, 50629, 50682, 6966, 307, 2221, 13, 2326, 388, 391, 311, 9060, 1570, 1880, 813, 702, 1871, 13, 50870, 50911, 634, 5112, 505, 300, 412, 341, 42729, 3196, 295, 264, 1064, 11, 365, 5272, 293, 12904, 9256, 450, 10539, 949, 505, 11, 51245, 51287, 1034, 4680, 10117, 490, 3936, 293, 1080, 3542, 5160, 881, 26336, 281, 264, 1575, 13, 51494, 51523, 634, 575, 12525, 22618, 1968, 6144, 35617, 1456, 397, 266, 311, 589, 307, 534, 10281, 934, 439, 11, 51799, 51815, 50257])
EXPECTED_OUTPUT = torch.tensor(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we want to split across the lines like this.

You can wrap EXPECTED_OUTPUT around # fmt: off and fmt: on comments to avoid this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok thanks for the tips! will be useful ;)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can also use # fmt: skip for single lines, c.f. the previous comment #30812 (comment)

@kamilakesbi
Copy link
Contributor Author

cc @amyeroberts @sanchit-gandhi Could you please merge this PR as I don't have the rights to do so?
Thanks!

@amyeroberts amyeroberts merged commit 1c2bb3a into huggingface:main May 20, 2024
21 checks passed
itazap pushed a commit that referenced this pull request May 24, 2024
* compute num_frames in WhisperFeatureExtractor

* add return_num_frames in WhisperFeatureProcessor + adapt pipeline

* return_timestamps renaming + pipeline fix

* fix

* fix

* fix

* add tests

* Update src/transformers/models/whisper/feature_extraction_whisper.py

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* apply review changes

* fix

* Update src/transformers/models/whisper/feature_extraction_whisper.py

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* Update tests/models/whisper/test_modeling_whisper.py

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* apply review

* fix

* review changes

* Update src/transformers/models/whisper/feature_extraction_whisper.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* make style quality

* EXPECTED_OUTPUT in single line

* small numpy->torch fix

* fix

---------

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Audio Good Second Issue Issues that are more difficult to do than "Good First" issues - give it a try if you want!
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Timestamps are broken for whisper large with WhisperForConditionalGeneration
4 participants