Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multires hubert #5363

Merged
merged 10 commits into from
Feb 26, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
185 changes: 185 additions & 0 deletions examples/mr_hubert/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
# MR-HuBERT

## Pre-trained models

### Main models
Model | Pretraining Data | Model | Paper Reference
|---|---|---|---
MR-HuBERT Base (~97M) | [Librispeech](http://www.openslr.org/12) 960 hr | [download](https://dl.fbaipublicfiles.com/mrhubert/mono_base/mrhubert_mono_base.pt) | mono\_base
MR-HuBERT Base (~321M) | [Libri-Light](https://github.com/facebookresearch/libri-light) 60k hr | [download](https://dl.fbaipublicfiles.com/mrhubert/mono_large/mrhubert_mono_large.pt) | mono\_large
Multilingual MR-HuBERT Base (~97M) | [Voxpopuli]() 100k hr | [download](https://dl.fbaipublicfiles.com/mrhubert/multi_base/multi_base.pt) | multi\_base
Multilingual MR-HuBERT Large (~321M) | [Voxpopuli]() 100k hr | [download](https://dl.fbaipublicfiles.com/mrhubert/multi_large/multi_large.pt) | Not in the paper


### Abalation models
Model | Pretraining Data | Model | Paper Reference
|---|---|---|---
MR-HuBERT Base (2-4-6 lyrs) | [Librispeech](http://www.openslr.org/12) 960 hr | [download](https://dl.fbaipublicfiles.com/mrhubert/b1-a/b1-a.pt) | (B.1)-a
MR-HuBERT Base (5-2-5 lyrs) | [Librispeech](http://www.openslr.org/12) 960 hr | [download](https://dl.fbaipublicfiles.com/mrhubert/b1-b/b1-b.pt) | (B.1)-b
MR-HuBERT Base (6-4-2 lyrs) | [Librispeech](http://www.openslr.org/12) 960 hr | [download](https://dl.fbaipublicfiles.com/mrhubert/b1-c/b1-c.pt) | (B.1)-c
MR-HuBERT Base (3res 3-2-2-2-3 lyrs) | [Librispeech](http://www.openslr.org/12) 960 hr | [download](https://dl.fbaipublicfiles.com/mrhubert/b2-a/b2-a.pt) | (B.2)-a
MR-HuBERT Base (3res 2-2-4-2-2 lyrs) | [Librispeech](http://www.openslr.org/12) 960 hr | [download](https://dl.fbaipublicfiles.com/mrhubert/b2-b/b2-b.pt) | (B.2)-b
MR-HuBERT Base (3res 2-2-2-2-2 lyrs) | [Librispeech](http://www.openslr.org/12) 960 hr | [download](https://dl.fbaipublicfiles.com/mrhubert/b2-c/b2-c.pt) | (B.2)-c
MR-HuBERT Base (Simple sampling) | [Librispeech](http://www.openslr.org/12) 960 hr | [download](https://dl.fbaipublicfiles.com/mrhubert/b3-a/b3-a.pt) | (B.3)-a
MR-HuBERT Base (Single target) | [Librispeech](http://www.openslr.org/12) 960 hr | [download](https://dl.fbaipublicfiles.com/mrhubert/b4-a/b4-a.pt) | (B.4)-a
MR-HuBERT Base (Simple Sampling + single target) | [Librispeech](http://www.openslr.org/12) 960 hr | [download](https://dl.fbaipublicfiles.com/mrhubert/b4-b/b4-b.pt) | (B.4)-b
MR-HuBERT Base (Mono-resolution 20ms) | [Librispeech](http://www.openslr.org/12) 960 hr | [download](https://dl.fbaipublicfiles.com/mrhubert/b5-a/b5-a.pt) | (B.5)-a
MR-HuBERT Base (3-3-3 lyrs) | [Librispeech](http://www.openslr.org/12) 960 hr | [download](https://dl.fbaipublicfiles.com/mrhubert/b6-a/b6-a.pt) | (B.6)-a
MR-HuBERT Base (Mono-resolution 20ms, 3-3-3 lyrs) | [Librispeech](http://www.openslr.org/12) 960 hr | [download](https://dl.fbaipublicfiles.com/mrhubert/b6-b/b6-b.pt) | (B.6)-b
MR-HuBERT Base (HuBERT 20ms&40ms units) | [Librispeech](http://www.openslr.org/12) 960 hr | [download](https://dl.fbaipublicfiles.com/mrhubert/b7-a/b7-a.pt) | (B.7)-a
MR-HuBERT Base (Encodec 50Hz unit) | [Librispeech](http://www.openslr.org/12) 960 hr | [download](https://dl.fbaipublicfiles.com/mrhubert/b7-b/b7-b.pt) | (B.7)-b
MR-HuBERT Base (Encodec 50Hz units and 25Hz units) | [Librispeech](http://www.openslr.org/12) 960 hr | [download](https://dl.fbaipublicfiles.com/mrhubert/b7-c/b7-c.pt) | (B.7)-c
MR-HuBERT Base (Encodec 50Hz units stream 0&1 ) | [Librispeech](http://www.openslr.org/12) 960 hr | [download](https://dl.fbaipublicfiles.com/mrhubert/b7-d/b7-d.pt) | (B.7)-d
MR-HuBERT Large (no audio norm) | [LibriLight](https://github.com/facebookresearch/libri-light) 60k hr | [download](https://dl.fbaipublicfiles.com/mrhubert/b8-a/b8-a.pt) | (B.8)-a
MR-HuBERT Large (check paper ) | [LibriLight](https://github.com/facebookresearch/libri-light) 60k hr | [download](https://dl.fbaipublicfiles.com/mrhubert/b8-a/b8-b.pt) | (B.8)-b
MR-HuBERT Large (check paper ) | [LibriLight](https://github.com/facebookresearch/libri-light) 60k hr | [download](https://dl.fbaipublicfiles.com/mrhubert/b8-a/b8-c.pt) | (B.8)-c
MR-HuBERT Large (check paper ) | [LibriLight](https://github.com/facebookresearch/libri-light) 60k hr | [download](https://dl.fbaipublicfiles.com/mrhubert/b8-a/b8-d.pt) | (B.8)-d
MR-HuBERT Large (check paper ) | [LibriLight](https://github.com/facebookresearch/libri-light) 60k hr | [download](https://dl.fbaipublicfiles.com/mrhubert/b8-a/b8-e.pt) | (B.8)-e
MR-HuBERT Large (check paper ) | [LibriLight](https://github.com/facebookresearch/libri-light) 60k hr | [download](https://dl.fbaipublicfiles.com/mrhubert/b8-a/b8-f.pt) | (B.8)-f
MR-HuBERT Large (check paper ) | [LibriLight](https://github.com/facebookresearch/libri-light) 60k hr | [download](https://dl.fbaipublicfiles.com/mrhubert/b8-a/b8-g.pt) | (B.8)-g
MR-HuBERT Large (check paper ) | [LibriLight](https://github.com/facebookresearch/libri-light) 60k hr | [download](https://dl.fbaipublicfiles.com/mrhubert/b8-a/b8-h.pt) | (B.8)-h
MR-HuBERT Large (check paper ) | [LibriLight](https://github.com/facebookresearch/libri-light) 60k hr | [download](https://dl.fbaipublicfiles.com/mrhubert/b8-a/b8-i.pt) | (B.8)-i
MR-HuBERT Large (check paper ) | [LibriLight](https://github.com/facebookresearch/libri-light) 60k hr | [download](https://dl.fbaipublicfiles.com/mrhubert/b8-a/b8-j.pt) | (B.8)-j
Multilingual MR-HuBERT Large (Simple sampling) | [Voxpopuli]() 100k hr | [download](https://dl.fbaipublicfiles.com/mrhubert/multi_large_simple/multi_large_simple.pt)

## Load a model
```
ckpt_path = "/path/to/the/checkpoint.pt"
models, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([ckpt_path])
model = models[0]
```

## Train a new model

### Data preparation

Follow the steps in `./simple_kmeans` to create:
- `{train,valid}.tsv` waveform list files with length information
```
/path/to/your/audio/files
file1.wav\t160000
file2.wav\t154600
...
filen.wav\t54362
```
- `{train,valid}.km` frame-aligned pseudo label files (the order is the same as wavefiles in the tsv file).
```
44 44 44 48 48 962 962 962 962 962 962 962 962 967 967 967 967 967 967 967 967 370 852 370 ... 18 18 745 745
44 44 44 48 48 962 962 962 147 147 147 147 147 147 147 147 147 147 147 147 176 176 271 271 ... 27 27 745 745
...
44 44 44 48 962 962 962 962 962 962 377 377 377 77 77 852 696 694 433 578 578 82 740 622 ... 27 27 745 745
```
- `dict.km.txt` a dummy dictionary (first column is id, the second is dummy one)
```
0 1
1 1
2 1
...
999 1
```

The `label_rate` is the same as the feature frame rate used for clustering,
which is 100Hz for MFCC features and 50Hz for HuBERT features by default.

### Pre-train a MR-HuBERT model

Suppose `{train,valid}.tsv` are saved at `/path/to/data`, `{train,valid}.km`
are saved at `/path/to/labels`, and the label rate is 100Hz.

To train a base model (12 layer transformer), run:
```sh
$ python fairseq_cli/hydra_train.py \
--config-dir /path/to/fairseq-py/examples/mr_hubert/config/pretrain \
--config-name mrhubert_base_librispeech \
task.data=/path/to/data task.label_dir=/path/to/labels \
task.labels='["km"]' model.label_rate=100 \
task.label_rate_ratios='[1, 2]' \
```

Please see sample pre-training scripts `train.sh` for an example script.

### Fine-tune a MR-HuBERT model with a CTC loss

Suppose `{train,valid}.tsv` are saved at `/path/to/data`, and their
corresponding character transcripts `{train,valid}.ltr` are saved at
`/path/to/trans`. A typical ltr file is with the same order of tsv waveform files as
```
HOW | ARE | YOU
...
THANK | YOU
```

To fine-tune a pre-trained MR-HuBERT model at `/path/to/checkpoint`, run
```sh
$ python fairseq_cli/hydra_train.py \
--config-dir /path/to/fairseq-py/examples/mr_hubert/config/finetune \
--config-name base_10h \
task.data=/path/to/data task.label_dir=/path/to/trans \
model.w2v_path=/path/to/checkpoint
```

Please see sample pre-training scripts `finetune.sh` for an example script.

### Decode a MR-HuBERT model

Suppose the `test.tsv` and `test.ltr` are the waveform list and transcripts of
the split to be decoded, saved at `/path/to/data`, and the fine-tuned model is
saved at `/path/to/checkpoint`.


We support three decoding modes:
- Viterbi decoding: greedy decoding without a language model
- KenLM decoding: decoding with an arpa-format KenLM n-gram language model
- Fairseq-LM deocding: decoding with a Fairseq neural language model (not fully tested)


#### Viterbi decoding

`task.normalize` needs to be consistent with the value used during fine-tuning.
Decoding results will be saved at
`/path/to/experiment/directory/decode/viterbi/test`.

```sh
$ python examples/speech_recognition/new/infer.py \
--config-dir /path/to/fairseq-py/examples/mr_hubert/config/decode \
--config-name infer \
task.data=/path/to/data \
task.normalize=[true|false] \
decoding.exp_dir=/path/to/experiment/directory \
common_eval.path=/path/to/checkpoint
dataset.gen_subset=test \
```

#### KenLM / Fairseq-LM decoding

Suppose the pronunciation lexicon and the n-gram LM are saved at
`/path/to/lexicon` and `/path/to/arpa`, respectively. Decoding results will be
saved at `/path/to/experiment/directory/decode/kenlm/test`.

```sh
$ python examples/speech_recognition/new/infer.py \
--config-dir /path/to/fairseq-py/examples/mr_hubert/config/decode \
--config-name infer_lm \
task.data=/path/to/data \
task.normalize=[true|false] \
decoding.exp_dir=/path/to/experiment/directory \
common_eval.path=/path/to/checkpoint
dataset.gen_subset=test \
decoding.decoder.lexicon=/path/to/lexicon \
decoding.decoder.lmpath=/path/to/arpa
```

The command above uses the default decoding hyperparameter, which can be found
in `examples/speech_recognition/hydra/decoder.py`. These parameters can be
configured from the command line. For example, to search with a beam size of
500, we can append the command above with `decoding.decoder.beam=500`.
Important parameters include:
- decoding.decoder.beam
- decoding.decoder.beamthreshold
- decoding.decoder.lmweight
- decoding.decoder.wordscore
- decoding.decoder.silweight

To decode with a Fairseq LM, you may check the usage examples in wav2vec2 or hubert examples.

Please see sample pre-training scripts `decode.sh` for an example script.
30 changes: 30 additions & 0 deletions examples/mr_hubert/config/decode/infer.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# @package _group_

defaults:
- model: null

hydra:
run:
dir: ${common_eval.results_path}/viterbi
sweep:
dir: ${common_eval.results_path}
subdir: viterbi

task:
_name: multires_hubert_pretraining
single_target: true
fine_tuning: true
label_rate_ratios: ???
data: ???
normalize: false

decoding:
type: viterbi
unique_wer_file: true
common_eval:
results_path: ???
path: ???
post_process: letter
dataset:
max_tokens: 1100000
gen_subset: ???
37 changes: 37 additions & 0 deletions examples/mr_hubert/config/decode/infer_lm.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# @package _group_

defaults:
- model: null

hydra:
run:
dir: ${common_eval.results_path}/beam${decoding.beam}_th${decoding.beamthreshold}_lmw${decoding.lmweight}_wrd${decoding.wordscore}_sil${decoding.silweight}
sweep:
dir: ${common_eval.results_path}
subdir: beam${decoding.beam}_th${decoding.beamthreshold}_lmw${decoding.lmweight}_wrd${decoding.wordscore}_sil${decoding.silweight}

task:
_name: multires_hubert_pretraining
single_target: true
fine_tuning: true
data: ???
label_rate_ratios: ???
normalize: ???

decoding:
type: kenlm
lexicon: ???
lmpath: ???
beamthreshold: 100
beam: 500
lmweight: 1.5
wordscore: -1
silweight: 0
unique_wer_file: true
common_eval:
results_path: ???
path: ???
post_process: letter
dataset:
max_tokens: 1100000
gen_subset: ???
17 changes: 17 additions & 0 deletions examples/mr_hubert/config/decode/run/submitit_slurm.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# @package _global_
hydra:
launcher:
cpus_per_task: ${distributed_training.distributed_world_size}
gpus_per_node: ${distributed_training.distributed_world_size}
tasks_per_node: ${hydra.launcher.gpus_per_node}
nodes: 1
mem_gb: 200
timeout_min: 4320
max_num_timeout: 50
name: ${hydra.job.config_name}
submitit_folder: ${hydra.sweep.dir}/submitit

distributed_training:
distributed_world_size: 1
distributed_no_spawn: true
distributed_port: 29761
17 changes: 17 additions & 0 deletions examples/mr_hubert/config/decode/run/submitit_slurm_8gpu.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# @package _global_
hydra:
launcher:
cpus_per_task: ${distributed_training.distributed_world_size}
gpus_per_node: ${distributed_training.distributed_world_size}
tasks_per_node: ${hydra.launcher.gpus_per_node}
nodes: 1
mem_gb: 200
timeout_min: 4320
max_num_timeout: 50
name: ${hydra.job.config_name}
submitit_folder: ${hydra.sweep.dir}/submitit

distributed_training:
distributed_world_size: 8
distributed_no_spawn: true
distributed_port: 29761
97 changes: 97 additions & 0 deletions examples/mr_hubert/config/finetune/base_100h.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# @package _group_

common:
fp16: true
log_format: json
log_interval: 200
tensorboard_logdir: tblog
seed: 1337

checkpoint:
no_epoch_checkpoints: true
best_checkpoint_metric: wer

distributed_training:
ddp_backend: c10d
find_unused_parameters: true
distributed_world_size: 8
distributed_port: 29671
nprocs_per_node: 8

task:
_name: multires_hubert_pretraining
data: ???
fine_tuning: true
label_dir: ???
label_rate_ratios: ???
normalize: false # must be consistent with pre-training
labels: ["ltr"]
single_target: true

dataset:
num_workers: 0
max_tokens: 3200000
validate_after_updates: ${model.freeze_finetune_updates}
validate_interval: 5
train_subset: train_100h
valid_subset: dev_other

criterion:
_name: ctc
zero_infinity: true

optimization:
max_update: 80000
lr: [3e-5]
sentence_avg: true
update_freq: [1]

optimizer:
_name: adam
adam_betas: (0.9,0.98)
adam_eps: 1e-08

lr_scheduler:
_name: tri_stage
phase_ratio: [0.1, 0.4, 0.5]
final_lr_scale: 0.05

model:
_name: multires_hubert_ctc
multires_hubert_path: ???
apply_mask: true
mask_selection: static
mask_length: 10
mask_other: 0
mask_prob: 0.75
mask_channel_selection: static
mask_channel_length: 64
mask_channel_other: 0
mask_channel_prob: 0.5
layerdrop: 0.1
dropout: 0.0
activation_dropout: 0.1
attention_dropout: 0.0
feature_grad_mult: 0.0
freeze_finetune_updates: 10000

hydra:
job:
config:
override_dirname:
kv_sep: '-'
item_sep: '__'
exclude_keys:
- run
- task.data
- task.label_dir
- model.multires_hubert_path
- dataset.train_subset
- dataset.valid_subset
- criterion.wer_kenlm_model
- criterion.wer_lexicon
run:
dir: ???
sweep:
dir: ???
subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}