Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature Request: load_state_dict should take filenames #1686

Open
soumith opened this issue May 30, 2017 · 3 comments
Open

Feature Request: load_state_dict should take filenames #1686

soumith opened this issue May 30, 2017 · 3 comments
Labels
feature A request for a proper, new feature. module: nn Related to torch.nn triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@soumith
Copy link
Member

soumith commented May 30, 2017

In high memory pressure situations, the following is a common occurrence:

  1. create model
  2. read state_dict from checkpoint file (loads on GPU)
  3. model.load_state_dict(s)

Because of memory pressure, a common workaround is to first do:

s = torch.load('my_file.pt', map_location=lambda storage, loc: storage)

And then load s into model.

This is a very common scenario that we should be able to avoid, and this scenario might have some pitfalls: what happens on part-GPU part-CPU models, what happens on multi-GPU models...

if load_state_dict took a filename directly, it can delete it's existing parameter storages and set them to the new one on the fly, thereby requiring no extra memory.

@soumith
Copy link
Member Author

soumith commented May 30, 2017

the same applies to optimizer state_dicts. for some optimizers like Adagrad, the checkpoints are large, and we can have the same memory pressure situation. optimizers dont even have a .cuda(), so we manually first have to load state_dict onto CPU, and then manually copy over parts to the GPU.

I ran into this while helping @aszlam today.

@alykhantejani
Copy link
Contributor

If load_state_dict takes a filename we should also allow for the map_location param too. A common situation for me is to save a checkpoint on cluster machine and then load it on my macbook (so need to load params onto CPU)

@vadimkantorov
Copy link
Contributor

Me and @szagoruyko are fans of HDF5 format for serialized models, maybe if it could get along nicely with this proposal

@soumith soumith added this to Uncategorized in Issue Status Aug 23, 2017
@soumith soumith added this to usability / simple-fixes in Issue Categories Aug 30, 2017
houseroad added a commit to houseroad/pytorch that referenced this issue Jan 4, 2019
…b18ba1 (pytorch#15739)

Summary:
Pull Request resolved: pytorch#15739

Previous import was 765f5ee823a67a866f4bd28a9860e81f3c811ce8

Included changes:
- **[8384c78](onnx/onnx@8384c78)**: add constantofshape (pytorch#1582) <Rui Zhu>
- **[9afc06c](onnx/onnx@9afc06c)**: Set symbol visibility to hidden for non-Windows (pytorch#1707) <Paul Jesse Hellemn>
- **[6f8a9f0](onnx/onnx@6f8a9f0)**: Revert "Add NonMaxSupression operator (pytorch#1695)" (pytorch#1702) <Lu Fang>
- **[8b89544](onnx/onnx@8b89544)**: Add NonMaxSupression operator (pytorch#1695) <Hector Li>
- **[0a7cc48](onnx/onnx@0a7cc48)**: Add bfloat16 support. (pytorch#1699) <Dmitri Smirnov>
- **[da7c50c](onnx/onnx@da7c50c)**: ONNX does not maintain versions for experimental ops (pytorch#1696) <Ke Zhang>
- **[0c8d857](onnx/onnx@0c8d857)**: Correct type of value_info in Graph (pytorch#1694) <Maik Riechert>
- **[f612532](onnx/onnx@f612532)**: Fix typos (pytorch#1686) <Eundoo Song>

Reviewed By: zrphercule

Differential Revision: D13581674

fbshipit-source-id: a961667184b09d2822815ba5d3fa4198a4c57e88
facebook-github-bot pushed a commit that referenced this issue Jan 4, 2019
…b18ba1 (#15739)

Summary:
Pull Request resolved: #15739

Previous import was 765f5ee823a67a866f4bd28a9860e81f3c811ce8

Included changes:
- **[8384c78](onnx/onnx@8384c78)**: add constantofshape (#1582) <Rui Zhu>
- **[9afc06c](onnx/onnx@9afc06c)**: Set symbol visibility to hidden for non-Windows (#1707) <Paul Jesse Hellemn>
- **[6f8a9f0](onnx/onnx@6f8a9f0)**: Revert "Add NonMaxSupression operator (#1695)" (#1702) <Lu Fang>
- **[8b89544](onnx/onnx@8b89544)**: Add NonMaxSupression operator (#1695) <Hector Li>
- **[0a7cc48](onnx/onnx@0a7cc48)**: Add bfloat16 support. (#1699) <Dmitri Smirnov>
- **[da7c50c](onnx/onnx@da7c50c)**: ONNX does not maintain versions for experimental ops (#1696) <Ke Zhang>
- **[0c8d857](onnx/onnx@0c8d857)**: Correct type of value_info in Graph (#1694) <Maik Riechert>
- **[f612532](onnx/onnx@f612532)**: Fix typos (#1686) <Eundoo Song>

Reviewed By: zrphercule

Differential Revision: D13581674

fbshipit-source-id: 8f8ee86a05a86fe99bf94509148c559ea3df1464
mrshenli pushed a commit to mrshenli/pytorch that referenced this issue Jan 6, 2019
…b18ba1 (pytorch#15739)

Summary:
Pull Request resolved: pytorch#15739

Previous import was 765f5ee823a67a866f4bd28a9860e81f3c811ce8

Included changes:
- **[8384c78](onnx/onnx@8384c78)**: add constantofshape (pytorch#1582) <Rui Zhu>
- **[9afc06c](onnx/onnx@9afc06c)**: Set symbol visibility to hidden for non-Windows (pytorch#1707) <Paul Jesse Hellemn>
- **[6f8a9f0](onnx/onnx@6f8a9f0)**: Revert "Add NonMaxSupression operator (pytorch#1695)" (pytorch#1702) <Lu Fang>
- **[8b89544](onnx/onnx@8b89544)**: Add NonMaxSupression operator (pytorch#1695) <Hector Li>
- **[0a7cc48](onnx/onnx@0a7cc48)**: Add bfloat16 support. (pytorch#1699) <Dmitri Smirnov>
- **[da7c50c](onnx/onnx@da7c50c)**: ONNX does not maintain versions for experimental ops (pytorch#1696) <Ke Zhang>
- **[0c8d857](onnx/onnx@0c8d857)**: Correct type of value_info in Graph (pytorch#1694) <Maik Riechert>
- **[f612532](onnx/onnx@f612532)**: Fix typos (pytorch#1686) <Eundoo Song>

Reviewed By: zrphercule

Differential Revision: D13581674

fbshipit-source-id: 8f8ee86a05a86fe99bf94509148c559ea3df1464
@albanD albanD added feature A request for a proper, new feature. module: nn Related to torch.nn triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Nov 27, 2019
@mruberry mruberry changed the title load_state_dict should take filenames Feature Request: load_state_dict should take filenames May 9, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature A request for a proper, new feature. module: nn Related to torch.nn triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
Issue Categories
usability / simple-fixes
Issue Status
Uncategorized
Development

No branches or pull requests

4 participants