Skip to content

Conversation

desaxce
Copy link

@desaxce desaxce commented Feb 15, 2025

Fixes # (#36210).

Changes:

In the heal_tokens function, we now:

  • allow tokenizers which don't have a bos_token_id
  • skip replacement of space when tokenizer doesn't specify space character as token

Just making the code not raise exceptions, there may be something more clever to do.

To test:

Use token healing on https://huggingface.co/Qwen/Qwen2.5-Coder-7B-Instruct:

from transformers import AutoTokenizer, Qwen2ForCausalLM, Qwen2Tokenizer

pipe = Qwen2ForCausalLM.from_pretrained("./")
tokenizer = Qwen2Tokenizer.from_pretrained("./")

prompt = f'Complete the following Lean 4 code:\n\n```lean4\nimport '
inputs = tokenizer(prompt, return_tensors="pt")

generate_ids = pipe.generate(inputs.input_ids, tokenizer=tokenizer, max_new_tokens=1, token_healing=True)
tokenizer.batch_decode(generate_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False)[0]

@Rocketknight1
Copy link
Member

cc @ArthurZucker @itazap for tokenizers!

@Rocketknight1
Copy link
Member

gentle ping @ArthurZucker @itazap

@itazap
Copy link
Collaborator

itazap commented Apr 3, 2025

Thanks for the catch and fix! 🤗 I took a look at the original motivation for 'token healing' (#30081) and found that there might be a different root cause to the error. For models like Qwen, the space token is actually "Ġ" instead of " ", and could be handled with:

space_tok = tokenizer.convert_ids_to_tokens(tokenizer.convert_tokens_to_ids(tokenizer.tokenize(" ")))[0]

vs. the current:

space_tok = tokenizer.convert_ids_to_tokens(tokenizer.convert_tokens_to_ids(" "))[0]

cc @ArthurZucker lmk what you think 😊

desaxce and others added 2 commits April 6, 2025 18:14
@itazap
Copy link
Collaborator

itazap commented Apr 15, 2025

cc @gante for generation

@gante
Copy link
Member

gante commented Apr 18, 2025

+1 to Ita's comment -- tokenizing the space to get the space token seems more intelligent and robust

The other change, regarding bos_token_id, looks good to me :)

# tail tokens are used for a prefix search, thus, whitespaces are replaced with
# their tokenization (e.g. 'Ġ') to enable search for tokens prefixed with a whitespace
tail_toks = (tokenizer.decode(t).replace(" ", space_tok) for t in tail_ids)
space_tok_id = tokenizer.convert_tokens_to_ids(" ")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
space_tok_id = tokenizer.convert_tokens_to_ids(" ")
space_tok_id = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(" "))

@itazap
Copy link
Collaborator

itazap commented Apr 21, 2025

Great! Last thing would be great to add a test for this, making sure that we get some output instead of an error, let me know if you'd like to add it or I can help 🤗 @desaxce

@desaxce
Copy link
Author

desaxce commented Apr 21, 2025

Great! Last thing would be great to add a test for this, making sure that we get some output instead of an error, let me know if you'd like to add it or I can help 🤗 @desaxce

Will take care of it before end of week.

@itazap
Copy link
Collaborator

itazap commented Apr 22, 2025

Appreciate it! Thank you! 😊

@desaxce
Copy link
Author

desaxce commented Apr 27, 2025

@itazap Sorry for the wrong timing, had a difficult end of week and I kept busy over the weekend.
Let's plan for a Wed. finish line so we start May ready to merge :)

@itazap
Copy link
Collaborator

itazap commented Apr 28, 2025

@desaxce no worries, thanks for the update!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants