Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Run without quantization #22

Open
freQuensy23-coder opened this issue Jan 22, 2024 · 9 comments
Open

Run without quantization #22

freQuensy23-coder opened this issue Jan 22, 2024 · 9 comments

Comments

@freQuensy23-coder
Copy link

freQuensy23-coder commented Jan 22, 2024

QuantConfig is mandatory of make model function

model = build_model(
    device=device,
    quant_config=quant_config,
    offload_config=offload_config,
    state_path=state_path,
)

Can I run mixtral with layer offloading, but WITHOUT quntization using this library?

@dvmazur
Copy link
Owner

dvmazur commented Jan 22, 2024

What hardware do you plan running the model on? It would require quite the amount of combined RAM + VRAM to run the model without quantization.

@freQuensy23-coder
Copy link
Author

freQuensy23-coder commented Jan 22, 2024

I'll use Tesla A100 with 80 gb vram + 512 ram

@dvmazur
Copy link
Owner

dvmazur commented Jan 22, 2024

Yeah, sound like it'll fit :D

The current codebase doesn't support running the model without quantization, but you could try rewriting the expert wrapper class.

This class moves the expert's parameters to a single storage, so it later can be efficiently moved between GPU and CPU memory. Here's a snippet that does this for the original expert class:

def replace_layer_storage(layer, device):
    state_dict = layer.state_dict()

    storage_size = 0
    offsets = [0]

    for x in nested_flatten(state_dict):
        if not isinstance(x, torch.Tensor):
            continue
        storage_size += x.nbytes
        offsets.append(storage_size)

    storage = torch.UntypedStorage(storage_size, device=device) 

    i = 0
    new_flattened_states = list()
    for x in nested_flatten(state_dict):
        if not isinstance(x, torch.Tensor):
            new_flattened_states.append(x)
            continue

        start = offsets[i]
        end = offsets[i + 1]
        a_view = torch.as_tensor(storage[start:end], dtype=x.dtype, device=device).view(x.shape)
        a_view[...] = x
        assert a_view.data_ptr() == storage.data_ptr() + start
        i += 1
        new_flattened_states.append(a_view)

    state_dict = nested_pack(new_flattened_states, state_dict)

    for name, param in layer.named_parameters():
        param.data = state_dict[name]

    return layer, storage

The rest of the codebase is still quite HQQ-specific and offloading the unquantized model will require rewriting some code in the build_model.py file. Most of it boils down to replacing HQQ layers with default pytorch ones, though.

If you decide to go down that path, I can help you out a bit in this issue :)

@lavawolfiee
Copy link
Collaborator

Seems like you'll be a little bit short on VRAM. Full fp16 model requires ~87GB. The table is taken from our tech report.

image

@freQuensy23-coder
Copy link
Author

Seems like you'll be a little bit short on VRAM. Full fp16 model requires ~87GB. The table is taken from our tech report.

image

I'll unload some of experts to RAM during inference, and it will use less gpu vram. It's the main idea of this lib. @dvmazur am i right

@freQuensy23-coder
Copy link
Author

freQuensy23-coder commented Jan 22, 2024

If you decide to go down that path, I can help you out a bit in this issue :)

Thanks, I’d appreciate your help with this. Also i 'll try to do it myself today's evening.

@dvmazur
Copy link
Owner

dvmazur commented Jan 22, 2024

@freQuensy23-coder, yes, you are right - @lavawolfiee must have misunderstood you.

@freQuensy23-coder
Copy link
Author

I've tried to rewrite your code to add a fp16 support using your tips, but i faced some difficulties: i don't understand where exactly in replace_layer_storage we use quantization? As i think it will work with 16bits layers to? Can you help me with it?

@dvmazur
Copy link
Owner

dvmazur commented Jan 27, 2024

I've tried to rewrite your code to add a fp16 support using your tips, but i faced some difficulties: i don't understand where exactly in replace_layer_storage we use quantization? As i think it will work with 16bits layers to? Can you help me with it?

The snippet I sent you doesn't use quantization. It simply puts a given layer to one single storage.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants