Skip to content

Commit

Permalink
MMS TTS Romanian char fix + MPS support + full checkpoint (#5168)
Browse files Browse the repository at this point in the history
* Fix ț filtering in Romanian at inference

* mps support + full checkpoints (discriminator+optimizer)

---------

Co-authored-by: Bowen Shi <bshi@meta.com>
  • Loading branch information
chevalierNoir and Bowen Shi committed May 30, 2023
1 parent ae59bd6 commit 533644c
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 25 deletions.
6 changes: 6 additions & 0 deletions examples/mms/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@ MMS-1B-all| 1162 | MMS-lab + FLEURS <br>+ CV + VP + MLS | [download](https://dl
wget https://dl.fbaipublicfiles.com/mms/tts/eng.tar.gz # English (eng)
wget https://dl.fbaipublicfiles.com/mms/tts/azj-script_latin.tar.gz # North Azerbaijani (azj-script_latin)
```
The above command downloads generator only, which is enough to run TTS inference. If you want the full model checkpoint which also includes the discriminator (`D_100000.pth`) and the optimizer states, download as follows.
```
# Example (full checkpoint: generator + discriminator + optimizer):
wget https://dl.fbaipublicfiles.com/mms/tts/full_model/eng.tar.gz # English (eng)
```


### LID

Expand Down
19 changes: 16 additions & 3 deletions examples/mms/tts/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,23 +69,36 @@ def get_text(self, text, hps):
text_norm = torch.LongTensor(text_norm)
return text_norm

def filter_oov(self, text):
def filter_oov(self, text, lang=None):
text = self.preprocess_char(text, lang=lang)
val_chars = self._symbol_to_id
txt_filt = "".join(list(filter(lambda x: x in val_chars, text)))
print(f"text after filtering OOV: {txt_filt}")
return txt_filt

def preprocess_char(self, text, lang=None):
"""
Special treatement of characters in certain languages
"""
if lang == "ron":
text = text.replace("ț", "ţ")
print(f"{lang} (ț -> ţ): {text}")
return text

def generate():
parser = argparse.ArgumentParser(description='TTS inference')
parser.add_argument('--model-dir', type=str, help='model checkpoint dir')
parser.add_argument('--wav', type=str, help='output wav path')
parser.add_argument('--txt', type=str, help='input text')
parser.add_argument('--uroman-dir', type=str, help='uroman lib dir (will download if not specified)')
parser.add_argument('--uroman-dir', type=str, default=None, help='uroman lib dir (will download if not specified)')
parser.add_argument('--lang', type=str, default=None, help='language iso code (required for Romanian)')
args = parser.parse_args()
ckpt_dir, wav_path, txt = args.model_dir, args.wav, args.txt

if torch.cuda.is_available():
device = torch.device("cuda")
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available() and torch.backends.mps.is_built():
device = torch.device("mps")
else:
device = torch.device("cpu")

Expand Down Expand Up @@ -122,7 +135,7 @@ def generate():
txt = text_mapper.uromanize(txt, uroman_pl)
print(f"uroman text: {txt}")
txt = txt.lower()
txt = text_mapper.filter_oov(txt)
txt = text_mapper.filter_oov(txt, lang=args.lang)
stn_tst = text_mapper.get_text(txt, hps)
with torch.no_grad():
x_tst = stn_tst.unsqueeze(0).to(device)
Expand Down
56 changes: 34 additions & 22 deletions examples/mms/tts/tutorial/MMS_TTS_Inference_Colab.ipynb

Large diffs are not rendered by default.

0 comments on commit 533644c

Please sign in to comment.