Skip to content

Commit

Permalink
fix: raise TypeError on wrong device type in Pipeline.to and Inferenc…
Browse files Browse the repository at this point in the history
…e.to

Fixes 1397
  • Loading branch information
chai3 committed Jun 8, 2023
1 parent 7379f1c commit 0551070
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 1 deletion.
5 changes: 5 additions & 0 deletions pyannote/audio/core/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,11 @@ def __init__(
def to(self, device: torch.device):
"""Send internal model to `device`"""

if not isinstance(device, torch.device):
raise TypeError(
f"`device` must be an instance of `torch.device`, got `{type(device).__name__}`"
)

self.model.to(device)
if self.model.specifications.powerset and not self.skip_conversion:
self._powerset.to(device)
Expand Down
7 changes: 6 additions & 1 deletion pyannote/audio/core/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,9 +324,14 @@ def __call__(self, file: AudioFile, **kwargs):

return self.apply(file, **kwargs)

def to(self, device):
def to(self, device: torch.device):
"""Send pipeline to `device`"""

if not isinstance(device, torch.device):
raise TypeError(
f"`device` must be an instance of `torch.device`, got `{type(device).__name__}`"
)

for _, pipeline in self._pipelines.items():
if hasattr(pipeline, "to"):
_ = pipeline.to(device)
Expand Down
15 changes: 15 additions & 0 deletions pyannote/audio/pipelines/speaker_verification.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,11 @@ def __init__(
self.model_.to(self.device)

def to(self, device: torch.device):
if not isinstance(device, torch.device):
raise TypeError(
f"`device` must be an instance of `torch.device`, got `{type(device).__name__}`"
)

self.model_.to(device)
self.device = device
return self
Expand Down Expand Up @@ -255,6 +260,11 @@ def __init__(
)

def to(self, device: torch.device):
if not isinstance(device, torch.device):
raise TypeError(
f"`device` must be an instance of `torch.device`, got `{type(device).__name__}`"
)

self.classifier_ = SpeechBrain_EncoderClassifier.from_hparams(
source=self.embedding,
savedir=f"{CACHE_DIR}/speechbrain",
Expand Down Expand Up @@ -415,6 +425,11 @@ def __init__(
self.model_.to(self.device)

def to(self, device: torch.device):
if not isinstance(device, torch.device):
raise TypeError(
f"`device` must be an instance of `torch.device`, got `{type(device).__name__}`"
)

self.model_.to(device)
self.device = device
return self
Expand Down

0 comments on commit 0551070

Please sign in to comment.