Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

draft: add prototype for safer checkpoint format #3186

Open
wants to merge 1 commit into
base: dev
Choose a base branch
from

Conversation

mbway
Copy link
Contributor

@mbway mbway commented Apr 11, 2024

This is a proof of concept for using a 'safer' checkpoint format in composer, mentioned in #3181.
This PR is intended for inspiration and is not expected to be merged.

pickle and therefore .pt and composer's .tar checkpoint formats all have the possibility of executing arbitrary code during loading which is undesirable. The approach taken here to avoid this is to use 'inert' formats like .npy, .safetensors and json to store data.

This implementation patches the composer saving and loading functionality to use the custom format if the checkpoint is given the file extension: .safe.

When saving, state_dict is traversed and anything not json serializable is extracted and replaced with some state that can be identified and restored during loading such as {"__safe_storage_obj": "timedelta", "seconds": 123} for a timedelta object, or {"__safe_storage_obj": "tensor", "id": "foo"} for a tensor (where the tensor data is stored separately).

This proof of concept has lots of room for improvement:

  • tensors could be converted to numpy arrays so there are only two components to a safe checkpoint, 'arrays' and 'other' instead of 'tensors', 'arrays', and 'other. This would eliminate the safetensors dependency. Many of the safetensors benefits such as memory mapping are not utilised since the file is put inside an archive. Note that the inverse would not work: Not all numpy arrays can be converted to tensors because torch doesn't support all dtypes.
  • could save individual .npy files instead of using savez so that the arrays aren't zipped twice
  • json could be substituted with msgpack or similar
  • the implementation is not very efficient as the state dict is copied to replace the non-serializable parts and extra disk space is required as the files have to first be saved, then zipped afterwards.
  • in this implementation the supported data types are hard-coded, but a production quality implementation may support more types out of the box and allow the user to register more.

Before submitting

  • Have you read the contributor guidelines?
  • [-] Is this change a documentation change or typo fix? If so, skip the rest of this checklist.
  • Was this change discussed/approved in a GitHub issue first? It is much more likely to be merged if so.
  • Did you update any related docs and document your change?
  • Did you update any related tests and add any new tests related to your change? (see testing)
  • Did you run the tests locally to make sure they pass?
  • Did you run pre-commit on your change? (see the pre-commit section of prerequisites)

@mvpatel2000
Copy link
Contributor

CC: @eracah @dakinggg

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants