Skip to content

Commit

Permalink
feat: add fbank_only property to WeSpeaker models (#1708)
Browse files Browse the repository at this point in the history
  • Loading branch information
hbredin committed May 8, 2024
1 parent 9a61ec2 commit 7a90137
Show file tree
Hide file tree
Showing 3 changed files with 210 additions and 18 deletions.
7 changes: 4 additions & 3 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,16 @@

- feat(task): add option to cache task training metadata to speed up training (with [@clement-pages](https://github.com/clement-pages/))
- feat(model): add `receptive_field`, `num_frames` and `dimension` to models (with [@Bilal-Rahou](https://github.com/Bilal-Rahou))
- feat(model): add `fbank_only` property to `WeSpeaker` models
- feat(util): add `Powerset.permutation_mapping` to help with permutation in powerset space (with [@FrenchKrab](https://github.com/FrenchKrab))
- feat(sample): add sample file at `pyannote.audio.sample.SAMPLE_FILE`
- feat(sample): add sample file at `pyannote.audio.sample.SAMPLE_FILE`
- feat(metric): add `reduce` option to `diarization_error_rate` metric (with [@Bilal-Rahou](https://github.com/Bilal-Rahou))
- feat(pipeline): add `Waveform` and `SampleRate` preprocessors

### Fixes

- fix(task): fix random generators and their reproducibility (with [@FrenchKrab](https://github.com/FrenchKrab))
- fix(task): fix estimation of training set size (with [@FrenchKrab](https://github.com/FrenchKrab))
- fix(task): fix random generators and their reproducibility (with [@FrenchKrab](https://github.com/FrenchKrab))
- fix(task): fix estimation of training set size (with [@FrenchKrab](https://github.com/FrenchKrab))

### Improvements

Expand Down
158 changes: 149 additions & 9 deletions pyannote/audio/models/embedding/wespeaker/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from typing import Optional

import torch
import torch.nn.functional as F
import torchaudio.compliance.kaldi as kaldi

from pyannote.audio.core.model import Model
Expand All @@ -39,16 +40,33 @@


class BaseWeSpeakerResNet(Model):
"""Base class for WeSpeaker's ResNet models
Parameters
----------
fbank_centering_span : float, optional
Span of the fbank centering window (in seconds).
Defaults (None) to use whole input.
See also
--------
torchaudio.compliance.kaldi.fbank
"""

def __init__(
self,
sample_rate: int = 16000,
num_channels: int = 1,
num_mel_bins: int = 80,
frame_length: int = 25,
frame_shift: int = 10,
frame_length: float = 25.0, # in milliseconds
frame_shift: float = 10.0, # in milliseconds
round_to_power_of_two: bool = True,
snip_edges: bool = True,
dither: float = 0.0,
window_type: str = "hamming",
use_energy: bool = False,
fbank_centering_span: Optional[float] = None,
task: Optional[Task] = None,
):
super().__init__(sample_rate=sample_rate, num_channels=num_channels, task=task)
Expand All @@ -60,21 +78,38 @@ def __init__(
"frame_length",
"frame_shift",
"dither",
"round_to_power_of_two",
"snip_edges",
"window_type",
"use_energy",
"fbank_centering_span",
)

self._fbank = partial(
kaldi.fbank,
num_mel_bins=self.hparams.num_mel_bins,
frame_length=self.hparams.frame_length,
round_to_power_of_two=self.hparams.round_to_power_of_two,
frame_shift=self.hparams.frame_shift,
snip_edges=self.hparams.snip_edges,
dither=self.hparams.dither,
sample_frequency=self.hparams.sample_rate,
window_type=self.hparams.window_type,
use_energy=self.hparams.use_energy,
)

@property
def fbank_only(self) -> bool:
"""Whether to only extract fbank features"""
return getattr(self, "_fbank_only", False)

@fbank_only.setter
def fbank_only(self, value: bool):
if hasattr(self, "receptive_field"):
del self.receptive_field

self._fbank_only = value

def compute_fbank(self, waveforms: torch.Tensor) -> torch.Tensor:
"""Extract fbank features
Expand All @@ -85,6 +120,7 @@ def compute_fbank(self, waveforms: torch.Tensor) -> torch.Tensor:
Returns
-------
fbank : (batch_size, num_frames, num_mel_bins)
fbank features
Source: https://github.com/wenet-e2e/wespeaker/blob/45941e7cba2c3ea99e232d02bedf617fc71b0dad/wespeaker/bin/infer_onnx.py#L30C1-L50
"""
Expand All @@ -98,11 +134,37 @@ def compute_fbank(self, waveforms: torch.Tensor) -> torch.Tensor:

features = torch.vmap(self._fbank)(waveforms.to(fft_device)).to(device)

return features - torch.mean(features, dim=1, keepdim=True)
# center features with global average
if self.hparams.fbank_centering_span is None:
return features - torch.mean(features, dim=1, keepdim=True)

# center features with running average
window_size = int(self.hparams.sample_rate * self.hparams.frame_length * 0.001)
step_size = int(self.hparams.sample_rate * self.hparams.frame_shift * 0.001)
kernel_size = conv1d_num_frames(
num_samples=int(
self.hparams.fbank_centering_span * self.hparams.sample_rate
),
kernel_size=window_size,
stride=step_size,
padding=0,
dilation=1,
)
return features - F.avg_pool1d(
features.transpose(1, 2),
kernel_size=2 * (kernel_size // 2) + 1,
stride=1,
padding=kernel_size // 2,
count_include_pad=False,
).transpose(1, 2)

@property
def dimension(self) -> int:
"""Dimension of output"""

if self.fbank_only:
return self.hparams.num_mel_bins

return self.resnet.embed_dim

@lru_cache
Expand All @@ -122,13 +184,19 @@ def num_frames(self, num_samples: int) -> int:
window_size = int(self.hparams.sample_rate * self.hparams.frame_length * 0.001)
step_size = int(self.hparams.sample_rate * self.hparams.frame_shift * 0.001)

# TODO: take round_to_power_of_two and snip_edges into account

num_frames = conv1d_num_frames(
num_samples=num_samples,
kernel_size=window_size,
stride=step_size,
padding=0,
dilation=1,
)

if self.fbank_only:
return num_frames

return self.resnet.num_frames(num_frames)

def receptive_field_size(self, num_frames: int = 1) -> int:
Expand All @@ -144,8 +212,13 @@ def receptive_field_size(self, num_frames: int = 1) -> int:
receptive_field_size : int
Receptive field size.
"""

receptive_field_size = num_frames
receptive_field_size = self.resnet.receptive_field_size(receptive_field_size)

if not self.fbank_only:
receptive_field_size = self.resnet.receptive_field_size(
receptive_field_size
)

window_size = int(self.hparams.sample_rate * self.hparams.frame_length * 0.001)
step_size = int(self.hparams.sample_rate * self.hparams.frame_shift * 0.001)
Expand All @@ -172,9 +245,11 @@ def receptive_field_center(self, frame: int = 0) -> int:
Index of receptive field center.
"""
receptive_field_center = frame
receptive_field_center = self.resnet.receptive_field_center(
frame=receptive_field_center
)

if not self.fbank_only:
receptive_field_center = self.resnet.receptive_field_center(
frame=receptive_field_center
)

window_size = int(self.hparams.sample_rate * self.hparams.frame_length * 0.001)
step_size = int(self.hparams.sample_rate * self.hparams.frame_shift * 0.001)
Expand All @@ -189,14 +264,79 @@ def receptive_field_center(self, frame: int = 0) -> int:
def forward(
self, waveforms: torch.Tensor, weights: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""Extract speaker embeddings
Parameters
----------
waveforms : torch.Tensor
Batch of waveforms with shape (batch, channel, sample)
weights : (batch, frames) or (batch, speakers, frames) torch.Tensor, optional
Batch of weights passed to statistics pooling layer.
Returns
-------
embeddings : (batch, dimension) or (batch, speakers, dimension) torch.Tensor
Batch of embeddings.
"""

fbank = self.compute_fbank(waveforms)
if self.fbank_only:
return fbank

return self.resnet(fbank, weights=weights)[1]

def forward_frames(self, waveforms: torch.Tensor) -> torch.Tensor:
"""Extract frame-wise embeddings
Parameters
----------
waveforms : torch.Tensor
Batch of waveforms with shape (batch, channel, sample)
weights : torch.Tensor, optional
Batch of weights with shape (batch, frame).
Returns
-------
embeddings : (batch, ..., embedding_frames) torch.Tensor
Batch of frame-wise embeddings.
"""
fbank = self.compute_fbank(waveforms)
return self.resnet.forward_frames(fbank)

def forward_embedding(
self, frames: torch.Tensor, weights: torch.Tensor = None
) -> torch.Tensor:
"""Extract speaker embeddings from frame-wise embeddings
Parameters
----------
frames : torch.Tensor
Batch of frames with shape (batch, ..., embedding_frames).
weights : (batch, frames) or (batch, speakers, frames) torch.Tensor, optional
Batch of weights passed to statistics pooling layer.
Returns
-------
embeddings : (batch, dimension) or (batch, speakers, dimension) torch.Tensor
Batch of embeddings.
"""
return self.resnet.forward_embedding(frames, weights=weights)[1]

def forward(
self, waveforms: torch.Tensor, weights: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""Extract speaker embeddings
Parameters
----------
waveforms : torch.Tensor
Batch of waveforms with shape (batch, channel, sample)
weights : (batch, frames) or (batch, speakers, frames) torch.Tensor, optional
Batch of weights passed to statistics pooling layer.
Returns
-------
embeddings : (batch, dimension) or (batch, speakers, dimension) torch.Tensor
Batch of embeddings.
"""

fbank = self.compute_fbank(waveforms)
Expand Down
63 changes: 57 additions & 6 deletions pyannote/audio/models/embedding/wespeaker/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,12 +344,64 @@ def receptive_field_center(self, frame: int = 0) -> int:

return receptive_field_center

def forward(self, x: torch.Tensor, weights: Optional[torch.Tensor] = None):
def forward_frames(self, fbank: torch.Tensor) -> torch.Tensor:
"""Extract frame-wise embeddings
Parameters
----------
fbanks : (batch, frames, features) torch.Tensor
Batch of fbank features
Returns
-------
embeddings : (batch, ..., embedding_frames) torch.Tensor
Batch of frame-wise embeddings.
"""
fbank = fbank.permute(0, 2, 1) # (B,T,F) => (B,F,T)
fbank = fbank.unsqueeze_(1)
out = F.relu(self.bn1(self.conv1(fbank)))
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
out = self.layer4(out)
return out

def forward_embedding(
self, frames: torch.Tensor, weights: torch.Tensor = None
) -> torch.Tensor:
"""Extract speaker embeddings
Parameters
----------
x : (batch, frames, features) torch.Tensor
frames : torch.Tensor
Batch of frames with shape (batch, ..., embedding_frames).
weights : (batch, frames) or (batch, speakers, frames) torch.Tensor, optional
Batch of weights passed to statistics pooling layer.
Returns
-------
embeddings : (batch, dimension) or (batch, speakers, dimension) torch.Tensor
Batch of embeddings.
"""

stats = self.pool(frames, weights=weights)

embed_a = self.seg_1(stats)
if self.two_emb_layer:
out = F.relu(embed_a)
out = self.seg_bn_1(out)
embed_b = self.seg_2(out)
return embed_a, embed_b
else:
return torch.tensor(0.0), embed_a

def forward(self, fbank: torch.Tensor, weights: Optional[torch.Tensor] = None):
"""Extract speaker embeddings
Parameters
----------
fbank : (batch, frames, features) torch.Tensor
Batch of features
weights : (batch, frames) torch.Tensor, optional
Batch of weights
Expand All @@ -358,10 +410,9 @@ def forward(self, x: torch.Tensor, weights: Optional[torch.Tensor] = None):
-------
embedding : (batch, embedding_dim) torch.Tensor
"""
x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)

x = x.unsqueeze_(1)
out = F.relu(self.bn1(self.conv1(x)))
fbank = fbank.permute(0, 2, 1) # (B,T,F) => (B,F,T)
fbank = fbank.unsqueeze_(1)
out = F.relu(self.bn1(self.conv1(fbank)))
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
Expand Down

0 comments on commit 7a90137

Please sign in to comment.