Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enabling word-level timestamps for all W2L Decoders #5403

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
10 changes: 5 additions & 5 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ default_language_version:

repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.1.0
rev: v4.5.0
hooks:
- id: trailing-whitespace
- id: check-ast
Expand All @@ -17,13 +17,13 @@ repos:
- id: end-of-file-fixer

- repo: https://github.com/ambv/black
rev: 22.3.0
rev: 23.12.0
hooks:
- id: black
language_version: python3.8

- repo: https://gitlab.com/pycqa/flake8
rev: 3.9.2
- repo: https://github.com/pycqa/flake8
rev: 6.1.0
hooks:
- id: flake8
args: [
Expand All @@ -32,7 +32,7 @@ repos:
]

- repo: https://github.com/pycqa/isort
rev: 5.10.1
rev: 5.13.2
hooks:
- id: isort
exclude: README.md
Expand Down
162 changes: 119 additions & 43 deletions examples/speech_recognition/w2l_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,32 +12,33 @@
import gc
import itertools as it
import os.path as osp
from typing import List
import warnings
from collections import deque, namedtuple
from typing import List

import numpy as np
import torch
from omegaconf import open_dict

from examples.speech_recognition.data.replabels import unpack_replabels
from fairseq import tasks
from fairseq.utils import apply_to_sample
from omegaconf import open_dict
from fairseq.data.data_utils import post_process
from fairseq.dataclass.utils import convert_namespace_to_omegaconf

from fairseq.utils import apply_to_sample

try:
from flashlight.lib.text.dictionary import create_word_dict, load_words
from flashlight.lib.sequence.criterion import CpuViterbiPath, get_data_ptr_as_bytes
from flashlight.lib.text.decoder import (
LM,
CriterionType,
LexiconDecoderOptions,
KenLM,
LM,
LexiconDecoder,
LexiconDecoderOptions,
LMState,
SmearingMode,
Trie,
LexiconDecoder,
)
from flashlight.lib.text.dictionary import create_word_dict, load_words
except:
warnings.warn(
"flashlight python bindings are required to use this functionality. Please install from https://github.com/facebookresearch/flashlight/tree/master/bindings/python"
Expand All @@ -51,6 +52,9 @@ def __init__(self, args, tgt_dict):
self.tgt_dict = tgt_dict
self.vocab_size = len(tgt_dict)
self.nbest = args.nbest
self.symbols = (
tgt_dict.symbols
) # symbols (usually chars) understood by the ASR model, that are predicted in the emission matrix.

# criterion-specific init
self.criterion_type = CriterionType.CTC
Expand Down Expand Up @@ -82,7 +86,7 @@ def get_emissions(self, models, encoder_input):
model = models[0]
encoder_out = model(**encoder_input)
if hasattr(model, "get_logits"):
emissions = model.get_logits(encoder_out) # no need to normalize emissions
emissions = model.get_logits(encoder_out) # no need to normalize emissions
else:
emissions = model.get_normalized_probs(encoder_out, log_probs=True)
return emissions.transpose(0, 1).float().cpu().contiguous()
Expand All @@ -93,6 +97,47 @@ def get_tokens(self, idxs):
idxs = filter(lambda x: x != self.blank, idxs)
return torch.LongTensor(list(idxs))

def get_timesteps(self, token_idxs: List[int]) -> List[int]:
"""Returns frame numbers corresponding to every non-blank token.

Parameters
----------
token_idxs : List[int]
IDs of decoded tokens (including blank tokens), i.e. list of tokens spanning all frames of the emission matrix.

Returns
-------
List[int]
Frame numbers corresponding to every non-blank token.
"""
timesteps = []
for i, token_idx in enumerate(token_idxs):
if token_idx == self.blank:
continue
if i == 0 or token_idx != token_idxs[i - 1]:
timesteps.append(i)

return timesteps

def get_symbols(self, token_idxs: List[int]) -> List[int]:
"""Returns characters corresponding to every non-blank token.

Parameters
----------
token_idxs : List[int]
IDs of non-blank tokens.

Returns
-------
List[int]
Character corresponding to every non-blank token.
"""
chars = []
for token_idx in token_idxs:
chars.append(self.symbols[token_idx])

return chars


class W2lViterbiDecoder(W2lDecoder):
def __init__(self, args, tgt_dict):
Expand All @@ -116,10 +161,30 @@ def decode(self, emissions):
get_data_ptr_as_bytes(viterbi_path),
get_data_ptr_as_bytes(workspace),
)
return [
[{"tokens": self.get_tokens(viterbi_path[b].tolist()), "score": 0}]
for b in range(B)
]

