Skip to content

Commit

Permalink
Fix k_dpm_2 & k_dpm_2_a on MPS (open-mmlab#2241)
Browse files Browse the repository at this point in the history
Needed to convert `timesteps` to `float32` a bit sooner.

Fixes open-mmlab#1537
  • Loading branch information
psychedelicious committed Feb 5, 2023
1 parent 7386e77 commit 22c1ba5
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 14 deletions.
Expand Up @@ -161,16 +161,16 @@ def set_timesteps(
# standard deviation of the initial noise distribution
self.init_noise_sigma = self.sigmas.max()

timesteps = torch.from_numpy(timesteps).to(device)
timesteps_interpol = self.sigma_to_t(sigmas_interpol).to(device)
interleaved_timesteps = torch.stack((timesteps_interpol[:-2, None], timesteps[1:, None]), dim=-1).flatten()
timesteps = torch.cat([timesteps[:1], interleaved_timesteps])

if str(device).startswith("mps"):
# mps does not support float64
self.timesteps = timesteps.to(device, dtype=torch.float32)
timesteps = torch.from_numpy(timesteps).to(device, dtype=torch.float32)
else:
self.timesteps = timesteps
timesteps = torch.from_numpy(timesteps).to(device)

timesteps_interpol = self.sigma_to_t(sigmas_interpol).to(device)
interleaved_timesteps = torch.stack((timesteps_interpol[:-2, None], timesteps[1:, None]), dim=-1).flatten()

self.timesteps = torch.cat([timesteps[:1], interleaved_timesteps])

self.sample = None

Expand Down
13 changes: 6 additions & 7 deletions src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py
Expand Up @@ -149,18 +149,17 @@ def set_timesteps(
# standard deviation of the initial noise distribution
self.init_noise_sigma = self.sigmas.max()

timesteps = torch.from_numpy(timesteps).to(device)
if str(device).startswith("mps"):
# mps does not support float64
timesteps = torch.from_numpy(timesteps).to(device, dtype=torch.float32)
else:
timesteps = torch.from_numpy(timesteps).to(device)

# interpolate timesteps
timesteps_interpol = self.sigma_to_t(sigmas_interpol).to(device)
interleaved_timesteps = torch.stack((timesteps_interpol[1:-1, None], timesteps[1:, None]), dim=-1).flatten()
timesteps = torch.cat([timesteps[:1], interleaved_timesteps])

if str(device).startswith("mps"):
# mps does not support float64
self.timesteps = timesteps.to(torch.float32)
else:
self.timesteps = timesteps
self.timesteps = torch.cat([timesteps[:1], interleaved_timesteps])

self.sample = None

Expand Down

0 comments on commit 22c1ba5

Please sign in to comment.