Skip to content

Commit

Permalink
attempt a metaprogramming solution for the unwrapped model issue in a…
Browse files Browse the repository at this point in the history
…ccelerate
  • Loading branch information
lucidrains committed Jan 6, 2024
1 parent 290323c commit 80cde12
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 25 deletions.
63 changes: 39 additions & 24 deletions magvit2_pytorch/trainer.py
Expand Up @@ -53,9 +53,26 @@ def cycle(dl):
for data in dl:
yield data

# a forwarding wrapper, to take care of the unwrapped model issue

class ForwardingWrapper:
def __init__(self, parent, child):
self.parent = parent
self.child = child

def __getattr__(self, key):
if hasattr(self.parent, key):
return getattr(self.parent, key)

return getattr(self.child, key)

def __call__(self, *args, **kwargs):
call_fn = self.__getattr__('__call__')
return call_fn(*args, **kwargs)

# class

class VideoTokenizerTrainer(Module):
class VideoTokenizerTrainer:
@beartype
def __init__(
self,
Expand Down Expand Up @@ -87,8 +104,6 @@ def __init__(
optimizer_kwargs: dict = dict(),
dataset_kwargs: dict = dict()
):
super().__init__()

self.use_wandb_tracking = use_wandb_tracking

if use_wandb_tracking:
Expand Down Expand Up @@ -194,16 +209,20 @@ def __init__(
self.discr_optimizer
)

# wrap model with forwarding wrapper

self.model = ForwardingWrapper(self.model, self.accelerator.unwrap_model(self.model))

# only use adversarial training after a certain number of steps

self.discr_start_after_step = discr_start_after_step

# multiscale discr losses

self.has_multiscale_discrs = self.unwrapped_model.has_multiscale_discrs
self.has_multiscale_discrs = self.model.has_multiscale_discrs
self.multiscale_discr_optimizers = []

for ind, discr in enumerate(self.unwrapped_model.multiscale_discrs):
for ind, discr in enumerate(self.model.multiscale_discrs):
multiscale_optimizer = get_optimizer(discr.parameters(), lr = learning_rate, **optimizer_kwargs)

self.multiscale_discr_optimizers.append(multiscale_optimizer)
Expand All @@ -227,7 +246,7 @@ def __init__(

# keep track of train step

self.register_buffer('step', torch.tensor(0))
self.step = 0

# move ema to the proper device

Expand All @@ -252,20 +271,16 @@ def trackers(
self.accelerator.end_training()

def log(self, **data_kwargs):
self.accelerator.log(data_kwargs, step = self.step.item())
self.accelerator.log(data_kwargs, step = self.step)

@property
def device(self):
return self.unwrapped_model.device
return self.model.device

@property
def is_main(self):
return self.accelerator.is_main_process

@property
def unwrapped_model(self):
return self.accelerator.unwrap_model(self.model)

@property
def is_local_main(self):
return self.accelerator.is_local_main_process
Expand All @@ -288,15 +303,15 @@ def save(self, path, overwrite = True):
assert overwrite or not path.exists()

pkg = dict(
model = self.unwrapped_model.state_dict(),
model = self.model.state_dict(),
ema_model = self.ema_model.state_dict(),
optimizer = self.optimizer.state_dict(),
discr_optimizer = self.discr_optimizer.state_dict(),
warmup = self.warmup.state_dict(),
scheduler = self.scheduler.state_dict(),
discr_warmup = self.discr_warmup.state_dict(),
discr_scheduler = self.discr_scheduler.state_dict(),
step = self.step.item()
step = self.step
)

for ind, opt in enumerate(self.multiscale_discr_optimizers):
Expand All @@ -322,16 +337,16 @@ def load(self, path):
for ind, opt in enumerate(self.multiscale_discr_optimizers):
opt.load_state_dict(pkg[f'multiscale_discr_optimizer_{ind}'])

self.step.copy_(pkg['step'])
self.step = pkg['step']

def train_step(self, dl_iter):
self.model.train()

step = self.step.item()
step = self.step

# determine whether to train adversarially

train_adversarially = self.unwrapped_model.use_gan and (step + 1) > self.discr_start_after_step
train_adversarially = self.model.use_gan and (step + 1) > self.discr_start_after_step

adversarial_loss_weight = 0. if not train_adversarially else None
multiscale_adversarial_loss_weight = 0. if not train_adversarially else None
Expand Down Expand Up @@ -387,7 +402,7 @@ def train_step(self, dl_iter):
# if adversarial loss is turned off, continue

if not train_adversarially:
self.step.add_(1)
self.step += 1
return

# discriminator and multiscale discriminators
Expand Down Expand Up @@ -424,10 +439,10 @@ def train_step(self, dl_iter):
self.print(f'discr loss: {discr_loss_breakdown.discr_loss.item():.3f}')

if exists(self.max_grad_norm):
self.accelerator.clip_grad_norm_(self.unwrapped_model.discr_parameters(), self.max_grad_norm)
self.accelerator.clip_grad_norm_(self.model.discr_parameters(), self.max_grad_norm)

if self.has_multiscale_discrs:
for multiscale_discr in self.unwrapped_model.multiscale_discrs:
for multiscale_discr in self.model.multiscale_discrs:
self.accelerator.clip_grad_norm_(multiscale_discr.parameters(), self.max_grad_norm)

self.discr_optimizer.step()
Expand All @@ -442,7 +457,7 @@ def train_step(self, dl_iter):

# update train step

self.step.add_(1)
self.step += 1

@torch.no_grad()
def valid_step(
Expand All @@ -464,7 +479,7 @@ def valid_step(
valid_video = valid_video.to(self.device)

with self.accelerator.autocast():
loss, _ = self.unwrapped_model(valid_video, return_recon_loss_only = True)
loss, _ = self.model(valid_video, return_recon_loss_only = True)
ema_loss, ema_recon_video = self.ema_model(valid_video, return_recon_loss_only = True)

recon_loss += loss / self.grad_accum_every
Expand Down Expand Up @@ -496,7 +511,7 @@ def valid_step(

real_and_recon = rearrange([valid_videos, recon_videos], 'n b c f h w -> c f (b h) (n w)')

validate_step = self.step.item() // self.validate_every_step
validate_step = self.step // self.validate_every_step

sample_path = str(self.results_folder / f'sampled.{validate_step}.gif')

Expand All @@ -506,7 +521,7 @@ def valid_step(

def train(self):

step = self.step.item()
step = self.step

dl_iter = cycle(self.dataloader)
valid_dl_iter = cycle(self.valid_dataloader)
Expand Down
2 changes: 1 addition & 1 deletion magvit2_pytorch/version.py
@@ -1 +1 @@
__version__ = '0.2.2'
__version__ = '0.3.0'

0 comments on commit 80cde12

Please sign in to comment.