for b in range(B):
tokens = self.get_tokens(viterbi_path[b].tolist()).tolist()
hypos.append(
[
{
"tokens": tokens, # non-blank token idxs.
"symbols": self.get_symbols(
tokens
), # characters (symbols) corresponding to non-blank token idxs.
"score": 0,
"timesteps": self.get_timesteps(
viterbi_path[b].tolist()
), # frame numbers of non-blank tokens.
"words": post_process(
self.tgt_dict.string(tokens), "letter"
).split(
" "
), # the transcript as a list of words.
}
]
)

return hypos


class W2lKenLMDecoder(W2lDecoder):
Expand Down Expand Up @@ -176,8 +241,13 @@ def __init__(self, args, tgt_dict):
self.unit_lm,
)
else:
assert args.unit_lm, "lexicon free decoding can only be done with a unit language model"
from flashlight.lib.text.decoder import LexiconFreeDecoder, LexiconFreeDecoderOptions
assert (
args.unit_lm
), "lexicon free decoding can only be done with a unit language model"
from flashlight.lib.text.decoder import (
LexiconFreeDecoder,
LexiconFreeDecoderOptions,
)

d = {w: [[w]] for w in tgt_dict.symbols}
self.word_dict = create_word_dict(d)
Expand All @@ -195,27 +265,6 @@ def __init__(self, args, tgt_dict):
self.decoder_opts, self.lm, self.silence, self.blank, []
)

def get_timesteps(self, token_idxs: List[int]) -> List[int]:
"""Returns frame numbers corresponding to every non-blank token.

Parameters
----------
token_idxs : List[int]
IDs of decoded tokens.

Returns
-------
List[int]
Frame numbers corresponding to every non-blank token.
"""
timesteps = []
for i, token_idx in enumerate(token_idxs):
if token_idx == self.blank:
continue
if i == 0 or token_idx != token_idxs[i-1]:
timesteps.append(i)
return timesteps

def decode(self, emissions):
B, T, N = emissions.size()
hypos = []
Expand All @@ -227,14 +276,22 @@ def decode(self, emissions):
hypos.append(
[
{
"tokens": self.get_tokens(result.tokens),
"tokens": tokens, # non-blank token idxs.
"symbols": self.get_symbols(
tokens
), # characters (symbols) corresponding to non-blank token idxs.
"score": result.score,
"timesteps": self.get_timesteps(result.tokens),
"timesteps": self.get_timesteps(
result.tokens
), # frame numbers of non-blank tokens.
"words": [
self.word_dict.get_entry(x) for x in result.words if x >= 0
],
], # the transcript as a list of words. Empty if lexicon-free decoding.
}
for result in nbest_results
if (
tokens := self.get_tokens(result.tokens).tolist()
) # tokens is a local variable for the list comprehension.
]
)
return hypos
Expand Down Expand Up @@ -440,8 +497,13 @@ def __init__(self, args, tgt_dict):
self.unit_lm,
)
else:
assert args.unit_lm, "lexicon free decoding can only be done with a unit language model"
from flashlight.lib.text.decoder import LexiconFreeDecoder, LexiconFreeDecoderOptions
assert (
args.unit_lm
), "lexicon free decoding can only be done with a unit language model"
from flashlight.lib.text.decoder import (
LexiconFreeDecoder,
LexiconFreeDecoderOptions,
)

d = {w: [[w]] for w in tgt_dict.symbols}
self.word_dict = create_word_dict(d)
Expand Down Expand Up @@ -470,9 +532,23 @@ def idx_to_word(idx):
return self.word_dict[idx]

def make_hypo(result):
hypo = {"tokens": self.get_tokens(result.tokens), "score": result.score}
hypo = {
"tokens": self.get_tokens(
result.tokens
).tolist(), # non-blank token idxs.
"score": result.score,
}
if self.lexicon:
hypo["words"] = [idx_to_word(x) for x in result.words if x >= 0]
hypo["words"] = [
idx_to_word(x) for x in result.words if x >= 0
] # the transcript as a list of words. Empty if lexicon-free decoding.
hypo["symbols"] = self.get_symbols(
hypo["tokens"]
) # characters (symbols) corresponding to non-blank token idxs.
hypo["timesteps"] = self.get_timesteps(
result.tokens
) # frame numbers of non-blank tokens.

return hypo

for b in range(B):
Expand Down