Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

improve(pipeline): do not extract embeddings in SpeakerDiarization pipeline when max_speakers is 1 #1686

Merged
merged 5 commits into from
May 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
### Improvements

- improve(io): when available, default to using `soundfile` backend
- improve(pipeline): do not extract embeddings when `max_speakers` is set to 1

## Version 3.2.0 (2024-05-08)

Expand Down
51 changes: 31 additions & 20 deletions pyannote/audio/pipelines/speaker_diarization.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,7 @@ def apply(
segmentations = self.get_segmentations(file, hook=hook)
hook("segmentation", segmentations)
# shape: (num_chunks, num_frames, local_num_speakers)
num_chunks, num_frames, local_num_speakers = segmentations.data.shape

# binarize segmentation
if self._segmentation.model.specifications.powerset:
Expand Down Expand Up @@ -507,29 +508,39 @@ def apply(

return diarization

if self.klustering == "OracleClustering" and not return_embeddings:
# skip speaker embedding extraction and clustering when only one speaker
if not return_embeddings and max_speakers < 2:
hard_clusters = np.zeros((num_chunks, local_num_speakers), dtype=np.int8)
embeddings = None
centroids = None

else:
embeddings = self.get_embeddings(
file,
binarized_segmentations,
exclude_overlap=self.embedding_exclude_overlap,
hook=hook,

# skip speaker embedding extraction with oracle clustering
if self.klustering == "OracleClustering" and not return_embeddings:
embeddings = None

else:
embeddings = self.get_embeddings(
file,
binarized_segmentations,
exclude_overlap=self.embedding_exclude_overlap,
hook=hook,
)
hook("embeddings", embeddings)
# shape: (num_chunks, local_num_speakers, dimension)

hard_clusters, _, centroids = self.clustering(
embeddings=embeddings,
segmentations=binarized_segmentations,
num_clusters=num_speakers,
min_clusters=min_speakers,
max_clusters=max_speakers,
file=file, # <== for oracle clustering
frames=self._segmentation.model.receptive_field, # <== for oracle clustering
)
hook("embeddings", embeddings)
# shape: (num_chunks, local_num_speakers, dimension)

hard_clusters, _, centroids = self.clustering(
embeddings=embeddings,
segmentations=binarized_segmentations,
num_clusters=num_speakers,
min_clusters=min_speakers,
max_clusters=max_speakers,
file=file, # <== for oracle clustering
frames=self._segmentation.model.receptive_field, # <== for oracle clustering
)
# hard_clusters: (num_chunks, num_speakers)
# centroids: (num_speakers, dimension)
# hard_clusters: (num_chunks, num_speakers)
# centroids: (num_speakers, dimension)

# number of detected clusters is the number of different speakers
num_different_speakers = np.max(hard_clusters) + 1
Expand Down