Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support saving and loading 8-bit block weights #273

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

mryab
Copy link
Member

@mryab mryab commented Feb 25, 2023

This PR relies on TimDettmers/bitsandbytes#159 and makes it possible to call convert_model with the int8 data type and later on download the 8-bit checkpoint instead of 16-bit if serving the model with load_in_8bit=True. This can save up to 2x bandwidth on starting a server, as shown by this comparison of model sizes for bloom-560m:

~/petals$ du -sh converted_model*
802M    converted_model
515M    converted_model_int8

The command that was used for conversion is python -m petals.cli.convert_model --model bigscience/bloom-560m --output_path ./converted_model_int8 --torch_dtype int8 --resize_token_embeddings 50000 --block_branch_prefix int8_block. To test that the checkpoint loads correctly, you need to install bitsandbytes from the branch in the PR above and run python -m petals.cli.run_server bigscience/test-bloomd --new_swarm --skip_reachability_check --throughput 100 --device cuda (pay attention that I had to change BLOCK_BRANCH_PREFIX in this branch for the sake of testing).


logger = get_logger(__name__)

CLIENT_BRANCH = "main"
BLOCK_BRANCH_PREFIX = "block_"
BLOCK_BRANCH_PREFIX = "int8_block"
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'll roll that back before merging

Comment on lines +51 to +57
if load_in_8bit:
block = replace_8bit_linear(block)
block = block.to(device)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I moved replace_8bit_linear here because it's not possible to correctly load the quantized Linear8bitLt checkpoint into the model before it's converted and quantized

Comment on lines 80 to 81
from petals.utils.linear8bitlt_patch import CustomLinear8bitLt

for n, module in model.named_children():
if len(list(module.children())) > 0:
replace_8bit_linear(module, threshold)

if isinstance(module, torch.nn.Linear) and n not in ["lm_head", "score"]:
assert module.weight.device.type == "cpu", f"expected linear layers on CPU, got {module.weight.device}"
model._modules[n] = CustomLinear8bitLt(
model._modules[n] = bnb.nn.Linear8bitLt(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not strictly necessary, but it'd be good to get rid of all bitsandbytes-related code that got into upstream before merging this

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in #297.

Copy link
Collaborator

@justheuristic justheuristic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gentle reminder: please update BNB before merging. This is not covered by tests

@justheuristic justheuristic mentioned this pull request Feb 28, 2023
4 tasks
@@ -38,6 +39,8 @@ def load_pretrained_block(
use_auth_token: Optional[str] = None,
cache_dir: Optional[str] = None,
max_disk_space: Optional[int] = None,
load_in_8bit=False,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
load_in_8bit=False,
load_in_8bit: bool = False,

Copy link
Collaborator

@borzunov borzunov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please defer this until #323 is merged, since it changes block loading code.

@borzunov
Copy link
Collaborator

borzunov commented Aug 3, 2023

We discussed that we may revive this feature for loading NF4-pre-quantized weights for Llama 2 and Stable Beluga 2.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants