Skip to content

Commit

Permalink
fix torchvision import (#6796)
Browse files Browse the repository at this point in the history
  • Loading branch information
patrickvonplaten committed Jan 31, 2024
1 parent e7a1666 commit 674d43f
Showing 1 changed file with 9 additions and 1 deletion.
10 changes: 9 additions & 1 deletion src/diffusers/training_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import numpy as np
import torch
from torchvision import transforms
from transformers import is_torchvision_available

from .models import UNet2DConditionModel
from .utils import (
Expand All @@ -23,6 +23,9 @@
if is_peft_available():
from peft import set_peft_model_state_dict

if is_torchvision_available():
from torchvision import transforms


def set_seed(seed: int):
"""
Expand Down Expand Up @@ -79,6 +82,11 @@ def resolve_interpolation_mode(interpolation_type: str):
`torchvision.transforms.InterpolationMode`: an `InterpolationMode` enum used by torchvision's `resize`
transform.
"""
if not is_torchvision_available():
raise ImportError(
"Please make sure to install `torchvision` to be able to use the `resolve_interpolation_mode()` function."
)

if interpolation_type == "bilinear":
interpolation_mode = transforms.InterpolationMode.BILINEAR
elif interpolation_type == "bicubic":
Expand Down

0 comments on commit 674d43f

Please sign in to comment.