Skip to content

Commit

Permalink
fix(hook): fix torch.Tensor support in ArtifactHook
Browse files Browse the repository at this point in the history
  • Loading branch information
hbredin committed May 8, 2024
1 parent 7a90137 commit 4615808
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

- fix(task): fix random generators and their reproducibility (with [@FrenchKrab](https://github.com/FrenchKrab))
- fix(task): fix estimation of training set size (with [@FrenchKrab](https://github.com/FrenchKrab))
- fix(hook): fix `torch.Tensor` support in `ArtifactHook`

### Improvements

Expand Down
4 changes: 4 additions & 0 deletions pyannote/audio/pipelines/utils/hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from copy import deepcopy
from typing import Any, Mapping, Optional, Text

import torch
from rich.progress import (
BarColumn,
Progress,
Expand Down Expand Up @@ -75,6 +76,9 @@ def __call__(
):
return

if isinstance(step_artifact, torch.Tensor):
step_artifact = step_artifact.numpy(force=True)

file.setdefault(self.file_key, dict())[step_name] = deepcopy(step_artifact)


Expand Down

0 comments on commit 4615808

Please sign in to comment.