Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug]: Converting gguf to state_dict #411

Open
heungson opened this issue Apr 16, 2024 · 11 comments
Open

[Bug]: Converting gguf to state_dict #411

heungson opened this issue Apr 16, 2024 · 11 comments
Labels
bug Something isn't working

Comments

@heungson
Copy link

Your current environment

PyTorch version: 2.2.0+cu121
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A
OS: Ubuntu 22.04.4 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: Could not collect
Libc version: glibc-2.35
Python version: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0] (64-bit runtime)
Python platform: Linux-5.4.0-171-generic-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 12.1.105
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: Quadro RTX 8000
GPU 1: Quadro RTX 8000

Nvidia driver version: 525.147.05
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Architecture:                       x86_64
CPU op-mode(s):                     32-bit, 64-bit
Address sizes:                      46 bits physical, 48 bits virtual
Byte Order:                         Little Endian
CPU(s):                             16
On-line CPU(s) list:                0-15
Vendor ID:                          GenuineIntel
Model name:                         Intel(R) Xeon(R) CPU E5-1660 v4 @ 3.20GHz
CPU family:                         6
Model:                              79
Thread(s) per core:                 2
Core(s) per socket:                 8
Socket(s):                          1
Stepping:                           1
CPU max MHz:                        3800.0000
CPU min MHz:                        1200.0000
BogoMIPS:                           6396.14
Flags:                              fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb cat_l3 cdp_l3 invpcid_single pti intel_ppin ssbd ibrs ibpb stibp tpr_shadow vnmi flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm cqm rdt_a rdseed adx smap intel_pt xsaveopt cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local dtherm ida arat pln pts md_clear flush_l1d
Virtualization:                     VT-x
L1d cache:                          256 KiB (8 instances)
L1i cache:                          256 KiB (8 instances)
L2 cache:                           2 MiB (8 instances)
L3 cache:                           20 MiB (1 instance)
NUMA node(s):                       1
NUMA node0 CPU(s):                  0-15
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit:        KVM: Vulnerable
Vulnerability L1tf:                 Mitigation; PTE Inversion
Vulnerability Mds:                  Mitigation; Clear CPU buffers; SMT vulnerable
Vulnerability Meltdown:             Mitigation; PTI
Vulnerability Mmio stale data:      Mitigation; Clear CPU buffers; SMT vulnerable
Vulnerability Retbleed:             Not affected
Vulnerability Spec store bypass:    Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1:           Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:           Mitigation; Retpolines, IBPB conditional, IBRS_FW, STIBP conditional, RSB filling, PBRSB-eIBRS Not affected
Vulnerability Srbds:                Not affected
Vulnerability Tsx async abort:      Mitigation; Clear CPU buffers; SMT vulnerable
Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] torch==2.2.0
[pip3] triton==2.2.0
[conda] Could not collect ROCM Version: Could not collect
Aphrodite Version: 0.5.2
Aphrodite Build Flags:
CUDA Archs: 6.1 7.0 7.5 8.0 8.6 8.9 9.0+PTX; ROCm: Disabled

🐛 Describe the bug

I might be missing something. But at the beginning of converting gguf to pytorch state_dict, it fails to find the layer 'blk.0.ffn_gate_exps' in the dictionary 'mapping'


I have no name!@a535e478c460:/tmp/hub/models--MaziyarPanahi--WizardLM-2-8x22B-GGUF$ python3 -m aphrodite.endpoints.openai.api_server --host 0.0.0.0 --port 7860 --download-dir /tmp/hub --model /tmp/hub/models--MaziyarPanahi--WizardLM-2-8x22B-GGUF/snapshots/e382348c70b7cbadc126025a60c2c9f7445fcddc/WizardLM-2-8x22B.IQ3_XS-00001-of-00005.gguf --dtype auto --max-model-len 32768 --tensor-parallel-size 2 --gpu-memory-utilization 0.95 --quantization gguf --enforce-eager --trust-remote-code


