Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Weights sharding for Keras saving #19286

Draft
wants to merge 7 commits into
base: master
Choose a base branch
from
Draft

Conversation

nkovela1
Copy link
Contributor

This PR adds weights sharding initial functionality to the Keras saving/loading APIs, which are accessed by passing the sharded=True flag to the corresponding saving/loading calls.

@codecov-commenter
Copy link

codecov-commenter commented Mar 11, 2024

Codecov Report

Attention: Patch coverage is 67.76860% with 39 lines in your changes are missing coverage. Please review.

Project coverage is 75.61%. Comparing base (c8700f4) to head (cfbb761).
Report is 88 commits behind head on master.

Files Patch % Lines
keras/saving/saving_lib.py 66.66% 26 Missing and 13 partials ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master   #19286      +/-   ##
==========================================
- Coverage   80.14%   75.61%   -4.53%     
==========================================
  Files         341      365      +24     
  Lines       36163    39909    +3746     
  Branches     7116     7747     +631     
==========================================
+ Hits        28982    30177    +1195     
- Misses       5578     8054    +2476     
- Partials     1603     1678      +75     
Flag Coverage Δ
keras 75.46% <67.76%> (-4.53%) ⬇️
keras-jax 59.71% <67.76%> (-3.35%) ⬇️
keras-numpy 54.30% <66.11%> (-2.79%) ⬇️
keras-tensorflow 61.21% <67.76%> (-3.44%) ⬇️
keras-torch 60.29% <53.71%> (-3.58%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@@ -617,6 +662,138 @@ def close(self):
self.io_file.close()


class ShardedH5IOStore:
def __init__(self, root_path, max_size="10GB", archive=None, mode="r"):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't max_size be an int? e.g. in MB?

@@ -754,6 +754,70 @@ def call(self, inputs):
return self.first_layer(self.second_layer(inputs))


def _get_large_model():
model = keras.Sequential(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why pick a convnet for a large model?

@@ -217,7 +251,7 @@ def save_weights_only(model, filepath):
weights_store.close()


def load_weights_only(model, filepath, skip_mismatch=False):
def load_weights_only(model, filepath, sharded=False, skip_mismatch=False):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why should sharded be configurable here -- wouldn't it just depend on the file and the model?

@mattdangerw
Copy link
Member

Talked with Neel a bit about this, but one idea, building off the recent change Francois made with __getitem__/__setitem__...

  • Create h5 groups on variable write, instead of eagerly when creating a H5Entry. (Side note, this would clean up our h5 entires, no empty groups for layers without weights.)
  • H5Entry is just a dict like object that proxies calls get/set calls to parent H5Store.
  • On write, H5Store could just keep a running list of how big the shard is currently, and "roll over" to a new shard as soon as the next variable would be bigger than shard limit.
  • On read, H5Store could just check every shard file for the weight is try to load (as checking is cheap, reading is slow).

Pseudocode:

write(path, key, value):
    if self.current_shard_size + value.nbytes > self.shard_size:
        close current shard
        open new shard file
        self.current_shard_size = 0
    group = create parent groups if needed
    self.current_shard_size += value.nbytes
    group[key] = value

read(path, key):
    for file in shards:
        if path in file:
            group = file[path]
            if key in group:
                return group[key]

This could be fairly simple. Avoid the need for a separate class if we want (though we still could), allow splitting up individual layer weight across shards (important if you have one big layer).

This could even allow avoiding the json file entirely I think? Supporting something like this:

# If shard_size is set, pass a format string as path?
filenames = model.save_weights("./model_{}.weights.h5", shard_size="10GB")
# Load weights handles loading a list of files, and checking all files for the variables.
model.load_weights(filenames)

This last bit is optional, just though it was interesting. What do people think?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants