Skip to content

Commit

Permalink
initial revision (#5328)
Browse files Browse the repository at this point in the history
  • Loading branch information
piyush-kansal committed Sep 15, 2023
1 parent b5d89cd commit e29f53b
Show file tree
Hide file tree
Showing 3 changed files with 191 additions and 3 deletions.
20 changes: 19 additions & 1 deletion fairseq/checkpoint_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,20 @@ def is_better(a, b):
"checkpoint_last{}.pt".format(suffix)
] = not cfg.no_last_checkpoints

extra_state = {"train_iterator": epoch_itr.state_dict(), "val_loss": val_loss}
extra_state = {
"train_iterator": epoch_itr.state_dict(),
"val_loss": val_loss,
}

# Going forward, different tasks could expose an API like this to dump all
# the checkpoint worthy attributes in a dictionary which then will be
# merged with the parent dictionary to create the "extra_state". This
# allows for an extensible yet simple design to checkpoint task level
# attributes
if hasattr(trainer.task, "get_checkpoint_dict"):
extra_state = {**extra_state, **trainer.task.get_checkpoint_dict()}
logger.info(f"{trainer.task.__class__} checkpoint worthy attributes are ready to be persisted with the checkpoint")

if hasattr(save_checkpoint, "best"):
extra_state.update({"best": save_checkpoint.best})

Expand Down Expand Up @@ -275,6 +288,11 @@ def load_checkpoint(cfg: CheckpointConfig, trainer, **passthrough_args):
epoch=itr_state["epoch"], load_dataset=True, **passthrough_args
)
epoch_itr.load_state_dict(itr_state)

# Preload the observer stats for Supernet
supernet_cp_dict = extra_state.get("supernet", {})
if supernet_cp_dict and hasattr(trainer.task, "set_checkpoint_dict"):
trainer.task.set_checkpoint_dict(supernet_cp_dict)
else:
epoch_itr = trainer.get_train_iterator(
epoch=1, load_dataset=True, **passthrough_args
Expand Down
2 changes: 0 additions & 2 deletions tests/test_checkpoint_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
from io import StringIO
from unittest.mock import patch

from omegaconf import OmegaConf

from fairseq import checkpoint_utils
from tests.utils import (
create_dummy_data,
Expand Down
172 changes: 172 additions & 0 deletions tests/test_checkpoint_utils_for_task_level_attributes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
#!/usr/bin/env fbpython
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.

import contextlib
import logging
import unittest
from io import StringIO
from unittest.mock import MagicMock, patch

import torch
from fairseq import checkpoint_utils, data
from omegaconf import OmegaConf


def mock_trainer(epoch, num_updates, iterations_in_epoch):
trainer = MagicMock()
trainer.load_checkpoint.return_value = {
"train_iterator": {
"epoch": epoch,
"iterations_in_epoch": iterations_in_epoch,
"shuffle": False,
},
"supernet": checkpoint_dict()["supernet"],
}
trainer.get_num_updates.return_value = num_updates
trainer.task.get_checkpoint_dict.return_value = checkpoint_dict()
trainer.task.set_checkpoint_dict = MagicMock()

return trainer


def checkpoint_dict():
return {
"supernet": {
"observer_stats": {
(
4,
16,
"MovingAveragePerChannelMinMax",
"MovingAveragePerChannelMinMax",
): {"mod1": 1, "mod2": 2, "mod3": 3}
}
}
}


def mock_dict():
d = MagicMock()
d.pad.return_value = 1
d.eos.return_value = 2
d.unk.return_value = 3
return d


def get_trainer_and_epoch_itr(epoch, epoch_size, num_updates, iterations_in_epoch):
tokens = torch.LongTensor(list(range(epoch_size))).view(1, -1)
tokens_ds = data.TokenBlockDataset(
tokens,
sizes=[tokens.size(-1)],
block_size=1,
pad=0,
eos=1,
include_targets=False,
)
trainer = mock_trainer(epoch, num_updates, iterations_in_epoch)
dataset = data.LanguagePairDataset(
tokens_ds, tokens_ds.sizes, mock_dict(), shuffle=False
)
epoch_itr = data.EpochBatchIterator(
dataset=dataset,
collate_fn=dataset.collater,
batch_sampler=[[i] for i in range(epoch_size)],
)
return trainer, epoch_itr


def get_mock_cfg(finetune_from_model):
cfg_mock = OmegaConf.create(
{
"checkpoint": {
"save_dir": None,
"optimizer_overrides": "{}",
"reset_dataloader": False,
"reset_meters": False,
"reset_optimizer": False,
"reset_lr_scheduler": False,
"finetune_from_model": finetune_from_model,
"model_parallel_size": 1,
"restore_file": "checkpoint_last.pt",
"no_save": False,
"save_interval_updates": 0,
"no_last_checkpoints": False,
"keep_interval_updates": 0,
"keep_last_epochs": 0,
"keep_best_checkpoints": 0,
},
"common": {
"model_parallel_size": 1,
},
}
)
return cfg_mock


class TestCheckpointsForTaskLevelAttributes(unittest.TestCase):
def setUp(self) -> None:
self.cfg_mock = get_mock_cfg(None)
self.patches = {
"os.makedirs": MagicMock(),
"os.path.join": MagicMock(),
"os.path.isfile": MagicMock(return_value=True),
"os.path.isabs": MagicMock(return_value=False),
"fairseq.file_io.PathManager.exists": MagicMock(return_value=False),
}
self.applied_patches = [patch(p, d) for p, d in self.patches.items()]
[p.start() for p in self.applied_patches]
logging.disable(logging.CRITICAL)

self.trainer, self.epoch_itr = get_trainer_and_epoch_itr(2, 150, 200, 50)
self.trainer.get_train_iterator = MagicMock(return_value=self.epoch_itr)
self.epoch_itr.next_epoch_itr(shuffle=False)

checkpoint_utils.save_checkpoint(
self.cfg_mock.checkpoint, self.trainer, self.epoch_itr, None
)

def tearDown(self):
patch.stopall()
logging.disable(logging.NOTSET)

def test_verify_checkpoint(self) -> None:
cp_dict = self.trainer.task.get_checkpoint_dict()
self.assertTrue(len(cp_dict) == 1)
self.assertTrue("supernet" in cp_dict)
self.assertTrue("observer_stats" in cp_dict["supernet"])
self.assertTrue(len(cp_dict["supernet"]["observer_stats"]) == 1)
self.assertTrue(
(
4,
16,
"MovingAveragePerChannelMinMax",
"MovingAveragePerChannelMinMax",
)
in cp_dict["supernet"]["observer_stats"]
)
self.assertTrue(
cp_dict["supernet"]["observer_stats"][
(
4,
16,
"MovingAveragePerChannelMinMax",
"MovingAveragePerChannelMinMax",
)
]
== {"mod1": 1, "mod2": 2, "mod3": 3}
)

def test_load_checkpoint(self) -> None:
with contextlib.redirect_stdout(StringIO()):
# Now, load checkpoint to ensure the respective logic works as expected
_, epoch_itr = checkpoint_utils.load_checkpoint(
self.cfg_mock.checkpoint, self.trainer
)

self.trainer.task.set_checkpoint_dict.assert_called_once_with(
checkpoint_dict()["supernet"]
)


if __name__ == "__main__":
unittest.main()

0 comments on commit e29f53b

Please sign in to comment.