Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tutorial to export a SPLADE model to ONNX #47

Open
ntnq4 opened this issue Dec 1, 2023 · 6 comments
Open

Tutorial to export a SPLADE model to ONNX #47

ntnq4 opened this issue Dec 1, 2023 · 6 comments

Comments

@ntnq4
Copy link

ntnq4 commented Dec 1, 2023

Hello,

I trained a SPLADE model on my own recently. To reduce the inference time, I tried to export my model to ONNX with torch.onnx.export() but I encountered a few errors.

Is there a tutorial somewhere for this conversion?

@thibault-formal
Copy link
Contributor

Hi @ntnq4

Not that I am aware of. I am not super familiar with ONNX - did you manage to make it work?

@ntnq4
Copy link
Author

ntnq4 commented Dec 12, 2023

Hi @thibault-formal

I didn't manage to make it work unfortunately... I tried this tutorial but it didn't work for my SPLADE model.

I also found this recent paper that mentionned this conversion.

@risan-raja
Copy link

risan-raja commented Feb 7, 2024

Hi @ntnq4 ,
I have managed to convert the splade models to onnx. Although I used the pretrained checkpoint. I am aware it is counterintuitive for you but nevertheless if this helps, I am glad.
To reproduce:

  • Convert the model to a torchscript.

model = AutoModelForMaskedLM.from_pretrained("naver/splade-cocondenser-ensembledistil", torchscript=True) # type: ignore

import torch
from transformers import AutoModelForMaskedLM,AutoTokenizer # type: ignore

class TransformerRep(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.model = AutoModelForMaskedLM.from_pretrained("naver/splade-cocondenser-ensembledistil", torchscript=True) # type: ignore
        self.model.eval() # type: ignore
        self.fp16 = True

    def encode(self, input_ids, token_type_ids, attention_mask):
        # Tokens is a dict with keys input_ids and attention_mask
        return self.model(input_ids, token_type_ids, attention_mask)[0]



class SpladeModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.model = TransformerRep()
        self.agg = "max"
        self.model.eval()
    
    def forward(self, input_ids,token_type_ids, attention_mask):
        with torch.cuda.amp.autocast(): # type: ignore
            with torch.no_grad():
                lm_logits = self.model.encode(input_ids,token_type_ids, attention_mask)[0]
                vec, _ = torch.max(torch.log(1 + torch.relu(lm_logits)) * attention_mask.unsqueeze(-1), dim=1)
                indices = vec.nonzero().squeeze()
                weights = vec.squeeze()[indices]
        return indices[:,1], weights[:,1]

# Convert the model to TorchScript
model = SpladeModel()
tokenizer = AutoTokenizer.from_pretrained("naver/splade-cocondenser-ensembledistil")
sample = "the capital of france is paris"
inputs = tokenizer(sample, return_tensors="pt")
traced_model = torch.jit.trace(model, (inputs["input_ids"], inputs["token_type_ids"], inputs["attention_mask"]))
  • Later Load it from File and convert it using a dummy input. Make sure to adjust the above script to match your implementation.
import torch
dyn_axis = {
    'input_ids': {0: 'batch_size', 1: 'sequence'},
    'attention_mask': {0: 'batch_size', 1: 'sequence'},
    'token_type_ids': {0: 'batch_size', 1: 'sequence'},
    'indices': {0: 'batch_size', 1: 'sequence'},
    'weights': {0: 'batch_size', 1: 'sequence'}
    }
model = torch.jit.load(model_file)
onnx_model = torch.onnx.export(
    model,
    dummy_input, # type: ignore
    f=model_onnx_file,
    input_names=['input_ids','token_type_ids', 'attention_mask'],
    output_names=['indices', 'weights'],
    dynamic_axes=dyn_axis,
    do_constant_folding=True,
    opset_version=15,
    verbose=False,
)
  • Using this method I have managed to convert the following HF models successfully.
model_names= [
   "naver/splade_v2_max",
   "naver/splade_v2_distil",
   "naver/splade-cocondenser-ensembledistil",
   "naver/efficient-splade-VI-BT-large-query",
   "naver/efficient-splade-VI-BT-large-doc",
]

requirements:

  • torch==2.2.0

Hope this helps! :)

@ntnq4
Copy link
Author

ntnq4 commented Feb 7, 2024

Hi @risan-raja,

Thank you for your help : )
I will try your solution on my side.

@sroussey
Copy link

sroussey commented Feb 8, 2024

if an ONNX conversion was added to HuggingFace in a folder called onnx then it would automatically become available to HuggingFace Transformers.js and be usable locally on the web.

@sroussey
Copy link

sroussey commented Feb 8, 2024

Example: https://huggingface.co/Xenova/t5-small-awesome-text-to-sql/tree/main/

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants