Skip to content

Commit

Permalink
address #18
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed May 7, 2024
1 parent c4d073a commit d266deb
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 2 deletions.
6 changes: 5 additions & 1 deletion magvit2_pytorch/trainer.py
Expand Up @@ -267,6 +267,10 @@ def is_main(self):
def is_local_main(self):
return self.accelerator.is_local_main_process

@property
def unwrapped_model(self):
return self.accelerator.unwrap_model(self.model)

def wait(self):
return self.accelerator.wait_for_everyone()

Expand Down Expand Up @@ -461,7 +465,7 @@ def valid_step(
valid_video = valid_video.to(self.device)

with self.accelerator.autocast():
loss, _ = self.model(valid_video, return_recon_loss_only = True)
loss, _ = self.unwrapped_model(valid_video, return_recon_loss_only = True)
ema_loss, ema_recon_video = self.ema_model(valid_video, return_recon_loss_only = True)

recon_loss += loss / self.grad_accum_every
Expand Down
2 changes: 1 addition & 1 deletion magvit2_pytorch/version.py
@@ -1 +1 @@
__version__ = '0.4.2'
__version__ = '0.4.3'

0 comments on commit d266deb

Please sign in to comment.