Skip to content

Commit

Permalink
chore: remove use of vmap in stats-pooling layer (#1706)
Browse files Browse the repository at this point in the history
  • Loading branch information
hbredin committed May 7, 2024
1 parent 4407a66 commit 5d56a11
Showing 1 changed file with 40 additions and 37 deletions.
77 changes: 40 additions & 37 deletions pyannote/audio/models/blocks/pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,53 +26,53 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange


class StatsPool(nn.Module):
"""Statistics pooling
def _pool(sequences: torch.Tensor, weights: torch.Tensor) -> torch.Tensor:
"""Helper function to compute statistics pooling
Compute temporal mean and (unbiased) standard deviation
and returns their concatenation.
Assumes that weights are already interpolated to match the number of frames
in sequences and that they encode the activation of only one speaker.
Reference
---------
https://en.wikipedia.org/wiki/Weighted_arithmetic_mean
Parameters
----------
sequences : (batch, features, frames) torch.Tensor
Sequences of features.
weights : (batch, frames) torch.Tensor
(Already interpolated) weights.
Returns
-------
output : (batch, 2 * features) torch.Tensor
Concatenation of mean and (unbiased) standard deviation.
"""

def _pool(self, sequences: torch.Tensor, weights: torch.Tensor) -> torch.Tensor:
"""Helper function to compute statistics pooling
weights = weights.unsqueeze(dim=1)
# (batch, 1, frames)

Assumes that weights are already interpolated to match the number of frames
in sequences and that they encode the activation of only one speaker.
v1 = weights.sum(dim=2) + 1e-8
mean = torch.sum(sequences * weights, dim=2) / v1

Parameters
----------
sequences : (batch, features, frames) torch.Tensor
Sequences of features.
weights : (batch, frames) torch.Tensor
(Already interpolated) weights.
dx2 = torch.square(sequences - mean.unsqueeze(2))
v2 = torch.square(weights).sum(dim=2)

Returns
-------
output : (batch, 2 * features) torch.Tensor
Concatenation of mean and (unbiased) standard deviation.
"""
var = torch.sum(dx2 * weights, dim=2) / (v1 - v2 / v1 + 1e-8)
std = torch.sqrt(var)

weights = weights.unsqueeze(dim=1)
# (batch, 1, frames)
return torch.cat([mean, std], dim=1)

v1 = weights.sum(dim=2) + 1e-8
mean = torch.sum(sequences * weights, dim=2) / v1

dx2 = torch.square(sequences - mean.unsqueeze(2))
v2 = torch.square(weights).sum(dim=2)
class StatsPool(nn.Module):
"""Statistics pooling
var = torch.sum(dx2 * weights, dim=2) / (v1 - v2 / v1 + 1e-8)
std = torch.sqrt(var)
Compute temporal mean and (unbiased) standard deviation
and returns their concatenation.
return torch.cat([mean, std], dim=1)
Reference
---------
https://en.wikipedia.org/wiki/Weighted_arithmetic_mean
"""

def forward(
self, sequences: torch.Tensor, weights: Optional[torch.Tensor] = None
Expand Down Expand Up @@ -112,17 +112,20 @@ def forward(
has_speaker_dimension = True

# interpolate weights if needed
_, _, num_frames = sequences.shape
_, _, num_weights = weights.shape
_, _, num_frames = sequences.size()
_, num_speakers, num_weights = weights.size()
if num_frames != num_weights:
warnings.warn(
f"Mismatch between frames ({num_frames}) and weights ({num_weights}) numbers."
)
weights = F.interpolate(weights, size=num_frames, mode="nearest")

output = rearrange(
torch.vmap(self._pool, in_dims=(None, 1))(sequences, weights),
"speakers batch features -> batch speakers features",
output = torch.stack(
[
_pool(sequences, weights[:, speaker, :])
for speaker in range(num_speakers)
],
dim=1,
)

if not has_speaker_dimension:
Expand Down

0 comments on commit 5d56a11

Please sign in to comment.