Skip to content

Commit

Permalink
Keep task level checkpoint key name generic (#5330)
Browse files Browse the repository at this point in the history
  • Loading branch information
piyush-kansal committed Sep 15, 2023
1 parent e29f53b commit 7409af7
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 14 deletions.
10 changes: 5 additions & 5 deletions fairseq/checkpoint_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def is_better(a, b):
# attributes
if hasattr(trainer.task, "get_checkpoint_dict"):
extra_state = {**extra_state, **trainer.task.get_checkpoint_dict()}
logger.info(f"{trainer.task.__class__} checkpoint worthy attributes are ready to be persisted with the checkpoint")
logger.info(f"State of {trainer.task.__class__.__name__} is ready to be persisted with the checkpoint")

if hasattr(save_checkpoint, "best"):
extra_state.update({"best": save_checkpoint.best})
Expand Down Expand Up @@ -289,10 +289,10 @@ def load_checkpoint(cfg: CheckpointConfig, trainer, **passthrough_args):
)
epoch_itr.load_state_dict(itr_state)

# Preload the observer stats for Supernet
supernet_cp_dict = extra_state.get("supernet", {})
if supernet_cp_dict and hasattr(trainer.task, "set_checkpoint_dict"):
trainer.task.set_checkpoint_dict(supernet_cp_dict)
# Preload the checkpoint for the task
task_cp_dict = extra_state.get(trainer.task.__class__.__name__, {})
if task_cp_dict and hasattr(trainer.task, "set_checkpoint_dict"):
trainer.task.set_checkpoint_dict(task_cp_dict)
else:
epoch_itr = trainer.get_train_iterator(
epoch=1, load_dataset=True, **passthrough_args
Expand Down
18 changes: 9 additions & 9 deletions tests/test_checkpoint_utils_for_task_level_attributes.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@ def mock_trainer(epoch, num_updates, iterations_in_epoch):
"iterations_in_epoch": iterations_in_epoch,
"shuffle": False,
},
"supernet": checkpoint_dict()["supernet"],
"FakeTask": checkpoint_dict()["FakeTask"],
}
trainer.get_num_updates.return_value = num_updates
trainer.task.__class__.__name__ = "FakeTask"
trainer.task.get_checkpoint_dict.return_value = checkpoint_dict()
trainer.task.set_checkpoint_dict = MagicMock()

Expand All @@ -31,7 +32,7 @@ def mock_trainer(epoch, num_updates, iterations_in_epoch):

def checkpoint_dict():
return {
"supernet": {
"FakeTask": {
"observer_stats": {
(
4,
Expand Down Expand Up @@ -131,20 +132,20 @@ def tearDown(self):
def test_verify_checkpoint(self) -> None:
cp_dict = self.trainer.task.get_checkpoint_dict()
self.assertTrue(len(cp_dict) == 1)
self.assertTrue("supernet" in cp_dict)
self.assertTrue("observer_stats" in cp_dict["supernet"])
self.assertTrue(len(cp_dict["supernet"]["observer_stats"]) == 1)
self.assertTrue("FakeTask" in cp_dict)
self.assertTrue("observer_stats" in cp_dict["FakeTask"])
self.assertTrue(len(cp_dict["FakeTask"]["observer_stats"]) == 1)
self.assertTrue(
(
4,
16,
"MovingAveragePerChannelMinMax",
"MovingAveragePerChannelMinMax",
)
in cp_dict["supernet"]["observer_stats"]
in cp_dict["FakeTask"]["observer_stats"]
)
self.assertTrue(
cp_dict["supernet"]["observer_stats"][
cp_dict["FakeTask"]["observer_stats"][
(
4,
16,
Expand All @@ -163,10 +164,9 @@ def test_load_checkpoint(self) -> None:
)

self.trainer.task.set_checkpoint_dict.assert_called_once_with(
checkpoint_dict()["supernet"]
checkpoint_dict()["FakeTask"]
)


if __name__ == "__main__":
unittest.main()

0 comments on commit 7409af7

Please sign in to comment.