Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enabling word-level timestamps for all W2L Decoders #5403

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
10 changes: 5 additions & 5 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ default_language_version:

repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.1.0
rev: v4.5.0
hooks:
- id: trailing-whitespace
- id: check-ast
Expand All @@ -17,13 +17,13 @@ repos:
- id: end-of-file-fixer

- repo: https://github.com/ambv/black
rev: 22.3.0
rev: 23.12.0
hooks:
- id: black
language_version: python3.8

- repo: https://gitlab.com/pycqa/flake8
rev: 3.9.2
- repo: https://github.com/pycqa/flake8
rev: 6.1.0
hooks:
- id: flake8
args: [
Expand All @@ -32,7 +32,7 @@ repos:
]

- repo: https://github.com/pycqa/isort
rev: 5.10.1
rev: 5.13.2
hooks:
- id: isort
exclude: README.md
Expand Down
106 changes: 72 additions & 34 deletions examples/speech_recognition/w2l_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from fairseq.utils import apply_to_sample
from omegaconf import open_dict
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
from fairseq.data.data_utils import post_process


try:
Expand Down Expand Up @@ -51,6 +52,7 @@ def __init__(self, args, tgt_dict):
self.tgt_dict = tgt_dict
self.vocab_size = len(tgt_dict)
self.nbest = args.nbest
self.symbols = tgt_dict.symbols # symbols (usually chars) understood by the ASR model, that are predicted in the emission matrix.

# criterion-specific init
self.criterion_type = CriterionType.CTC
Expand Down Expand Up @@ -93,11 +95,52 @@ def get_tokens(self, idxs):
idxs = filter(lambda x: x != self.blank, idxs)
return torch.LongTensor(list(idxs))

def get_timesteps(self, token_idxs: List[int]) -> List[int]:
"""Returns frame numbers corresponding to every non-blank token.

Parameters
----------
token_idxs : List[int]
IDs of decoded tokens (including blank tokens), i.e. list of tokens spanning all frames of the emission matrix.

Returns
-------
List[int]
Frame numbers corresponding to every non-blank token.
"""
timesteps = []
for i, token_idx in enumerate(token_idxs):
if token_idx == self.blank:
continue
if i == 0 or token_idx != token_idxs[i-1]:
timesteps.append(i)

return timesteps

def get_symbols(self, token_idxs: List[int]) -> List[int]:
"""Returns characters corresponding to every non-blank token.

Parameters
----------
token_idxs : List[int]
IDs of non-blank tokens.

Returns
-------
List[int]
Character corresponding to every non-blank token.
"""
chars = []
for token_idx in token_idxs:
chars.append(self.symbols[token_idx])

return chars


class W2lViterbiDecoder(W2lDecoder):
def __init__(self, args, tgt_dict):
super().__init__(args, tgt_dict)

def decode(self, emissions):
B, T, N = emissions.size()
hypos = []
Expand All @@ -116,10 +159,21 @@ def decode(self, emissions):
get_data_ptr_as_bytes(viterbi_path),
get_data_ptr_as_bytes(workspace),
)
return [
[{"tokens": self.get_tokens(viterbi_path[b].tolist()), "score": 0}]
for b in range(B)
]

for b in range(B):
hypos.append(
[
{
"tokens": self.get_tokens(viterbi_path[b].tolist()).tolist(), # non-blank token idxs.
"symbols": self.get_symbols(self.get_tokens(viterbi_path[b].tolist()).tolist()), # characters (symbols) corresponding to non-blank token idxs.
"score": 0,
"timesteps": self.get_timesteps(viterbi_path[b].tolist()), # frame numbers of non-blank tokens.
"words": post_process(self.tgt_dict.string(self.get_tokens(viterbi_path[b].tolist()).int().cpu()), 'letter').split(' ') # the transcript as a list of words.
abarcovschi marked this conversation as resolved.
Show resolved Hide resolved
}
]
)

return hypos


class W2lKenLMDecoder(W2lDecoder):
Expand Down Expand Up @@ -195,27 +249,6 @@ def __init__(self, args, tgt_dict):
self.decoder_opts, self.lm, self.silence, self.blank, []
)

def get_timesteps(self, token_idxs: List[int]) -> List[int]:
"""Returns frame numbers corresponding to every non-blank token.

Parameters
----------
token_idxs : List[int]
IDs of decoded tokens.

Returns
-------
List[int]
Frame numbers corresponding to every non-blank token.
"""
timesteps = []
for i, token_idx in enumerate(token_idxs):
if token_idx == self.blank:
continue
if i == 0 or token_idx != token_idxs[i-1]:
timesteps.append(i)
return timesteps

def decode(self, emissions):
B, T, N = emissions.size()
hypos = []
Expand All @@ -227,12 +260,11 @@ def decode(self, emissions):
hypos.append(
[
{
"tokens": self.get_tokens(result.tokens),
"tokens": self.get_tokens(result.tokens).tolist(), # non-blank token idxs.
"symbols": self.get_symbols(self.get_tokens(result.tokens)), # characters (symbols) corresponding to non-blank token idxs.
"score": result.score,
"timesteps": self.get_timesteps(result.tokens),
"words": [
self.word_dict.get_entry(x) for x in result.words if x >= 0
],
"timesteps": self.get_timesteps(result.tokens), # frame numbers of non-blank tokens.
"words": [self.word_dict.get_entry(x) for x in result.words if x >= 0] # the transcript as a list of words. Empty if lexicon-free decoding.
abarcovschi marked this conversation as resolved.
Show resolved Hide resolved
}
for result in nbest_results
]
Expand Down Expand Up @@ -470,9 +502,15 @@ def idx_to_word(idx):
return self.word_dict[idx]

def make_hypo(result):
hypo = {"tokens": self.get_tokens(result.tokens), "score": result.score}
hypo = {
"tokens": self.get_tokens(result.tokens).tolist(), # non-blank token idxs.
"score": result.score
}
if self.lexicon:
hypo["words"] = [idx_to_word(x) for x in result.words if x >= 0]
hypo["words"] = [idx_to_word(x) for x in result.words if x >= 0] # the transcript as a list of words. Empty if lexicon-free decoding.
hypo["symbols"] = self.get_symbols(hypo["tokens"]) # characters (symbols) corresponding to non-blank token idxs.
hypo["timesteps"] = self.get_timesteps(result.tokens) # frame numbers of non-blank tokens.

return hypo

for b in range(B):
Expand All @@ -483,4 +521,4 @@ def make_hypo(result):
hypos.append([make_hypo(result) for result in nbest_results])
self.lm.empty_cache()

return hypos
return hypos