Skip to content

[Bug] Error occurs when loading the gemma model in bitsandbytes format. #2556

@upskyy

Description

@upskyy

Checklist

  • 1. I have searched related issues but cannot get the expected help.
  • 2. The bug has not been fixed in the latest version.
  • 3. Please note that if the bug-related issue you submitted lacks corresponding environment info and a minimal reproducible demo, it will be challenging for us to reproduce and resolve the issue, reducing the likelihood of receiving feedback.
  • 4. If the issue you raised is not a bug but a question, please raise a discussion at https://github.com/sgl-project/sglang/discussions/new/choose Otherwise, it will be closed.
  • 5. Please use English, otherwise it will be closed.

Describe the bug

An error occurs when loading the gemma model with the command below.

python3 -m sglang.launch_server --model-path /models --tokenizer-path /models --port 30000 --tokenizer-mode auto --dtype bfloat16 --mem-fraction-static 0.5 --random-seed 0 --enable-torch-compile --disable-cuda-graph --schedule-conservativeness 1.3 --kv-cache-dtype fp8_e5m2 --load-format bitsandbytes --quantization bitsandbytes

First of all, bitsandbytes_stacked_params_mapping does not exist in the gemma model, so I added it. (https://github.com/sgl-project/sglang/blob/v0.4.0.post1/python/sglang/srt/model_loader/loader.py#L908-L912)

When loading 4 bits with bitsandbytes, a KeyError occurs when the weight is changed to qweight in the code below. Are there any cases where weight should be changed to qweight?
weight_name = weight_name.replace(".weight", ".qweight")

https://github.com/sgl-project/sglang/blob/v0.4.0.post1/python/sglang/srt/model_loader/loader.py#L839-L844

Reproduction

Are there any cases where weight should be changed to qweight?

Loading safetensors checkpoint shards:   0% Completed | 0/4 [00:00<?, ?it/s]
[2024-12-23 00:00:07 TP0] Scheduler hit an exception: Traceback (most recent call last):
  File "/sglang/python/sglang/srt/managers/scheduler.py", line 1528, in run_scheduler_process
    scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank)
  File "/sglang/python/sglang/srt/managers/scheduler.py", line 192, in __init__
    self.tp_worker = TpWorkerClass(
  File "/sglang/python/sglang/srt/managers/tp_worker_overlap_thread.py", line 62, in __init__
    self.worker = TpModelWorker(server_args, gpu_id, tp_rank, dp_rank, nccl_port)
  File "/sglang/python/sglang/srt/managers/tp_worker.py", line 62, in __init__
    self.model_runner = ModelRunner(
  File "/sglang/python/sglang/srt/model_executor/model_runner.py", line 158, in __init__
    self.load_model()
  File "/sglang/python/sglang/srt/model_executor/model_runner.py", line 258, in load_model
    self.model = get_model(
  File "/sglang/python/sglang/srt/model_loader/__init__.py", line 22, in get_model
    return loader.load_model(
  File "/sglang/python/sglang/srt/model_loader/loader.py", line 1029, in load_model
    self._load_weights(model_config, model)
  File "/sglang/python/sglang/srt/model_loader/loader.py", line 960, in _load_weights
    model.load_weights(qweight_iterator)
  File "/sglang/python/sglang/srt/models/gemma2.py", line 445, in load_weights
    param = params_dict[name]
KeyError: 'model.layers.43.mlp.down_proj.qweight'

Environment

root@33e74a81f115:/sglang/python# python3 -m sglang.check_env                                                                                                                                                         

Python: 3.10.16 (main, Dec  4 2024, 08:53:37) [GCC 9.4.0]
CUDA available: True
GPU 0: NVIDIA A100-SXM4-80GB
GPU 0 Compute Capability: 8.0
CUDA_HOME: /usr/local/cuda
NVCC: Cuda compilation tools, release 12.4, V12.4.131
CUDA Driver Version: 550.127.05
PyTorch: 2.5.1+cu124
sglang: 0.4.0.post2
flashinfer: 0.1.6+cu124torch2.4
triton: 3.1.0
transformers: 4.47.0
torchao: 0.6.1
numpy: 1.26.4
aiohttp: 3.11.10
fastapi: 0.115.6
hf_transfer: 0.1.8
huggingface_hub: 0.26.3
interegular: 0.3.3
modelscope: 1.21.0
orjson: 3.10.12
packaging: 24.2
psutil: 6.1.0
pydantic: 2.10.3
multipart: 0.0.19
zmq: 26.2.0
uvicorn: 0.32.1
uvloop: 0.21.0
vllm: 0.6.4.post1
openai: 1.57.0
anthropic: 0.40.0
decord: 0.6.0

Metadata

Metadata

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions