-
Notifications
You must be signed in to change notification settings - Fork 6.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
16 changed files
with
1,111 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,147 @@ | ||
# MR-HuBERT | ||
|
||
## Pre-trained and fine-tuned (ASR) models | ||
Model | Pretraining Data | Finetuning Dataset | Model | Quantizer | ||
|---|---|---|---|--- | ||
|
||
## Load a model | ||
``` | ||
ckpt_path = "/path/to/the/checkpoint.pt" | ||
models, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([ckpt_path]) | ||
model = models[0] | ||
``` | ||
|
||
## Train a new model | ||
|
||
### Data preparation | ||
|
||
Follow the steps in `./simple_kmeans` to create: | ||
- `{train,valid}.tsv` waveform list files with length information | ||
``` | ||
/path/to/your/audio/files | ||
file1.wav\t160000 | ||
file2.wav\t154600 | ||
... | ||
filen.wav\t54362 | ||
``` | ||
- `{train,valid}.km` frame-aligned pseudo label files (the order is the same as wavefiles in the tsv file). | ||
``` | ||
44 44 44 48 48 962 962 962 962 962 962 962 962 967 967 967 967 967 967 967 967 370 852 370 ... 18 18 745 745 | ||
44 44 44 48 48 962 962 962 147 147 147 147 147 147 147 147 147 147 147 147 176 176 271 271 ... 27 27 745 745 | ||
... | ||
44 44 44 48 962 962 962 962 962 962 377 377 377 77 77 852 696 694 433 578 578 82 740 622 ... 27 27 745 745 | ||
``` | ||
- `dict.km.txt` a dummy dictionary (first column is id, the second is dummy one) | ||
``` | ||
0 1 | ||
1 1 | ||
2 1 | ||
... | ||
999 1 | ||
``` | ||
|
||
The `label_rate` is the same as the feature frame rate used for clustering, | ||
which is 100Hz for MFCC features and 50Hz for HuBERT features by default. | ||
|
||
### Pre-train a MR-HuBERT model | ||
|
||
Suppose `{train,valid}.tsv` are saved at `/path/to/data`, `{train,valid}.km` | ||
are saved at `/path/to/labels`, and the label rate is 100Hz. | ||
|
||
To train a base model (12 layer transformer), run: | ||
```sh | ||
$ python fairseq_cli/hydra_train.py \ | ||
--config-dir /path/to/fairseq-py/examples/mr_hubert/config/pretrain \ | ||
--config-name mrhubert_base_librispeech \ | ||
task.data=/path/to/data task.label_dir=/path/to/labels \ | ||
task.labels='["km"]' model.label_rate=100 \ | ||
task.label_rate_ratios='[1, 2]' \ | ||
``` | ||
|
||
Please see sample pre-training scripts `train.sh` for an example script. | ||
|
||
### Fine-tune a MR-HuBERT model with a CTC loss | ||
|
||
Suppose `{train,valid}.tsv` are saved at `/path/to/data`, and their | ||
corresponding character transcripts `{train,valid}.ltr` are saved at | ||
`/path/to/trans`. A typical ltr file is with the same order of tsv waveform files as | ||
``` | ||
HOW | ARE | YOU | ||
... | ||
THANK | YOU | ||
``` | ||
|
||
To fine-tune a pre-trained MR-HuBERT model at `/path/to/checkpoint`, run | ||
```sh | ||
$ python fairseq_cli/hydra_train.py \ | ||
--config-dir /path/to/fairseq-py/examples/mr_hubert/config/finetune \ | ||
--config-name base_10h \ | ||
task.data=/path/to/data task.label_dir=/path/to/trans \ | ||
model.w2v_path=/path/to/checkpoint | ||
``` | ||
|
||
Please see sample pre-training scripts `finetune.sh` for an example script. | ||
|
||
### Decode a MR-HuBERT model | ||
|
||
Suppose the `test.tsv` and `test.ltr` are the waveform list and transcripts of | ||
the split to be decoded, saved at `/path/to/data`, and the fine-tuned model is | ||
saved at `/path/to/checkpoint`. | ||
|
||
|
||
We support three decoding modes: | ||
- Viterbi decoding: greedy decoding without a language model | ||
- KenLM decoding: decoding with an arpa-format KenLM n-gram language model | ||
- Fairseq-LM deocding: decoding with a Fairseq neural language model (not fully tested) | ||
|
||
|
||
#### Viterbi decoding | ||
|
||
`task.normalize` needs to be consistent with the value used during fine-tuning. | ||
Decoding results will be saved at | ||
`/path/to/experiment/directory/decode/viterbi/test`. | ||
|
||
```sh | ||
$ python examples/speech_recognition/new/infer.py \ | ||
--config-dir /path/to/fairseq-py/examples/mr_hubert/config/decode \ | ||
--config-name infer \ | ||
task.data=/path/to/data \ | ||
task.normalize=[true|false] \ | ||
decoding.exp_dir=/path/to/experiment/directory \ | ||
common_eval.path=/path/to/checkpoint | ||
dataset.gen_subset=test \ | ||
``` | ||
|
||
#### KenLM / Fairseq-LM decoding | ||
|
||
Suppose the pronunciation lexicon and the n-gram LM are saved at | ||
`/path/to/lexicon` and `/path/to/arpa`, respectively. Decoding results will be | ||
saved at `/path/to/experiment/directory/decode/kenlm/test`. | ||
|
||
```sh | ||
$ python examples/speech_recognition/new/infer.py \ | ||
--config-dir /path/to/fairseq-py/examples/mr_hubert/config/decode \ | ||
--config-name infer_lm \ | ||
task.data=/path/to/data \ | ||
task.normalize=[true|false] \ | ||
decoding.exp_dir=/path/to/experiment/directory \ | ||
common_eval.path=/path/to/checkpoint | ||
dataset.gen_subset=test \ | ||
decoding.decoder.lexicon=/path/to/lexicon \ | ||
decoding.decoder.lmpath=/path/to/arpa | ||
``` | ||
|
||
The command above uses the default decoding hyperparameter, which can be found | ||
in `examples/speech_recognition/hydra/decoder.py`. These parameters can be | ||
configured from the command line. For example, to search with a beam size of | ||
500, we can append the command above with `decoding.decoder.beam=500`. | ||
Important parameters include: | ||
- decoding.decoder.beam | ||
- decoding.decoder.beamthreshold | ||
- decoding.decoder.lmweight | ||
- decoding.decoder.wordscore | ||
- decoding.decoder.silweight | ||
|
||
To decode with a Fairseq LM, you may check the usage examples in wav2vec2 or hubert examples. | ||
|
||
Please see sample pre-training scripts `decode.sh` for an example script. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
# @package _group_ | ||
|
||
defaults: | ||
- model: null | ||
|
||
hydra: | ||
run: | ||
dir: ${common_eval.results_path}/viterbi | ||
sweep: | ||
dir: ${common_eval.results_path} | ||
subdir: viterbi | ||
|
||
task: | ||
_name: multires_hubert_pretraining | ||
single_target: true | ||
fine_tuning: true | ||
label_rate_ratios: ??? | ||
data: ??? | ||
normalize: false | ||
|
||
decoding: | ||
type: viterbi | ||
unique_wer_file: true | ||
common_eval: | ||
results_path: ??? | ||
path: ??? | ||
post_process: letter | ||
dataset: | ||
max_tokens: 1100000 | ||
gen_subset: ??? |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
# @package _group_ | ||
|
||
defaults: | ||
- model: null | ||
|
||
hydra: | ||
run: | ||
dir: ${common_eval.results_path}/beam${decoding.beam}_th${decoding.beamthreshold}_lmw${decoding.lmweight}_wrd${decoding.wordscore}_sil${decoding.silweight} | ||
sweep: | ||
dir: ${common_eval.results_path} | ||
subdir: beam${decoding.beam}_th${decoding.beamthreshold}_lmw${decoding.lmweight}_wrd${decoding.wordscore}_sil${decoding.silweight} | ||
|
||
task: | ||
_name: multires_hubert_pretraining | ||
single_target: true | ||
fine_tuning: true | ||
data: ??? | ||
label_rate_ratios: ??? | ||
normalize: ??? | ||
|
||
decoding: | ||
type: kenlm | ||
lexicon: ??? | ||
lmpath: ??? | ||
beamthreshold: 100 | ||
beam: 500 | ||
lmweight: 1.5 | ||
wordscore: -1 | ||
silweight: 0 | ||
unique_wer_file: true | ||
common_eval: | ||
results_path: ??? | ||
path: ??? | ||
post_process: letter | ||
dataset: | ||
max_tokens: 1100000 | ||
gen_subset: ??? |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
# @package _global_ | ||
hydra: | ||
launcher: | ||
cpus_per_task: ${distributed_training.distributed_world_size} | ||
gpus_per_node: ${distributed_training.distributed_world_size} | ||
tasks_per_node: ${hydra.launcher.gpus_per_node} | ||
nodes: 1 | ||
mem_gb: 200 | ||
timeout_min: 4320 | ||
max_num_timeout: 50 | ||
name: ${hydra.job.config_name} | ||
submitit_folder: ${hydra.sweep.dir}/submitit | ||
|
||
distributed_training: | ||
distributed_world_size: 1 | ||
distributed_no_spawn: true | ||
distributed_port: 29761 |
17 changes: 17 additions & 0 deletions
17
examples/mr_hubert/config/decode/run/submitit_slurm_8gpu.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
# @package _global_ | ||
hydra: | ||
launcher: | ||
cpus_per_task: ${distributed_training.distributed_world_size} | ||
gpus_per_node: ${distributed_training.distributed_world_size} | ||
tasks_per_node: ${hydra.launcher.gpus_per_node} | ||
nodes: 1 | ||
mem_gb: 200 | ||
timeout_min: 4320 | ||
max_num_timeout: 50 | ||
name: ${hydra.job.config_name} | ||
submitit_folder: ${hydra.sweep.dir}/submitit | ||
|
||
distributed_training: | ||
distributed_world_size: 8 | ||
distributed_no_spawn: true | ||
distributed_port: 29761 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
# @package _group_ | ||
|
||
common: | ||
fp16: true | ||
log_format: json | ||
log_interval: 200 | ||
tensorboard_logdir: tblog | ||
seed: 1337 | ||
|
||
checkpoint: | ||
no_epoch_checkpoints: true | ||
best_checkpoint_metric: wer | ||
|
||
distributed_training: | ||
ddp_backend: c10d | ||
find_unused_parameters: true | ||
distributed_world_size: 8 | ||
distributed_port: 29671 | ||
nprocs_per_node: 8 | ||
|
||
task: | ||
_name: multires_hubert_pretraining | ||
data: ??? | ||
fine_tuning: true | ||
label_dir: ??? | ||
label_rate_ratios: ??? | ||
normalize: false # must be consistent with pre-training | ||
labels: ["ltr"] | ||
single_target: true | ||
|
||
dataset: | ||
num_workers: 0 | ||
max_tokens: 3200000 | ||
validate_after_updates: ${model.freeze_finetune_updates} | ||
validate_interval: 5 | ||
train_subset: train_100h | ||
valid_subset: dev_other | ||
|
||
criterion: | ||
_name: ctc | ||
zero_infinity: true | ||
|
||
optimization: | ||
max_update: 80000 | ||
lr: [3e-5] | ||
sentence_avg: true | ||
update_freq: [1] | ||
|
||
optimizer: | ||
_name: adam | ||
adam_betas: (0.9,0.98) | ||
adam_eps: 1e-08 | ||
|
||
lr_scheduler: | ||
_name: tri_stage | ||
phase_ratio: [0.1, 0.4, 0.5] | ||
final_lr_scale: 0.05 | ||
|
||
model: | ||
_name: multires_hubert_ctc | ||
multires_hubert_path: ??? | ||
apply_mask: true | ||
mask_selection: static | ||
mask_length: 10 | ||
mask_other: 0 | ||
mask_prob: 0.75 | ||
mask_channel_selection: static | ||
mask_channel_length: 64 | ||
mask_channel_other: 0 | ||
mask_channel_prob: 0.5 | ||
layerdrop: 0.1 | ||
dropout: 0.0 | ||
activation_dropout: 0.1 | ||
attention_dropout: 0.0 | ||
feature_grad_mult: 0.0 | ||
freeze_finetune_updates: 10000 | ||
|
||
hydra: | ||
job: | ||
config: | ||
override_dirname: | ||
kv_sep: '-' | ||
item_sep: '__' | ||
exclude_keys: | ||
- run | ||
- task.data | ||
- task.label_dir | ||
- model.multires_hubert_path | ||
- dataset.train_subset | ||
- dataset.valid_subset | ||
- criterion.wer_kenlm_model | ||
- criterion.wer_lexicon | ||
run: | ||
dir: ??? | ||
sweep: | ||
dir: ??? | ||
subdir: ${hydra.job.config_name}__${hydra.job.override_dirname} |
Oops, something went wrong.