WARNING: gguf quantization is not fully optimized yet. The speed can be slower than non-quantized models.
2024-04-16 01:16:53,309 INFO worker.py:1724 -- Started a local Ray instance.
INFO: Initializing the Aphrodite Engine (v0.5.2) with the following config:
INFO: Model =
'/tmp/hub/models--MaziyarPanahi--WizardLM-2-8x22B-GGUF/snapshots/e382348c70b7cbadc126025a60c2c9f7445fcddc/WizardLM-2-8x22B.IQ3_XS-00001-of-00005.gguf'
INFO: DataType = torch.float16
INFO: Model Load Format = auto
INFO: Number of GPUs = 2
INFO: Disable Custom All-Reduce = False
INFO: Quantization Format = gguf
INFO: Context Length = 32768
INFO: Enforce Eager Mode = True
INFO: KV Cache Data Type = auto
INFO: KV Cache Params Path = None
INFO: Device = cuda
You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama.LlamaTokenizer'>. This is expected, and simply means that the legacy (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set legacy=False. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in huggingface/transformers#24565
Converting GGUF tensors to PyTorch... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1% 1/128 -:--:--
(RayWorkerAphrodite pid=1148) Converting GGUF tensors to PyTorch... 1% 1/128 -:--:--
Traceback (most recent call last):
File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
exec(code, run_globals)
File "/app/aphrodite-engine/aphrodite/endpoints/openai/api_server.py", line 599, in
engine = AsyncAphrodite.from_engine_args(engine_args)
File "/app/aphrodite-engine/aphrodite/engine/async_aphrodite.py", line 676, in from_engine_args
engine = cls(parallel_config.worker_use_ray,
File "/app/aphrodite-engine/aphrodite/engine/async_aphrodite.py", line 341, in init
self.engine = self._init_engine(*args, **kwargs)
File "/app/aphrodite-engine/aphrodite/engine/async_aphrodite.py", line 410, in _init_engine
return engine_class(*args, **kwargs)
File "/app/aphrodite-engine/aphrodite/engine/aphrodite_engine.py", line 113, in init
self._init_workers_ray(placement_group)
File "/app/aphrodite-engine/aphrodite/engine/aphrodite_engine.py", line 283, in _init_workers_ray
self._run_workers(
File "/app/aphrodite-engine/aphrodite/engine/aphrodite_engine.py", line 1028, in _run_workers
driver_worker_output = getattr(self.driver_worker,
File "/app/aphrodite-engine/aphrodite/task_handler/worker.py", line 112, in load_model
self.model_runner.load_model()
File "/app/aphrodite-engine/aphrodite/task_handler/model_runner.py", line 121, in load_model
self.model = get_model(self.model_config, self.device_config,
File "/app/aphrodite-engine/aphrodite/modeling/loader.py", line 91, in get_model
model.load_weights(model_config.model, model_config.download_dir,
File "/app/aphrodite-engine/aphrodite/modeling/models/mixtral_quant.py", line 450, in load_weights
for name, loaded_weight in hf_model_weights_iterator(
File "/app/aphrodite-engine/aphrodite/modeling/hf_downloader.py", line 293, in hf_model_weights_iterator
for name, param in convert_gguf_to_state_dict(model_name_or_path,
File "/app/aphrodite-engine/aphrodite/modeling/hf_downloader.py", line 271, in convert_gguf_to_state_dict
new_key, output_dim = mapping[layer]
KeyError: 'blk.0.ffn_gate_exps'
[rank0]:[W CudaIPCTypes.cpp:16] Producer process has been terminated before all shared CUDA tensors released. See Note [Sharing CUDA tensors]

@heungson heungson added the bug Something isn't working label Apr 16, 2024
@sgsdxzy
Copy link
Collaborator

sgsdxzy commented Apr 16, 2024

It is recommended to use exl2, gptq or awq over gguf. The support for gguf (especially sharded gguf) is unfinished.

@heungson
Copy link
Author

It is recommended to use exl2, gptq or awq over gguf. The support for gguf (especially sharded gguf) is unfinished.

Oh I see. Thank you for the reply!

@sgsdxzy
Copy link
Collaborator

sgsdxzy commented Apr 30, 2024

Experimental support of multiple gguf files is added to the dev branch, please test if it works according to the documentation

@heungson
Copy link
Author

heungson commented May 10, 2024

@sgsdxzy Thank you for the update. I tried to test 'dev' branch but while the document says

The dev branch extends support for GGUF to all available model architectures besides LLAMA, and sharded (multiple-file) GGUF.

the code has the part which contradicts it what the document says

Only support llama so far
if architecture != "llama":
raise RuntimeError(f"Unsupported architecture {architecture}, "
"only llama is supported.")

Thus, when I tried to run the model 'dranger003/c4ai-command-r-plus-iMat.GGUF', it raises the error.

With the llama 3 model, it raises different error.

(/home/lhs1012/.conda/aphrodite-runtime) lhs1012@ubuntu:/mnt3/lhs1012/laboratory/aphrodite-engine$ python -m aphrodite.endpoints.openai.api_server --model /mnt3/.cache/huggingface/hub/models--QuantFactory--Meta-Llama-3-70B-Instruct-GGUF-v2/snapshots/7549d4063b18c5b0eb91e547a633245ee8fc4cdd/Meta-Llama-3-70B-Instruct-v2.Q5_1-00001-of-00002.gguf --enforce-eager true --tensor-parallel-size 2 --gpu-memory-utilization 0.95 --quantization gguf
INFO: Extracting config from GGUF...
WARNING: gguf quantization is not fully optimized yet. The speed can be slower than non-quantized models.
2024-05-10 11:09:38,502 INFO worker.py:1749 -- Started a local Ray instance.
INFO: Initializing the Aphrodite Engine (v0.5.2) with the following config:
INFO: Model =
'/mnt3/.cache/huggingface/hub/models--QuantFactory--Meta-Llama-3-70B-Instruct-GGUF-v2/snapshots/7549d4063b18c5b0eb91e54
7a633245ee8fc4cdd/Meta-Llama-3-70B-Instruct-v2.Q5_1-00001-of-00002.gguf'
INFO: Speculative Config = None
INFO: DataType = torch.float16
INFO: Model Load Format = auto
INFO: Number of GPUs = 2
INFO: Disable Custom All-Reduce = False
INFO: Quantization Format = gguf
INFO: Context Length = 8192
INFO: Enforce Eager Mode = True
INFO: KV Cache Data Type = auto
INFO: KV Cache Params Path = None
INFO: Device = cuda
INFO: Guided Decoding Backend = DecodingConfig(guided_decoding_backend='outlines')
INFO: Converting tokenizer from GGUF...
You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama.LlamaTokenizer'>. This is expected, and simply means that the legacy (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set legacy=False. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in huggingface/transformers#24565
Traceback (most recent call last):
File "", line 198, in _run_module_as_main
File "", line 88, in _run_code
File "/mnt3/lhs1012/laboratory/aphrodite-engine/aphrodite/endpoints/openai/api_server.py", line 562, in
run_server(args)
File "/mnt3/lhs1012/laboratory/aphrodite-engine/aphrodite/endpoints/openai/api_server.py", line 519, in run_server
engine = AsyncAphrodite.from_engine_args(engine_args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/mnt3/lhs1012/laboratory/aphrodite-engine/aphrodite/engine/async_aphrodite.py", line 358, in from_engine_args
engine = cls(engine_config.parallel_config.worker_use_ray,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/mnt3/lhs1012/laboratory/aphrodite-engine/aphrodite/engine/async_aphrodite.py", line 323, in init
self.engine = self._init_engine(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/mnt3/lhs1012/laboratory/aphrodite-engine/aphrodite/engine/async_aphrodite.py", line 429, in _init_engine
return engine_class(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/mnt3/lhs1012/laboratory/aphrodite-engine/aphrodite/engine/aphrodite_engine.py", line 125, in init
self._init_tokenizer()
File "/mnt3/lhs1012/laboratory/aphrodite-engine/aphrodite/engine/aphrodite_engine.py", line 247, in _init_tokenizer
self.tokenizer: BaseTokenizerGroup = get_tokenizer_group(
^^^^^^^^^^^^^^^^^^^^
File "/mnt3/lhs1012/laboratory/aphrodite-engine/aphrodite/transformers_utils/tokenizer_group/init.py", line 20, in get_tokenizer_group
return TokenizerGroup(**init_kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/mnt3/lhs1012/laboratory/aphrodite-engine/aphrodite/transformers_utils/tokenizer_group/tokenizer_group.py", line 23, in init
self.tokenizer = get_tokenizer(self.tokenizer_id, **tokenizer_config)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/mnt3/lhs1012/laboratory/aphrodite-engine/aphrodite/transformers_utils/tokenizer.py", line 136, in get_tokenizer
return convert_gguf_to_tokenizer(tokenizer_name)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/mnt3/lhs1012/laboratory/aphrodite-engine/aphrodite/transformers_utils/tokenizer.py", line 80, in convert_gguf_to_tokenizer
tokenizer = LlamaTokenizer(**tokenizer_args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/lhs1012/.conda/aphrodite-runtime/lib/python3.11/site-packages/transformers/models/llama/tokenization_llama.py", line 169, in init
self.sp_model = self.get_spm_processor(kwargs.pop("from_slow", False))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/lhs1012/.conda/aphrodite-runtime/lib/python3.11/site-packages/transformers/models/llama/tokenization_llama.py", line 196, in get_spm_processor
tokenizer.Load(self.vocab_file)
File "/home/lhs1012/.conda/aphrodite-runtime/lib/python3.11/site-packages/sentencepiece/init.py", line 961, in Load
return self.LoadFromFile(model_file)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/lhs1012/.conda/aphrodite-runtime/lib/python3.11/site-packages/sentencepiece/init.py", line 316, in LoadFromFile
return _sentencepiece.SentencePieceProcessor_LoadFromFile(self, arg)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Internal: unk is not defined.

@sgsdxzy
Copy link
Collaborator

sgsdxzy commented May 10, 2024

Support for sharded ggufs (you are using 00001-of-00002), and other architectures requires pre-convert. You also need to point --model to the directory containing all the gguf shards, not a single one.

The model must be of LlamaForCausalLM architecture to be loaded directly form GGUF, otherwise the original config.json and other json configs must be present in the directory. The tokenizer must be of LlamaTokenizer architecture to be loaded directly form GGUF, otherwise the original tokenizer must be present in the directory, or optionally use --tokenizer to choose another tokenizer.

Llama3 doesn't use LlamaTokenizer

@heungson
Copy link
Author

It succeeded in converting but I got this error when running the model

aphrodite run /mnt3/.cache/huggingface/hub/models--command-r-plus-gguf -tp 2
WARNING: gguf quantization is not fully optimized yet. The speed can be slower than non-quantized models.
2024-05-13 07:56:52,525 INFO worker.py:1749 -- Started a local Ray instance.
INFO: Initializing the Aphrodite Engine (v0.5.2) with the following config:
INFO: Model = '/mnt3/.cache/huggingface/hub/models--command-r-plus-gguf'
INFO: Speculative Config = None
INFO: DataType = torch.float16
INFO: Model Load Format = auto
INFO: Number of GPUs = 2
INFO: Disable Custom All-Reduce = False
INFO: Quantization Format = gguf
INFO: Context Length = 131072
INFO: Enforce Eager Mode = True
INFO: KV Cache Data Type = auto
INFO: KV Cache Params Path = None
INFO: Device = cuda
INFO: Guided Decoding Backend = DecodingConfig(guided_decoding_backend='outlines')
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
WARNING: The tokenizer's vocabulary size 255029 does not match the model's vocabulary size 256000.
INFO: Cannot use FlashAttention backend for Volta and Turing GPUs.
INFO: Using XFormers backend.
(RayWorkerAphrodite pid=2075297) INFO: Cannot use FlashAttention backend for Volta and Turing GPUs.
(RayWorkerAphrodite pid=2075297) INFO: Using XFormers backend.
INFO: Aphrodite is using nccl==2.21.5
(RayWorkerAphrodite pid=2075297) INFO: Aphrodite is using nccl==2.21.5
INFO: reading GPU P2P access cache from /home/lhs1012/.config/aphrodite/gpu_p2p_access_cache_for_0,1.json
(RayWorkerAphrodite pid=2075297) INFO: reading GPU P2P access cache from /home/lhs1012/.config/aphrodite/gpu_p2p_access_cache_for_0,1.json
(RayWorkerAphrodite pid=2075297) WARNING: GGUF tensor name for lm_head.weight not found, this is normal if the model uses tie word embeddings.
(RayWorkerAphrodite pid=2075297) Converting GGUF tensors to PyTorch... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 0% 0/0 -:--:--
(RayWorkerAphrodite pid=2075297) ERROR: Error executing method load_model. This might cause deadlock in distributed execution.
WARNING: GGUF tensor name for lm_head.weight not found, this is normal if the model uses tie word embeddings.
Converting GGUF tensors to PyTorch... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 0% 0/0 -:--:--
Traceback (most recent call last):
File "/home/lhs1012/.conda/aphrodite-runtime/bin/aphrodite", line 8, in
sys.exit(main())
^^^^^^
File "/mnt3/lhs1012/laboratory/aphrodite-engine/aphrodite/endpoints/cli.py", line 25, in main
args.func(args)
File "/mnt3/lhs1012/laboratory/aphrodite-engine/aphrodite/endpoints/openai/api_server.py", line 519, in run_server
engine = AsyncAphrodite.from_engine_args(engine_args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/mnt3/lhs1012/laboratory/aphrodite-engine/aphrodite/engine/async_aphrodite.py", line 358, in from_engine_args
engine = cls(engine_config.parallel_config.worker_use_ray,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/mnt3/lhs1012/laboratory/aphrodite-engine/aphrodite/engine/async_aphrodite.py", line 323, in init
self.engine = self._init_engine(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/mnt3/lhs1012/laboratory/aphrodite-engine/aphrodite/engine/async_aphrodite.py", line 429, in _init_engine
return engine_class(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/mnt3/lhs1012/laboratory/aphrodite-engine/aphrodite/engine/aphrodite_engine.py", line 131, in init
self.model_executor = executor_class(
^^^^^^^^^^^^^^^
File "/mnt3/lhs1012/laboratory/aphrodite-engine/aphrodite/executor/executor_base.py", line 39, in init
self._init_executor()
File "/mnt3/lhs1012/laboratory/aphrodite-engine/aphrodite/executor/ray_gpu_executor.py", line 45, in _init_executor
self._init_workers_ray(placement_group)
File "/mnt3/lhs1012/laboratory/aphrodite-engine/aphrodite/executor/ray_gpu_executor.py", line 193, in _init_workers_ray
self._run_workers(
File "/mnt3/lhs1012/laboratory/aphrodite-engine/aphrodite/executor/ray_gpu_executor.py", line 309, in run_workers
driver_worker_output = getattr(self.driver_worker,
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/mnt3/lhs1012/laboratory/aphrodite-engine/aphrodite/task_handler/worker.py", line 125, in load_model
self.model_runner.load_model()
File "/mnt3/lhs1012/laboratory/aphrodite-engine/aphrodite/task_handler/model_runner.py", line 179, in load_model
self.model = get_model(
^^^^^^^^^^
File "/mnt3/lhs1012/laboratory/aphrodite-engine/aphrodite/modeling/loader.py", line 103, in get_model
model.load_weights(model_config.model, model_config.download_dir,
File "/mnt3/lhs1012/laboratory/aphrodite-engine/aphrodite/modeling/models/cohere.py", line 453, in load_weights
weight_loader(param, loaded_weight)
File "/mnt3/lhs1012/laboratory/aphrodite-engine/aphrodite/modeling/layers/vocab_parallel_embedding.py", line 115, in weight_loader
loaded_weight.shape[output_dim]).copy
(loaded_weight)
^^^^^^^^^^^^^^^^^^^^
NotImplementedError: Cannot copy out of meta tensor; no data!
[W CudaIPCTypes.cpp:16] Producer process has been terminated before all shared CUDA tensors released. See Note [Sharing CUDA tensors]

@sgsdxzy
Copy link
Collaborator

sgsdxzy commented May 13, 2024

Can you test with latest release v0.5.3 and see if the issue still persists?

@heungson
Copy link
Author

still the same error with v0.5.3 and also the current main branch

@sgsdxzy
Copy link
Collaborator

sgsdxzy commented May 19, 2024

I am unable to reproduce this issue on main with ggml-c4ai-command-r-plus-iq2_xxs.gguf
Please check if you followed https://github.com/PygmalionAI/aphrodite-engine/wiki/8.-Quantization#pre-convert-to-pytorch-state_dict-recommanded correctly.

@JJordanCCurnow
Copy link

JJordanCCurnow commented May 22, 2024

Hi mate,

# Only support llama so far
if architecture != "llama":
    raise RuntimeError(f"Unsupported architecture {architecture}, "
                       "only llama is supported.")

This code is also still present in the 5.3.0 release and is halting convertion of the gguf to torch in the offical docker container. As mentioned above this contradicts what is stated in the documentation.

Line 49-52 of aphrodite/transformers_utils/config.py#L49
https://github.com/PygmalionAI/aphrodite-engine/blob/dev/aphrodite/transformers_utils/config.py#L49

I have no name!@aphrodite-engine:~/examples$ python gguf_to_torch.py --input /tmp/hub/models--microsoft--Phi-3-mini-4k-instruct-gguf/snapshots/c80d904a71b99a3eaeb8d3dbf164166384c09dc3/Phi-3-mini-4k-instruct-q4.gguf --output /tmp/hub/models--microsoft--Phi-3-mini-4k-instruct-gguf/                                                  
INFO:     Extracting config from GGUF...
Traceback (most recent call last):
  File "/app/aphrodite-engine/examples/gguf_to_torch.py", line 56, in <module>
    convert_save_model(args.input, args.unquantized_path, args.output,
  File "/app/aphrodite-engine/examples/gguf_to_torch.py", line 14, in convert_save_model
    config = extract_gguf_config(checkpoint)
  File "/app/aphrodite-engine/aphrodite/transformers_utils/config.py", line 48, in extract_gguf_config
    raise RuntimeError(f"Unsupported architecture {architecture}, "
RuntimeError: Unsupported architecture phi3, only llama is supported.
I have no name!@aphrodite-engine:~/aphrodite$ cat __init__.py 
from aphrodite.engine.args_tools import AsyncEngineArgs, EngineArgs
from aphrodite.engine.async_aphrodite import AsyncAphrodite
from aphrodite.engine.aphrodite_engine import AphroditeEngine
from aphrodite.engine.ray_tools import initialize_ray_cluster
from aphrodite.endpoints.llm import LLM
from aphrodite.modeling.models import ModelRegistry
from aphrodite.common.outputs import CompletionOutput, RequestOutput
from aphrodite.common.sampling_params import SamplingParams

__version__ = "0.5.3"

same line is present in dev branch.
https://github.com/PygmalionAI/aphrodite-engine/blob/dev/aphrodite/transformers_utils/config.py#L49

Cheers!

@sgsdxzy
Copy link
Collaborator

sgsdxzy commented May 22, 2024

@JJordanCCurnow you need to pass --unquantized-path to gguf_to_torch.py. Without this only Llama models are supported.

--unquantized-path: The path to the unquantized model to copy config and tokenizer. For llama 1&2 models this can be skipped to try extracting the config and tokenizer from the GGUF file, but it is recommended to always supply this because the tokenizer inside GGUF can sometimes be broken.

On the other hand I don't think phi3 is supported in Aphrodite.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants