Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add F1 score, precision, and recall metrics as MultilabelSegmentation default metrics #1336

Open
wants to merge 40 commits into
base: develop
Choose a base branch
from
Open
Changes from 2 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
e4680c2
add f1,precision,recall metrics for multilabel task
FrenchKrab Apr 19, 2023
d771729
black formatting
FrenchKrab Apr 20, 2023
458bebe
simplify MultilabelSegmentation's default_metric logic
FrenchKrab Apr 20, 2023
cd47969
better shape for tensors passed to MultilabelSegmentation metrics
FrenchKrab Apr 20, 2023
2d97b0b
use macro avg for default MultilabelSegmentation metrics
FrenchKrab Apr 20, 2023
53e676d
black format
FrenchKrab Apr 20, 2023
999252e
add support for "per class" metric for MultilabelSegmentation
FrenchKrab Apr 20, 2023
e698bb3
fix logic in multilabel setup_validation_metric
FrenchKrab Apr 21, 2023
b7cf6f5
rename "_per_metric" -> "_classwise"
FrenchKrab Apr 21, 2023
a890ef6
add Loggable and LoggableHistogram classes
FrenchKrab Apr 21, 2023
7ac8209
Revert "add Loggable and LoggableHistogram classes" (oops, wrong branch)
FrenchKrab Apr 24, 2023
68a4d47
fix wrong default_metric return value
FrenchKrab Apr 25, 2023
7e06f4b
make MultilabelSegmentation global metric actually be multilabel
FrenchKrab May 2, 2023
e008ad6
update comments
FrenchKrab May 2, 2023
1c1674b
Merge branch 'develop' into multilabel_default_metrics
FrenchKrab May 2, 2023
9b43f41
small fix
FrenchKrab May 2, 2023
3f0bf90
Merge branch 'multilabel_default_metrics' of github.com:FrenchKrab/py…
FrenchKrab May 2, 2023
05638c7
Merge branch 'develop' into multilabel_default_metrics
hbredin May 10, 2023
5f51b50
add ignore_index to default_metric_classwise
FrenchKrab May 11, 2023
3107105
Revert "add ignore_index to default_metric_classwise"
FrenchKrab May 11, 2023
0551070
fix: raise TypeError on wrong device type in Pipeline.to and Inferenc…
chai3 Jun 8, 2023
30ddb0b
feat(task): add support for multi-task models (#1374)
hbredin Jun 12, 2023
4eb7190
fix(inference): fix multi-task inference
hbredin Jun 12, 2023
dcdfc15
feat: update FAQtory default answer
hbredin Jun 15, 2023
3363be6
improve(test): use pyannote.database.registry (#1413)
hbredin Jun 22, 2023
017c910
feat(pipeline): add `return_embeddings` option to `SpeakerDiarization…
flyingleafe Jun 23, 2023
cf0e3b3
fix: fix missed speech at the very beginning/end
hbredin Jun 27, 2023
f393546
doc: add note to self regarding cluster reassignment (#1419)
hbredin Jun 28, 2023
35be745
fix(doc): fix typo in diarization docstring
DiaaAj Jul 9, 2023
bc0920f
ci: update suggest.md (#1435)
hbredin Jul 16, 2023
7194929
feat: add support for WeSpeaker embeddings (#1444)
hbredin Aug 2, 2023
37b39b0
fix: fix security issue in FAQtory bot
aashish-19 Aug 7, 2023
5a7df38
Update README.md
hbredin Aug 30, 2023
2af703d
Update README.md
hbredin Aug 30, 2023
b660b1e
fix(task): fix MultiLabelSegmentation.val_monitor
FrenchKrab Sep 15, 2023
11e8e6c
Merge branch 'develop' into multilabel_default_metrics
hbredin Sep 15, 2023
9df6944
fix(core): fix Model.example_output for embedding models
hbredin Sep 16, 2023
4be15a4
update multilabel classwise metrics naming
FrenchKrab Sep 18, 2023
9b9c966
update multilabel default metric + docstring + black formatting
FrenchKrab Sep 18, 2023
10a51f9
Merge branch 'develop' into multilabel_default_metrics
hbredin Sep 18, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
32 changes: 31 additions & 1 deletion pyannote/audio/tasks/segmentation/multilabel.py
Expand Up @@ -29,7 +29,7 @@
from pyannote.database import Protocol
from pyannote.database.protocol import SegmentationProtocol
from torch_audiomentations.core.transforms_interface import BaseWaveformTransform
from torchmetrics import Metric
from torchmetrics import F1Score, Metric, Precision, Recall

from pyannote.audio.core.task import Problem, Resolution, Specifications, Task
from pyannote.audio.tasks.segmentation.mixins import SegmentationTaskMixin
Expand Down Expand Up @@ -255,6 +255,16 @@ def validation_step(self, batch, batch_idx: int):
y_true = y_true[mask]
loss = F.binary_cross_entropy(y_pred, y_true.type(torch.float))

self.model.validation_metric(y_pred, y_true)

self.model.log_dict(
self.model.validation_metric,
on_step=False,
on_epoch=True,
prog_bar=True,
logger=True,
)

self.model.log(
f"{self.logging_prefix}ValLoss",
loss,
Expand All @@ -265,6 +275,26 @@ def validation_step(self, batch, batch_idx: int):
)
return {"loss": loss}

def default_metric(
self,
) -> Union[Metric, Sequence[Metric], Dict[str, Metric]]:
classes = None
if self.classes is not None:
classes = self.classes
else:
classes = self.protocol.stats()["labels"].keys()
FrenchKrab marked this conversation as resolved.
Show resolved Hide resolved

if classes is not None:
class_count = len(classes)
classification_type = "multilabel" if class_count > 1 else "binary"
return [
F1Score(task=classification_type, num_labels=class_count),
Precision(task=classification_type, num_labels=class_count),
Recall(task=classification_type, num_labels=class_count),
]
else:
return []

@property
def val_monitor(self):
"""Quantity (and direction) to monitor
Expand Down