Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rewriting usage of torch.bucketize with more elementary functions #30839

Closed
EricLBuehler opened this issue May 15, 2024 · 0 comments
Closed

Rewriting usage of torch.bucketize with more elementary functions #30839

EricLBuehler opened this issue May 15, 2024 · 0 comments

Comments

@EricLBuehler
Copy link

Hello everyone,

Thank you for your excellent work here. We are implementing the Idefics 2 model into mistral.rs which is built on Candle.

Candle has a Pytorch-like API, however, it does not have some functions. In particular, we noticed that Candle is missing the torch.bucketize function. I have opened an issue on Candle but thought it would also be a good idea to check here how to implement torch.bucketize using more elementary functions or generally rewriting that part of the code to emulate the behavior.

This is the part of the code I am referencing:

for batch_idx, p_attn_mask in enumerate(patch_attention_mask):
nb_patches_h = p_attn_mask[:, 0].sum()
nb_patches_w = p_attn_mask[0].sum()
fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h)
fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w)
bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True)
bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True)
pos_ids = (bucket_coords_h[:, None] * self.num_patches_per_side + bucket_coords_w).flatten()
position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids

Is there a way to rewrite this such that I do not need to call torch.bucketize? Thank you for any help, I would really appreciate it!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant