Skip to content

[Bug]: AssertionError: skip_special_tokens=False is not supported for Mistral tokenizers #16958

@chaunceyjiang

Description

@chaunceyjiang

Your current environment

The output of `python collect_env.py`
Your output of `python collect_env.py` here

🐛 Describe the bug

vllm serve stelterlab/Mistral-Small-24B-Instruct-2501-AWQ --tool-call-parser mistral   --enable-auto-tool-choice  --tokenizer-mode mistral --guided-decoding-backend xgrammar
from lmformatenforcer.external.jsonschemaobject import JsonSchemaObject

from openai import OpenAI

openai_api_key = "EMPTY"
openai_api_base = "http://localhost:8000/v1"

client = OpenAI(
    api_key=openai_api_key,
    base_url=openai_api_base,
)
tools = [
    {
        "type": "function",
        "function": {
            "name": "get_current_weather",
            "description": "Get the current weather in a given location",
            "parameters": {
                "type": "object",
                "properties": {
                    "city": {
                        "type": "string",
                        "description":
                        "The city to find the weather for, e.g. 'Vienna'",
                        "default": "Vienna",
                    },
                    "country": {
                        "type":
                        "string",
                        "description":
                        "The country that the city is in, e.g. 'Austria'",
                    },
                    "unit": {
                        "type": "string",
                        "description":
                        "The unit to fetch the temperature in",
                        "enum": ["celsius", "fahrenheit"],
                    },
                },
                "required": ["country", "unit"],
            },
        },
    },
    {
        "type": "function",
        "function": {
            "name": "get_forecast",
            "description": "Get the weather forecast for a given location",
            "parameters": {
                "type": "object",
                "properties": {
                    "city": {
                        "type": "string",
                        "description":
                        "The city to get the forecast for, e.g. 'Vienna'",
                        "default": "Vienna",
                    },
                    "country": {
                        "type":
                        "string",
                        "description":
                        "The country that the city is in, e.g. 'Austria'",
                    },
                    "days": {
                        "type":
                        "integer",
                        "description":
                        "Number of days to get the forecast for (1-7)",
                    },
                    "unit": {
                        "type": "string",
                        "description":
                        "The unit to fetch the temperature in",
                        "enum": ["celsius", "fahrenheit"],
                    },
                },
                "required": ["country", "days", "unit"],
            },
        },
    },
]

messages = [
    {
        "role": "user",
        "content": "Hi! How are you doing today?"
    },
    {
        "role": "assistant",
        "content": "I'm doing well! How can I help you?"
    },
    {
        "role":
        "user",
        "content":
        "Can you tell me what the current weather is in Berlin and the "\
        "forecast for the next 5 days, in fahrenheit?",
    },
]

# Non-streaming test
chat_completion = client.chat.completions.create(
    messages=messages,
    model='',
    tools=tools,
    # tool_choice="required",
    tool_choice="auto",
)
print("Chat completion response:")
print(f"Chat completion: {chat_completion}")
for choice in chat_completion.choices:
    if choice.message.tool_calls:
        print(
            f"Tool calls: {choice.message.tool_calls}")
    else:
        print("No tool calls found.")
assert chat_completion.choices[0].message.tool_calls is not None
assert len(chat_completion.choices[0].message.tool_calls) > 0
Traceback (most recent call last):
  File "/root/vllm/vllm/v1/engine/async_llm.py", line 277, in generate
    q = await self.add_request(
        ^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/vllm/vllm/v1/engine/async_llm.py", line 215, in add_request
    await self._add_request(request, None, 0, queue)
  File "/root/vllm/vllm/v1/engine/async_llm.py", line 233, in _add_request
    self.output_processor.add_request(request, parent_req, index, queue)
  File "/root/vllm/vllm/v1/engine/output_processor.py", line 281, in add_request
    req_state = RequestState.from_new_request(
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/vllm/vllm/v1/engine/output_processor.py", line 135, in from_new_request
    detokenizer=IncrementalDetokenizer.from_new_request(
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/vllm/vllm/v1/engine/detokenizer.py", line 51, in from_new_request
    return SlowIncrementalDetokenizer(tokenizer, request)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/vllm/vllm/v1/engine/detokenizer.py", line 219, in __init__
    convert_prompt_ids_to_tokens(
  File "/root/vllm/vllm/transformers_utils/detokenizer_utils.py", line 66, in convert_prompt_ids_to_tokens
    new_tokens = tokenizer.convert_ids_to_tokens(
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/vllm/vllm/transformers_utils/tokenizers/mistral.py", line 461, in convert_ids_to_tokens
    skip_special_tokens
AssertionError: skip_special_tokens=False is not supported for Mistral tokenizers.


I’m using the master branch of the codebase, and I noticed that PR #15137 introduced an additional configuration for the Mistral model: request.skip_special_tokens = False.

def adjust_request(
self, request: ChatCompletionRequest) -> ChatCompletionRequest:
if request.tools and request.tool_choice != 'none':
# do not skip special tokens because mistral uses the special
# tokens to indicate the start and end of the tool calls
# information.
request.skip_special_tokens = False
return request

However, the MistralTokenizer only supports cases where skip_special_tokens is set to True.

assert (
skip_special_tokens
), "skip_special_tokens=False is not supported for Mistral tokenizers."

Additionally, I noticed that PR #14094 restricted users from setting skip_special_tokens=False, but now the default is skip_special_tokens=False.

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions