Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Automatically resumable training in learner class #4020

Open
wants to merge 2 commits into
base: master
Choose a base branch
from

Conversation

PalaashAgrawal
Copy link

TLDR

This PR introduces granular resumable training in Learner class. Specifically, the Learner can resume training from the exact epoch and iteration that a checkpoint was saved at. The Checkpoint file saves this info (along with model and optimizer states). This way, when a user invokes Learner.load, the Learner automatically resumes training from the last "saved" n_epoch and n_iter.

In Summary

  1. Learner.save saves the iteration info along with model and opt states. Specifically, the epoch (Learner.epoch) and the iteration (Learner.iter) in the epoch. Similarly, Learner.load checks for saved epoch and iter.
  2. If Learner.load is invoked, Learner.fit retrieves info on which epoch and iter to resume training on.
  3. SkipToEpoch has been modified to SkipToIter, and skips training to the iterth iteration in the epochth epoch.

Problem Statement

I think that the Learner class is designed for mini-scale training on a local GPU, and does not support large-scale training without having to write a lot of custom code for housekeeping.

(FYI, I am currently training an LLM using fastai)

  1. One big problem I faced was that I could not resume the training if my hardware suddenly failed. Sure, there's the start_epoch argument in Learner.fit. But most likely for large-scale training (eg for LLMs), an epoch itself is EXTREMELY large, and it makes more sense to be able to resume from a specific iteration in a specific epoch.

  2. Ideally, the learner should automatically resume training from the last saved epoch and iteration. Therefore the checkpoint file should save the current epoch and iter info along with model and opt states. Learn.fit should automatically resume from this iteration, UNLESS the user specifically specifies start_epoch (and by extension, start_iter) in Learner.fit.

Code

Note: Please suggest improvement in the code structure to make the code cleaner and style-compliant. I'm writing obvious issues that MAY be problematic. Also, let me know what all unit tests, jupyter notebook experiments and documentations need to be written. I'll be happy to spend more time on it.

  1. Learner save patch function passes an additional dictionary containing current epoch and iter, to the save_model function, IF with_iter is True. If with_iter is True, with_opt must also be True.
  2. save_model saves a dictionary (using torch.save) containing model, opt and iter (where iter is the dictionary containing epoch and iter info)
  3. Learner load patch function saves iter info (dictionary) in a new Learner variable Learner.resumeIter.
  4. load_model function not only loads model and opt states, but also returns iter info to the above load function.
    NOTE: not sure how to modify Learner.epoch and Learner.iter by reference. So I had to introduce return function. Please suggest a better way.
  5. Learner.load checks for self.resumeIter variable, and initializes SkiptoIter (modified SkipToEpoch) Callback. However, start_epoch and a new argument start_iter override the loaded epoch and iter values.
  6. SkipToIter essentially adds a before_batch methods apart from the before_epoch method, which ensures that Training is skipped until the desired iteration.

@PalaashAgrawal PalaashAgrawal changed the title PR: automatically resumable learner Automatically resumable training in learner class Apr 5, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

1 participant