-
Notifications
You must be signed in to change notification settings - Fork 4
/
generate.py
85 lines (66 loc) · 2.44 KB
/
generate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
"""Extract vocals from waveform.
usage: generate.py [options] <checkpoint-path> <input-wav>
options:
--output-dir=<dir> Directory where to save output wav [default: generated].
--sr=<sr> Sample rate of generated waveform
-h, --help Show this help message and exit
"""
from docopt import docopt
import os
from os.path import dirname, join, expanduser
import random
from tqdm import tqdm
import numpy as np
import librosa
import librosa.display
import librosa.output
import torch
from audio import *
from model import build_model
from hparams import hparams as hp
from utils import resample
use_cuda = torch.cuda.is_available()
def _load(checkpoint_path):
if use_cuda:
checkpoint = torch.load(checkpoint_path)
else:
checkpoint = torch.load(checkpoint_path,
map_location=lambda storage, loc: storage)
return checkpoint
def load_checkpoint(path, model):
print("Load checkpoint from: {}".format(path))
checkpoint = _load(path)
model.load_state_dict(checkpoint["state_dict"])
return model
def generate(device, model, path, output_dir, target_sr):
wav = load_wav(path)
estimates = model.generate_wav(device, wav)
if target_sr != hp.sample_rate:
resample(estimates, target_sr)
file_id = path.split('/')[-1].split('.')[0]
vox_outpath = os.path.join(output_dir, f'{file_id}_vocals.wav')
bg_outpath = os.path.join(output_dir, f'{file_id}_accompaniment.wav')
save_wav(estimates['vocals'], vox_outpath, sr=target_sr)
save_wav(estimates['accompaniment'], bg_outpath, sr=target_sr)
if __name__=="__main__":
args = docopt(__doc__)
output_dir = args["--output-dir"]
checkpoint_path = args["<checkpoint-path>"]
input_path = args["<input-wav>"]
target_sr = args["--sr"]
if output_dir is None:
output_dir = 'generated'
if target_sr is None:
target_sr = hp.sample_rate
else:
target_sr = int(target_sr)
# make dirs, load dataloader and set up device
os.makedirs(output_dir, exist_ok=True)
device = torch.device("cuda" if use_cuda else "cpu")
print("using device:{}".format(device))
# build model, create optimizer
model = build_model().to(device)
# load checkpoint
model = load_checkpoint(checkpoint_path, model)
print("loading model from checkpoint:{}".format(checkpoint_path))
generate(device, model, input_path, output_dir, target_sr)