Skip to content

Commit

Permalink
use torch_utils.float64
Browse files Browse the repository at this point in the history
  • Loading branch information
w-e-w committed May 16, 2024
1 parent 41f6684 commit f015b94
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions modules/sd_samplers_timesteps_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@

from modules import shared
from modules.models.diffusion.uni_pc import uni_pc
from modules.torch_utils import float64


@torch.no_grad()
def ddim(model, x, timesteps, extra_args=None, callback=None, disable=None, eta=0.0):
alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
alphas = alphas_cumprod[timesteps]
alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(torch.float64 if x.device.type != 'mps' and x.device.type != 'xpu' else torch.float32)
alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(float64(x))
sqrt_one_minus_alphas = torch.sqrt(1 - alphas)
sigmas = eta * np.sqrt((1 - alphas_prev.cpu().numpy()) / (1 - alphas.cpu()) * (1 - alphas.cpu() / alphas_prev.cpu().numpy()))

Expand Down Expand Up @@ -43,7 +44,7 @@ def ddim(model, x, timesteps, extra_args=None, callback=None, disable=None, eta=
def plms(model, x, timesteps, extra_args=None, callback=None, disable=None):
alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
alphas = alphas_cumprod[timesteps]
alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(torch.float64 if x.device.type != 'mps' and x.device.type != 'xpu' else torch.float32)
alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(float64(x))
sqrt_one_minus_alphas = torch.sqrt(1 - alphas)

extra_args = {} if extra_args is None else extra_args
Expand Down

0 comments on commit f015b94

Please sign in to comment.