Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] MLPSpeculator speculative decoding support #1850

Closed
wants to merge 2 commits into from

Conversation

JRosenkranz
Copy link
Contributor

@JRosenkranz JRosenkranz commented May 2, 2024

What does this PR do?

1. What is the motivation behind this feature? Is it related to a problem or frustration with the library? Is it a feature related to something you need for a project? Is it something you worked on and think it could benefit the community?

This feature will improve inference performance significantly with minimal effort from the user by leveraging the existing speculative decoding framework with an improved speculator architecture implementation.

2. Describe your requested feature in as much detail as possible. The more you can tell us about it, the better we'll be able to help you.

This feature includes a new Speculator architecture (MLPSpeculator) to support flash_causal_lm models. The architecture itself is similar to medusa, with a key difference that each head conditions its outputs on the prior head. This allows it to produce better formed n-grams and increase acceptance rate.

3. Provide a code snippet that demonstrates the features usage.

TBD

4. If the feature is related to a paper, please include a link.

https://arxiv.org/abs/2404.19124

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@Narsil

Note: More information will be added to this description as the PR becomes more mature.

Current Speculators:

Reference implementation can be found: https://github.com/foundation-model-stack/fms-extras

Copy link
Collaborator

@Narsil Narsil left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have link to the actual models so we can try this out and see how is the performance ?

I made a lot of comments about the overall structure. I'm happy to help and make some modifications (either here or keep things for follow up PRs).
Some of the stuff is more because we didn't use to have multiple speculator, so now is a good time to clean things up a bit, but not really the main goal of this PR.

Comment on lines 468 to 471
elementwise_scale=True,
elementwise_shift=False,
use_mean=False,
use_high_precision_pow=False,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any way we could set on a definite set of values ?

Having lots of if/else in the modeling code is great when tinkering, but super confusing when running actual models, and usually only 1 set is used most of the time.
Having different pathways makes optimization also harder if we were to replace things with an actual kernel.


# Update candidate set with new predictions
out = out.unsqueeze(2).expand(-1, -1, top_k_tokens_per_head[i], -1) # b k k' h
out = torch.cat([out, preds.unsqueeze(3)], dim=3) # b k k' h+1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can prevent all those cats by prellocating the tensor directly (and indexing directly into it, that will save a lot of tensor allocations)

for i in range(self.n_predict):
# Project and predict
z = self.emb[i](ind)
z = z.mul(self.emb_weight * math.sqrt(self.inner_dim / 2)) # b k d
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Couldn't this multiplication be done at init time ?

Copy link

@daviswer daviswer May 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can rescale self.emb_weight at init time, but that's just a single scalar so I assumed it wouldn't matter much. Rescaling z=self.emb itself is a little more problematic because we keep our layers initialized to the same scale for training purposes - we'd have to adjust our pretrained weights to accommodate a direct rescaling of the embedding layers. What we could do though is fuse this multiply with the add that comes afterwards. Assuming I'm understanding your thought process correctly

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it wouldn't matter much.

It does because it launches a new kernels for both ops. Not critical, but still something we try to avoid as much as possible (it's usually trivial)

) # b n h v

def load(self, config, prefix, weights):
self.emb = nn.ModuleList(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's remove all these torch allocations, and use the weights directly coming from the weights.

Doing this saves a lot of VRAM (because f32 by default in torch) and also makes bugs easier to detect (because you'll get error missing tensor instead of potentially silently keeping some tensors non set)

Comment on lines 486 to 490
def reset_parameters(self):
if self.elementwise_scale:
self.weight.data.fill_(1)
if self.elementwise_shift:
self.bias.data.zero_()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unused function, no ?

)
self.proj = nn.ModuleList(
[
nn.Linear((config.emb_dim if i == 0 else config.inner_dim), config.inner_dim, bias=False)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How many such layers are there ?
We should probable TP parallelize all this computation to make it efficient.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, the number of layers is based on the n_predict parameter to the model.

except:
medusa = MedusaHeadV2(config, prefix, weights)

architecture = speculator_config["architectures"][0]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will fail on medusa (which doesn't define an architecture). We need to try..except or get("architecture")...

self.head[i].weight.data.copy_(weights.get_tensor(f"{prefix}.head.{i}.weight"))


class MLPSpeculatorHeadV1(nn.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to specify V1 if there's no V2. But as you prefer.


speculator_path = speculator_config.use_speculator

filename = str(Path(speculator_path) / "*.safetensors")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't there a filename already ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looked consistent with what medusa was doing

medusa_config = str(Path(use_medusa) / "config.json")
filename = str(Path(use_medusa) / "medusa_lm_head.safetensors")

Also, there could be multiple safetensors files. Let me know if there is another way I should be getting the path.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant why is there a * ? Shouldn't it be a real name for the file ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently it looks something like this (depending on size):

model-00001-of-00002.safetensors
model-00002-of-00002.safetensors

use_medusa = config.use_medusa
if use_medusa:
use_speculator = config.use_speculator
if use_speculator:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like this should switch from a bool to a true enum value directly (making the modifications more extensive).

config.speculator = {None, "medusa", "MLPS"} basically.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it was ever a bool? it looked like it was a string path pointing to the speculator directory containing the config and weights?

medusa_config = str(Path(use_medusa) / "config.json")
filename = str(Path(use_medusa) / "medusa_lm_head.safetensors")

We could keep make this an enum, or we can just keep as path and recover the class from the config? If this is possible with the current config.

@JRosenkranz
Copy link
Contributor Author

JRosenkranz commented May 2, 2024

Do you have link to the actual models so we can try this out and see how is the performance ?

I made a lot of comments about the overall structure. I'm happy to help and make some modifications (either here or keep things for follow up PRs). Some of the stuff is more because we didn't use to have multiple speculator, so now is a good time to clean things up a bit, but not really the main goal of this PR.

Sure! Feel free to make modifications in this PR. Here are the links to the current speculators we have:

https://huggingface.co/ibm-fms/llama3-8b-accelerator
https://huggingface.co/ibm-fms/codellama-13b-accelerator
https://huggingface.co/ibm-fms/llama-13b-accelerator
https://huggingface.co/ibm/granite-7b-lab-accelerator

…s from LayerNormParameterized and renamed to MLPSpeculatorLayerNorm; now using modules for tensor-parallel (this is work in progress - looking into if this is right approach); fixed issue with getting medusa model; fixed for more efficient loading
@JRosenkranz
Copy link
Contributor Author

Closing as this has been merged via #1865

daviswer added a commit to foundation-model-stack/fms-extras that referenced this pull request May 14, 2024
See comments [here](huggingface/text-generation-inference#1850), specifically [A](huggingface/text-generation-inference#1850 (comment)) and [B](huggingface/text-generation-inference#1850 (comment))

Makes `forward` and `generate_suffixes` more efficient by fusing ops and removing repeated allocations

Outputs confirmed the same up to 1e-6 error (due to very slightly different handling of LN epsilon)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants