Skip to content

Commit

Permalink
Support phonetic transcriptions
Browse files Browse the repository at this point in the history
  • Loading branch information
neonbjb committed May 10, 2022
1 parent 1a2fee9 commit 260bdf7
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 16 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Expand Up @@ -103,3 +103,5 @@ venv.bak/
/torchscript/traced_model*
/ocotillo.onnx
/upload_notes.txt
/torchscript/traced_wav2vec2_large_robust_ft_libritts_voxpopuli_cuda.pth
/torchscript/traced_wav2vec2_lv_60_espeak_cv_ft_cuda.pth
12 changes: 6 additions & 6 deletions ocotillo/api.py
Expand Up @@ -12,12 +12,12 @@ def _append_with_at_least_one_space(text, new_text):
return text + new_text

class Transcriber:
def __init__(self, on_cuda=True, cuda_device=0):
def __init__(self, phonetic=False, on_cuda=True, cuda_device=0):
if on_cuda:
self.device = 'cuda'
self.device = f'cuda:{cuda_device}'
else:
self.device = 'cpu'
self.model, self.processor = load_model(self.device)
self.model, self.processor = load_model(self.device, phonetic=phonetic)

def transcribe(self, audio_data, sample_rate):
"""
Expand Down Expand Up @@ -94,10 +94,10 @@ def _process_large_clip(self, audio_data, sample_rate):


if __name__ == '__main__':
transcriber = Transcriber(on_cuda=True)
audio = load_audio('data/obama.mp3', 44100)
transcriber = Transcriber(on_cuda=True, phonetic=True)
audio = load_audio('../data/obama.mp3', 44100)
print(transcriber.transcribe(audio, 44100))
start = time()
audio = load_audio('data/obama_long.mp3', 16000)
audio = load_audio('../data/obama_long.mp3', 16000)
print(transcriber.transcribe(audio, 16000))
print(f"Elapsed: {time() - start}")
21 changes: 12 additions & 9 deletions ocotillo/model_loader.py
Expand Up @@ -5,26 +5,30 @@
from transformers import Wav2Vec2ForCTC, Wav2Vec2FeatureExtractor, Wav2Vec2CTCTokenizer, Wav2Vec2Processor


def load_model(device, use_torchscript=False):
def load_model(device, phonetic=False, use_torchscript=False):
"""
Utility function to load the model and corresponding processor to the specified device. Supports loading
torchscript models when they have been pre-built (which is accomplished by running this file.)
"""
model_name = "facebook/wav2vec2-lv-60-espeak-cv-ft" if phonetic else "jbetker/wav2vec2-large-robust-ft-libritts-voxpopuli"
if use_torchscript:
model = trace_torchscript_model('cuda' if 'cuda' in device else 'cpu')
model = trace_torchscript_model(model_name.split('/')[-1].replace('-', '_'), 'cuda' if 'cuda' in device else 'cpu')
model = model.to(device)
else:
model = Wav2Vec2ForCTC.from_pretrained("jbetker/wav2vec2-large-robust-ft-libritts-voxpopuli").to(device)
model = Wav2Vec2ForCTC.from_pretrained(model_name).to(device)
model.config.return_dict = False
model.eval()
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(f"facebook/wav2vec2-large-960h")
tokenizer = Wav2Vec2CTCTokenizer.from_pretrained('jbetker/tacotron_symbols')
processor = Wav2Vec2Processor(feature_extractor, tokenizer)
if phonetic:
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-lv-60-espeak-cv-ft")
else:
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(f"facebook/wav2vec2-large-960h")
tokenizer = Wav2Vec2CTCTokenizer.from_pretrained('jbetker/tacotron-symbols')
processor = Wav2Vec2Processor(feature_extractor, tokenizer)
return model, processor


def trace_torchscript_model(dev_type='cpu', load_from_cache=True):
output_trace_cache_file = f'torchscript/traced_model_{dev_type}.pth'
def trace_torchscript_model(model_name, dev_type='cpu', load_from_cache=True):
output_trace_cache_file = f'torchscript/traced_{model_name}_{dev_type}.pth'
if load_from_cache and os.path.exists(output_trace_cache_file):
return torch.jit.load(output_trace_cache_file)

Expand Down Expand Up @@ -79,4 +83,3 @@ def test_onnx_model():
if __name__ == '__main__':
trace_onnx_model()
test_onnx_model()
#trace_torchscript_model('cuda', load_from_cache=False)
3 changes: 2 additions & 1 deletion ocotillo/transcribe.py
Expand Up @@ -18,13 +18,14 @@
parser = argparse.ArgumentParser()
parser.add_argument('--path', help='Input folder containing audio files you want transcribed.')
parser.add_argument('--output_file', default='results.tsv', help='Where transcriptions will be placed.')
parser.add_argument('--phonetic', default=False, help='Whether or not to output phonetic symbols.')
parser.add_argument('--resume', default=0, type=int, help='Skip the first <n> audio tracks.')
parser.add_argument('--batch_size', default=8, type=int, help='Number of audio files to process at a time. Larger batches are more efficient on a GPU.')
parser.add_argument('--cuda', default=-1, type=int, help='The cuda device to perform inference on. -1 (or default) means use the CPU.')
parser.add_argument('--output_tokens', default=False, type=bool, help='Whether or not to output the CTC codes. Useful for text alignment.')
args = parser.parse_args()

model, processor = load_model(f'cuda:{args.cuda}' if args.cuda != -1 else 'cpu', use_torchscript=True)
model, processor = load_model(f'cuda:{args.cuda}' if args.cuda != -1 else 'cpu', use_torchscript=True, phonetic=args.phonetic)
dataset = AudioFolderDataset(args.path, sampling_rate=16000, pad_to=566400, skip=args.resume)
dataloader = DataLoader(dataset, args.batch_size, num_workers=2)

Expand Down

0 comments on commit 260bdf7

Please sign in to comment.