Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to use w transformers? #30

Open
thistleknot opened this issue Apr 12, 2024 · 4 comments
Open

How to use w transformers? #30

thistleknot opened this issue Apr 12, 2024 · 4 comments

Comments

@thistleknot
Copy link

I use transformers with a custom script, I see you show how to use this with a custom fast chat script

Do you have boilerplate code on how to wrap a transformers pipeline to use w this?

@guyan364
Copy link
Collaborator

guyan364 commented Apr 12, 2024

You can use patch_hf for transformers.
For this usage, you can refer to the integration in chat.py.
Load configuration as a dict, and pass it to the patch_hf with your model.

from inf_llm.utils import patch_hf
config = load_yaml_config()['model']
model = patch_hf(model, config['type'], **config)

@thistleknot
Copy link
Author

thistleknot commented Apr 13, 2024

"load_yaml_config"?

I tried

import yaml
from inf_llm.utils import patch_hf

# Simulated YAML configuration as a string
config_string = """
model:
  type: mistral  # Assuming 'mistral' is a valid type for your use case
  path: mistralai/Mistral-7B-Instruct-v0.2
  block_size: 128
  n_init: 128
  n_local: 4096
  topk: 16
  repr_topk: 4
  max_cached_block: 32
  exc_block_size: 512
  fattn: false
  base: 1000000
  distance_scale: 1.0
max_len: 2147483647
chunk_size: 8192
conv_type: mistral-inst

"""
# Parsing the YAML string into a dictionary
config = yaml.safe_load(config_string)

# Assuming 'model' is a pre-initialized model object
model = None  # Replace this with your actual model initialization logic

# Extracting type and the rest of the model configuration
model_type = config['model'].pop('type')

# Applying the configuration to the model
patched_model = patch_hf(model, model_type, **config['model'])

# Output for debugging
print("Model patched with the following configuration:")
print(config['model'])

and I get an error about only certain model's are supported, mistral, llama, etc.

---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
Cell In[21], line 35
     32 model_type = config['model'].pop('type')
     34 # Applying the configuration to the model
---> 35 patched_model = patch_hf(model, model_type, **config['model'])
     37 # Output for debugging
     38 print("Model patched with the following configuration:")

File /data/InfLLM/inf_llm/utils/patch.py:133, in patch_hf(model, attn_type, attn_kwargs, base, distance_scale, **kwargs)
    125         return tuple(v for v in [hidden_states, pkv, all_hidden_states, all_self_attns] if v is not None)
    126     return BaseModelOutputWithPast(
    127         last_hidden_state=hidden_states,
    128         past_key_values=pkv,
    129         hidden_states=all_hidden_states,
    130         attentions=all_self_attns,
    131     )
--> 133 forward = huggingface_forward(ATTN_FORWRAD[attn_type](**attn_kwargs))
    135 if isinstance(model, LlamaForCausalLM):
    136     Attention = LlamaAttention

KeyError: 'mistral'

@chris-aeviator
Copy link

chris-aeviator commented Apr 21, 2024

while I can patch the model it won't work with the standard HF tools. Transformers assumes past_key_values to be subscriptable but past_key_values is a ContextManager


model_inputs = tokenizer([prompt], return_tensors="pt").to("cuda")

generated_ids = patched_model.generate(**model_inputs, max_new_tokens=100, do_sample=True)
tokenizer.batch_decode(generated_ids)[0]

in MistralForCausalLM.prepare_inputs_for_generation(self, input_ids, past_key_values, attention_mask, inputs_embeds, **kwargs)
   1206     max_cache_length = past_key_values.get_max_length()
   1207 else:
-> 1208     cache_length = past_length = past_key_values[0][0].shape[2]
   1209     max_cache_length = None
   1211 # Keep only the unprocessed tokens:
   1212 # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
   1213 # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
   1214 # input)

TypeError: 'ContextManager' object is not subscriptable


@thistleknot
Copy link
Author

import yaml
from inf_llm.utils import patch_hf
from transformers import AutoModel

def load_yaml_config(file_path='path_to_your_config_file.yaml'):
""" Load a YAML configuration file. """
with open(file_path, 'r') as file:
return yaml.safe_load(file)

Load the configuration for infinite context

config_path = 'minicpm-inf-llm.yaml'
with open(config_path, 'r') as file:
inf_llm_config = yaml.safe_load(file)
inf_llm_config

from inf_llm.utils import patch_hf
config = load_yaml_config(file_path=config_path)['model']
model = patch_hf(model, config['type'], **config)

produces

ValueError Traceback (most recent call last)
Cell In[26], line 3
1 from inf_llm.utils import patch_hf
2 config = load_yaml_config(file_path=config_path)['model']
----> 3 model = patch_hf(model, config['type'], **config)

File /home/user/mamba/InfLLM/inf_llm/utils/patch.py:150, in patch_hf(model, attn_type, attn_kwargs, base, distance_scale, **kwargs)
148 Model = model.model.class
149 else:
--> 150 raise ValueError("Only supports llama, mistral and qwen2 models.")
152 hf_rope = model.model.layers[0].self_attn.rotary_emb
153 base = base if base is not None else hf_rope.base

ValueError: Only supports llama, mistral and qwen2 models.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants