Skip to content

Commit

Permalink
Merge branch 'release/3.1.0'
Browse files Browse the repository at this point in the history
  • Loading branch information
hbredin committed Nov 16, 2023
2 parents 28fcf50 + eecc634 commit f45da71
Show file tree
Hide file tree
Showing 20 changed files with 5,045 additions and 4,193 deletions.
197 changes: 116 additions & 81 deletions CHANGELOG.md
@@ -1,132 +1,167 @@
# Changelog

## `develop` branch

## Version 3.1.0 (2023-11-16)

### TL;DR

[`pyannote/speaker-diarization-3.1`](https://hf.co/pyannote/speaker-diarization-3.1) no longer requires [unpopular](https://github.com/pyannote/pyannote-audio/issues/1537) ONNX runtime

### New features

- feat(model): add WeSpeaker embedding wrapper based on PyTorch
- feat(model): add support for multi-speaker statistics pooling
- feat(pipeline): add `TimingHook` for profiling processing time
- feat(pipeline): add `ArtifactHook` for saving internal steps
- feat(pipeline): add support for list of hooks with `Hooks`
- feat(utils): add `"soft"` option to `Powerset.to_multilabel`

### Fixes

- fix(pipeline): add missing "embedding" hook call in `SpeakerDiarization`
- fix(pipeline): fix `AgglomerativeClustering` to honor `num_clusters` when provided
- fix(pipeline): fix frame-wise speaker count exceeding `max_speakers` or detected `num_speakers` in `SpeakerDiarization` pipeline

### Improvements

- improve(pipeline): compute `fbank` on GPU when requested

### Breaking changes

- BREAKING(pipeline): rename `WeSpeakerPretrainedSpeakerEmbedding` to `ONNXWeSpeakerPretrainedSpeakerEmbedding`
- BREAKING(setup): remove `onnxruntime` dependency.
You can still use ONNX `hbredin/wespeaker-voxceleb-resnet34-LM` but you will have to install `onnxruntime` yourself.
- BREAKING(pipeline): remove `logging_hook` (use `ArtifactHook` instead)
- BREAKING(pipeline): remove `onset` and `offset` parameter in `SpeakerDiarizationMixin.speaker_count`
You should now binarize segmentations before passing them to `speaker_count`

## Version 3.0.1 (2023-09-28)

- fix(pipeline): fix WeSpeaker GPU support
- fix(pipeline): fix WeSpeaker GPU support

## Version 3.0.0 (2023-09-26)

### Features and improvements

- feat(pipeline): send pipeline to device with `pipeline.to(device)`
- feat(pipeline): add `return_embeddings` option to `SpeakerDiarization` pipeline
- feat(pipeline): make `segmentation_batch_size` and `embedding_batch_size` mutable in `SpeakerDiarization` pipeline (they now default to `1`)
- feat(pipeline): add progress hook to pipelines
- feat(task): add [powerset](https://www.isca-speech.org/archive/interspeech_2023/plaquet23_interspeech.html) support to `SpeakerDiarization` task
- feat(task): add support for multi-task models
- feat(task): add support for label scope in speaker diarization task
- feat(task): add support for missing classes in multi-label segmentation task
- feat(model): add segmentation model based on torchaudio self-supervised representation
- feat(pipeline): check version compatibility at load time
- improve(task): load metadata as tensors rather than pyannote.core instances
- improve(task): improve error message on missing specifications
- feat(pipeline): send pipeline to device with `pipeline.to(device)`
- feat(pipeline): add `return_embeddings` option to `SpeakerDiarization` pipeline
- feat(pipeline): make `segmentation_batch_size` and `embedding_batch_size` mutable in `SpeakerDiarization` pipeline (they now default to `1`)
- feat(pipeline): add progress hook to pipelines
- feat(task): add [powerset](https://www.isca-speech.org/archive/interspeech_2023/plaquet23_interspeech.html) support to `SpeakerDiarization` task
- feat(task): add support for multi-task models
- feat(task): add support for label scope in speaker diarization task
- feat(task): add support for missing classes in multi-label segmentation task
- feat(model): add segmentation model based on torchaudio self-supervised representation
- feat(pipeline): check version compatibility at load time
- improve(task): load metadata as tensors rather than pyannote.core instances
- improve(task): improve error message on missing specifications

### Breaking changes

- BREAKING(task): rename `Segmentation` task to `SpeakerDiarization`
- BREAKING(pipeline): pipeline defaults to CPU (use `pipeline.to(device)`)
- BREAKING(pipeline): remove `SpeakerSegmentation` pipeline (use `SpeakerDiarization` pipeline)
- BREAKING(pipeline): remove `segmentation_duration` parameter from `SpeakerDiarization` pipeline (defaults to `duration` of segmentation model)
- BREAKING(task): remove support for variable chunk duration for segmentation tasks
- BREAKING(pipeline): remove support for `FINCHClustering` and `HiddenMarkovModelClustering`
- BREAKING(setup): drop support for Python 3.7
- BREAKING(io): channels are now 0-indexed (used to be 1-indexed)
- BREAKING(io): multi-channel audio is no longer downmixed to mono by default.
You should update how `pyannote.audio.core.io.Audio` is instantiated:
* replace `Audio()` by `Audio(mono="downmix")`;
* replace `Audio(mono=True)` by `Audio(mono="downmix")`;
* replace `Audio(mono=False)` by `Audio()`.
- BREAKING(model): get rid of (flaky) `Model.introspection`
If, for some weird reason, you wrote some custom code based on that,
you should instead rely on `Model.example_output`.
- BREAKING(interactive): remove support for Prodigy recipes

- BREAKING(task): rename `Segmentation` task to `SpeakerDiarization`
- BREAKING(pipeline): pipeline defaults to CPU (use `pipeline.to(device)`)
- BREAKING(pipeline): remove `SpeakerSegmentation` pipeline (use `SpeakerDiarization` pipeline)
- BREAKING(pipeline): remove `segmentation_duration` parameter from `SpeakerDiarization` pipeline (defaults to `duration` of segmentation model)
- BREAKING(task): remove support for variable chunk duration for segmentation tasks
- BREAKING(pipeline): remove support for `FINCHClustering` and `HiddenMarkovModelClustering`
- BREAKING(setup): drop support for Python 3.7
- BREAKING(io): channels are now 0-indexed (used to be 1-indexed)
- BREAKING(io): multi-channel audio is no longer downmixed to mono by default.
You should update how `pyannote.audio.core.io.Audio` is instantiated:
- replace `Audio()` by `Audio(mono="downmix")`;
- replace `Audio(mono=True)` by `Audio(mono="downmix")`;
- replace `Audio(mono=False)` by `Audio()`.
- BREAKING(model): get rid of (flaky) `Model.introspection`
If, for some weird reason, you wrote some custom code based on that,
you should instead rely on `Model.example_output`.
- BREAKING(interactive): remove support for Prodigy recipes

### Fixes and improvements

- fix(pipeline): fix reproducibility issue with Ampere CUDA devices
- fix(pipeline): fix support for IOBase audio
- fix(pipeline): fix corner case with no speaker
- fix(train): prevent metadata preparation to happen twice
- fix(task): fix support for "balance" option
- improve(task): shorten and improve structure of Tensorboard tags
- fix(pipeline): fix reproducibility issue with Ampere CUDA devices
- fix(pipeline): fix support for IOBase audio
- fix(pipeline): fix corner case with no speaker
- fix(train): prevent metadata preparation to happen twice
- fix(task): fix support for "balance" option
- improve(task): shorten and improve structure of Tensorboard tags

### Dependencies update

- setup: switch to torch 2.0+, torchaudio 2.0+, soundfile 0.12+, lightning 2.0+, torchmetrics 0.11+
- setup: switch to pyannote.core 5.0+, pyannote.database 5.0+, and pyannote.pipeline 3.0+
- setup: switch to speechbrain 0.5.14+
- setup: switch to torch 2.0+, torchaudio 2.0+, soundfile 0.12+, lightning 2.0+, torchmetrics 0.11+
- setup: switch to pyannote.core 5.0+, pyannote.database 5.0+, and pyannote.pipeline 3.0+
- setup: switch to speechbrain 0.5.14+

## Version 2.1.1 (2022-10-27)

- BREAKING(pipeline): rewrite speaker diarization pipeline
- feat(pipeline): add option to optimize for DER variant
- feat(clustering): add support for NeMo speaker embedding
- feat(clustering): add FINCH clustering
- feat(clustering): add min_cluster_size hparams to AgglomerativeClustering
- feat(hub): add support for private/gated models
- setup(hub): switch to latest hugginface_hub API
- fix(pipeline): fix support for missing reference in Resegmentation pipeline
- fix(clustering) fix corner case where HMM.fit finds too little states
- BREAKING(pipeline): rewrite speaker diarization pipeline
- feat(pipeline): add option to optimize for DER variant
- feat(clustering): add support for NeMo speaker embedding
- feat(clustering): add FINCH clustering
- feat(clustering): add min_cluster_size hparams to AgglomerativeClustering
- feat(hub): add support for private/gated models
- setup(hub): switch to latest hugginface_hub API
- fix(pipeline): fix support for missing reference in Resegmentation pipeline
- fix(clustering) fix corner case where HMM.fit finds too little states

## Version 2.0.1 (2022-07-20)

- BREAKING: complete rewrite
- feat: much better performance
- feat: Python-first API
- feat: pretrained pipelines (and models) on Huggingface model hub
- feat: multi-GPU training with pytorch-lightning
- feat: data augmentation with torch-audiomentations
- feat: Prodigy recipe for model-assisted audio annotation
- BREAKING: complete rewrite
- feat: much better performance
- feat: Python-first API
- feat: pretrained pipelines (and models) on Huggingface model hub
- feat: multi-GPU training with pytorch-lightning
- feat: data augmentation with torch-audiomentations
- feat: Prodigy recipe for model-assisted audio annotation

## Version 1.1.2 (2021-01-28)

- fix: make sure master branch is used to load pretrained models (#599)
- fix: make sure master branch is used to load pretrained models (#599)

## Version 1.1 (2020-11-08)

- last release before complete rewriting
- last release before complete rewriting

## Version 1.0.1 (2018-07-19)

- fix: fix regression in Precomputed.__call__ (#110, #105)
- fix: fix regression in Precomputed.**call** (#110, #105)

## Version 1.0 (2018-07-03)

- chore: switch from keras to pytorch (with tensorboard support)
- improve: faster & better traning (`AutoLR`, advanced learning rate schedulers, improved batch generators)
- feat: add tunable speaker diarization pipeline (with its own tutorial)
- chore: drop support for Python 2 (use Python 3.6 or later)
- chore: switch from keras to pytorch (with tensorboard support)
- improve: faster & better traning (`AutoLR`, advanced learning rate schedulers, improved batch generators)
- feat: add tunable speaker diarization pipeline (with its own tutorial)
- chore: drop support for Python 2 (use Python 3.6 or later)

## Version 0.3.1 (2017-07-06)

- feat: add python 3 support
- chore: rewrite neural speaker embedding using autograd
- feat: add new embedding architectures
- feat: add new embedding losses
- chore: switch to Keras 2
- doc: add tutorial for (MFCC) feature extraction
- doc: add tutorial for (LSTM-based) speech activity detection
- doc: add tutorial for (LSTM-based) speaker change detection
- doc: add tutorial for (TristouNet) neural speaker embedding
- feat: add python 3 support
- chore: rewrite neural speaker embedding using autograd
- feat: add new embedding architectures
- feat: add new embedding losses
- chore: switch to Keras 2
- doc: add tutorial for (MFCC) feature extraction
- doc: add tutorial for (LSTM-based) speech activity detection
- doc: add tutorial for (LSTM-based) speaker change detection
- doc: add tutorial for (TristouNet) neural speaker embedding

## Version 0.2.1 (2017-03-28)

- feat: add LSTM-based speech activity detection
- feat: add LSTM-based speaker change detection
- improve: refactor LSTM-based speaker embedding
- feat: add librosa basic support
- feat: add SMORMS3 optimizer
- feat: add LSTM-based speech activity detection
- feat: add LSTM-based speaker change detection
- improve: refactor LSTM-based speaker embedding
- feat: add librosa basic support
- feat: add SMORMS3 optimizer

## Version 0.1.4 (2016-09-26)

- feat: add 'covariance_type' option to BIC segmentation
- feat: add 'covariance_type' option to BIC segmentation

## Version 0.1.3 (2016-09-23)

- chore: rename sequence generator in preparation of the release of
TristouNet reproducible research package.
- chore: rename sequence generator in preparation of the release of
TristouNet reproducible research package.

## Version 0.1.2 (2016-09-22)

- first public version
- first public version
2 changes: 1 addition & 1 deletion pyannote/audio/cli/train.py
Expand Up @@ -115,7 +115,7 @@ def configure_optimizers(self):
checkpoint = ModelCheckpoint(
monitor=monitor,
mode=direction,
save_top_k=None if monitor is None else 1,
save_top_k=None if monitor is None else 10,
every_n_epochs=1,
save_last=True,
save_weights_only=False,
Expand Down
107 changes: 75 additions & 32 deletions pyannote/audio/models/blocks/pooling.py
@@ -1,6 +1,6 @@
# MIT License
#
# Copyright (c) 2020 CNRS
# Copyright (c) 2020- CNRS
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
Expand All @@ -26,6 +26,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange


class StatsPool(nn.Module):
Expand All @@ -40,49 +41,91 @@ class StatsPool(nn.Module):
"""

def forward(
self, sequences: torch.Tensor, weights: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""Forward pass
def _pool(self, sequences: torch.Tensor, weights: torch.Tensor) -> torch.Tensor:
"""Helper function to compute statistics pooling
Assumes that weights are already interpolated to match the number of frames
in sequences and that they encode the activation of only one speaker.
Parameters
----------
sequences : (batch, channel, frames) torch.Tensor
Sequences.
weights : (batch, frames) torch.Tensor, optional
When provided, compute weighted mean and standard deviation.
sequences : (batch, features, frames) torch.Tensor
Sequences of features.
weights : (batch, frames) torch.Tensor
(Already interpolated) weights.
Returns
-------
output : (batch, 2 * channel) torch.Tensor
output : (batch, 2 * features) torch.Tensor
Concatenation of mean and (unbiased) standard deviation.
"""

if weights is None:
mean = sequences.mean(dim=2)
std = sequences.std(dim=2, unbiased=True)
weights = weights.unsqueeze(dim=1)
# (batch, 1, frames)

else:
weights = weights.unsqueeze(dim=1)
# (batch, 1, frames)
v1 = weights.sum(dim=2) + 1e-8
mean = torch.sum(sequences * weights, dim=2) / v1

dx2 = torch.square(sequences - mean.unsqueeze(2))
v2 = torch.square(weights).sum(dim=2)

var = torch.sum(dx2 * weights, dim=2) / (v1 - v2 / v1 + 1e-8)
std = torch.sqrt(var)

return torch.cat([mean, std], dim=1)

def forward(
self, sequences: torch.Tensor, weights: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""Forward pass
num_frames = sequences.shape[2]
num_weights = weights.shape[2]
if num_frames != num_weights:
warnings.warn(
f"Mismatch between frames ({num_frames}) and weights ({num_weights}) numbers."
)
weights = F.interpolate(
weights, size=num_frames, mode="linear", align_corners=False
)
Parameters
----------
sequences : (batch, features, frames) torch.Tensor
Sequences of features.
weights : (batch, frames) or (batch, speakers, frames) torch.Tensor, optional
Compute weighted mean and standard deviation, using provided `weights`.
v1 = weights.sum(dim=2)
mean = torch.sum(sequences * weights, dim=2) / v1
Note
----
`sequences` and `weights` might use a different number of frames, in which case `weights`
are interpolated linearly to reach the number of frames in `sequences`.
dx2 = torch.square(sequences - mean.unsqueeze(2))
v2 = torch.square(weights).sum(dim=2)
Returns
-------
output : (batch, 2 * features) or (batch, speakers, 2 * features) torch.Tensor
Concatenation of mean and (unbiased) standard deviation. When `weights` are
provided with the `speakers` dimension, `output` is computed for each speaker
separately and returned as (batch, speakers, 2 * channel)-shaped tensor.
"""

var = torch.sum(dx2 * weights, dim=2) / (v1 - v2 / v1)
std = torch.sqrt(var)
if weights is None:
mean = sequences.mean(dim=-1)
std = sequences.std(dim=-1, correction=1)
return torch.cat([mean, std], dim=-1)

return torch.cat([mean, std], dim=1)
if weights.dim() == 2:
has_speaker_dimension = False
weights = weights.unsqueeze(dim=1)
# (batch, frames) -> (batch, 1, frames)
else:
has_speaker_dimension = True

# interpolate weights if needed
_, _, num_frames = sequences.shape
_, _, num_weights = weights.shape
if num_frames != num_weights:
warnings.warn(
f"Mismatch between frames ({num_frames}) and weights ({num_weights}) numbers."
)
weights = F.interpolate(weights, size=num_frames, mode="nearest")

output = rearrange(
torch.vmap(self._pool, in_dims=(None, 1))(sequences, weights),
"speakers batch features -> batch speakers features",
)

if not has_speaker_dimension:
return output.squeeze(dim=1)

return output

0 comments on commit f45da71

Please sign in to comment.