Skip to content

Conversation

keyboardAnt
Copy link
Contributor

@keyboardAnt keyboardAnt commented Nov 30, 2024

📄 ICML oral (top %1): Accelerating LLM Inference with Lossless Speculative Decoding Algorithms for Heterogeneous Vocabularies, https://arxiv.org/abs/2502.05202 - Nadav Timor, Jonathan Mamou, Daniel Korat, Moshe Berchansky, Oren Pereg, Gaurav Jain, Moshe Wasserblat, David Harel


This PR is a collaborative effort with @jmamou and @gauravjain14. This PR supersedes #34760 and builds upon #35009.


This PR is open for initial review, although some areas are still under development.

What does this PR do?

This PR introduces the UniversalSpeculativeDecodingGenerator class, enabling speculative decoding for assistants with slightly different tokenizers. The key addition is two logits processors (LogitsProcessor) that ensure the assistant generates tokens exclusively from the target vocabulary, maintaining alignment and preserving the target distribution without altering the verification method. Theoretically, it is agnostic to the do_sample choice. This avoids issues like #32867 and #33534 and sets the stage for advanced universal speculative decoding techniques (that we are currently researching and have not yet been published).


Motivation and Context

This update resolves prior inconsistencies in speculative decoding caused by misaligned vocabularies. Key benefits include:

  • Ensuring the assistant generates only tokens present in the target vocabulary.
  • Lossless preservation of the target distribution.
  • Compatibility with future speculative decoding advancements.

This PR is a step toward advancements in Universal Assisted Generation, in collaboration with @danielkorat, @orenpereg, @mosheber, @jmamou, @gante, @lewtun, and @MosheWasserb.


Related

Issues:

PRs:


Dependencies


Before Submitting Checklist


Who can review?

@gante / @ArthurZucker / @zucchini-nlp

@gauravjain14
Copy link
Contributor

Hi @keyboardAnt - I checked out your changes and did the following evaluations to obtain speed-up when using universal assisted generation.

The following is the summary -

Dataset Used - tau/scrolls - qasper
Number of samples evaluated per run - 20
Avg. Speed observed = mean([baseline_time[0]/assisted_time[0], baseline_time[1]/assisted_time[1], ...., baseline_time[N-1]/assisted_time[N-1]])

Case 1:
max_new_tokens=100
Avg Speed up observed = 1.13x

Case 2:
max_new_tokens=256
Avg Speed up observed = 1.46x

Case 3:
max_new_tokens=512
Avg Speed up observed = 1.72x

A couple of things that remain to be observed -

  1. When does that start saturating, i.e. I don't expect we will endlessly see speedup as more tokens are generated by the model(s).

  2. Also, does this affect the accuracy between the baseline (no-assistant) generation mode and the assisted generation model?

@gauravjain14
Copy link
Contributor

Running more evaluation with usd, I am seeing them fail with the following errors -

../aten/src/ATen/native/cuda/Indexing.cu:1236: indexSelectSmallIndex: block: [14,0,0], thread: [0,0,0] Assertion `srcIndex < srcSelectDimSize` failed.                                                    
../aten/src/ATen/native/cuda/Indexing.cu:1236: indexSelectSmallIndex: block: [14,0,0], thread: [1,0,0] Assertion `srcIndex < srcSelectDimSize` failed.                                                    
../aten/src/ATen/native/cuda/Indexing.cu:1236: indexSelectSmallIndex: block: [14,0,0], thread: [2,0,0] Assertion `srcIndex < srcSelectDimSize` failed.                                                    
../aten/src/ATen/native/cuda/Indexing.cu:1236: indexSelectSmallIndex: block: [14,0,0], thread: [3,0,0] Assertion `srcIndex < srcSelectDimSize` failed.                                                    
../aten/src/ATen/native/cuda/Indexing.cu:1236: indexSelectSmallIndex: block: [14,0,0], thread: [4,0,0] Assertion `srcIndex < srcSelectDimSize` failed.                                                    
../aten/src/ATen/native/cuda/Indexing.cu:1236: indexSelectSmallIndex: block: [14,0,0], thread: [5,0,0] Assertion `srcIndex < srcSelectDimSize` failed.

File "/disk/universal_assisted_generation/perf_comparison_llama_qwen.py", line 70, in <module>
    assisted_text, assisted_time = generate_assisted(
                                   ^^^^^^^^^^^^^^^^^^
  File "/disk/universal_assisted_generation/perf_comparison_llama_qwen.py", line 51, in generate_assisted
    outputs = target_model.generate(
              ^^^^^^^^^^^^^^^^^^^^^^
  File "/disk/anaconda3/envs/uag/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/disk/universal_assisted_generation/transformers/src/transformers/generation/utils.py", line 2213, in generate
    result = self._assisted_decoding(
             ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/disk/universal_assisted_generation/transformers/src/transformers/generation/utils.py", line 4318, in _assisted_decoding
    outputs = self(**model_inputs)
              ^^^^^^^^^^^^^^^^^^^^
  File "/disk/anaconda3/envs/uag/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/disk/anaconda3/envs/uag/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/disk/anaconda3/envs/uag/lib/python3.12/site-packages/accelerate/hooks.py", line 170, in new_forward
    output = module._old_forward(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/disk/universal_assisted_generation/transformers/src/transformers/models/llama/modeling_llama.py", line 1163, in forward
    outputs = self.model(
              ^^^^^^^^^^^
  File "/disk/envs/uag/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/disk/anaconda3/envs/uag/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/disk/universal_assisted_generation/transformers/src/transformers/models/llama/modeling_llama.py", line 883, in forward
    causal_mask = self._update_causal_mask(
                  ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/disk/universal_assisted_generation/transformers/src/transformers/models/llama/modeling_llama.py", line 973, in _update_causal_mask
    if AttentionMaskConverter._ignore_causal_mask_sdpa(
       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/disk/universal_assisted_generation/transformers/src/transformers/modeling_attn_mask_utils.py", line 284, in _ignore_causal_mask_sdpa
    elif not is_tracing and torch.all(attention_mask == 1):
                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

This only occurs when the assistant model and the target model have different tokenizers. At least what I have consistently observed till now. Any idea what could be causing this?

I am running these evaluations on 4xT4 system with about 64GB memory. The models I am using are

target_checkpoint = "meta-llama/Llama-3.1-8B-Instruct"
assistant_checkpoint = "Qwen/Qwen2.5-0.5B-Instruct"
``

@jmamou
Copy link
Contributor

jmamou commented Dec 2, 2024

Hi @keyboardAnt - I checked out your changes and did the following evaluations to obtain speed-up when using universal assisted generation.

The following is the summary -

Dataset Used - tau/scrolls - qasper Number of samples evaluated per run - 20 Avg. Speed observed = mean([baseline_time[0]/assisted_time[0], baseline_time[1]/assisted_time[1], ...., baseline_time[N-1]/assisted_time[N-1]])

Case 1: max_new_tokens=100 Avg Speed up observed = 1.13x

Case 2: max_new_tokens=256 Avg Speed up observed = 1.46x

Case 3: max_new_tokens=512 Avg Speed up observed = 1.72x

A couple of things that remain to be observed -

  1. When does that start saturating, i.e. I don't expect we will endlessly see speedup as more tokens are generated by the model(s).
  2. Also, does this affect the accuracy between the baseline (no-assistant) generation mode and the assisted generation model?

@gauravjain14
thanks for sharing your experiments!
I am currently running benchmark, I will hopefully share it soon.

  1. Note that we can reach EOS token before reaching max_new_tokens and in that case, we stop generating at EOS. It often occurs in summarization tasks.
  2. no, accuracy is not affected, USD is lossless.

@jmamou
Copy link
Contributor

jmamou commented Dec 2, 2024

Hi @keyboardAnt - I checked out your changes and did the following evaluations to obtain speed-up when using universal assisted generation.

The following is the summary -

Dataset Used - tau/scrolls - qasper Number of samples evaluated per run - 20 Avg. Speed observed = mean([baseline_time[0]/assisted_time[0], baseline_time[1]/assisted_time[1], ...., baseline_time[N-1]/assisted_time[N-1]])

Case 1: max_new_tokens=100 Avg Speed up observed = 1.13x

Case 2: max_new_tokens=256 Avg Speed up observed = 1.46x

Case 3: max_new_tokens=512 Avg Speed up observed = 1.72x

A couple of things that remain to be observed -

  1. When does that start saturating, i.e. I don't expect we will endlessly see speedup as more tokens are generated by the model(s).
  2. Also, does this affect the accuracy between the baseline (no-assistant) generation mode and the assisted generation model?

@gauravjain14
which target/draft models did you use here?

@jmamou
Copy link
Contributor

jmamou commented Dec 2, 2024

Running more evaluation with usd, I am seeing them fail with the following errors -

../aten/src/ATen/native/cuda/Indexing.cu:1236: indexSelectSmallIndex: block: [14,0,0], thread: [0,0,0] Assertion `srcIndex < srcSelectDimSize` failed.                                                    
../aten/src/ATen/native/cuda/Indexing.cu:1236: indexSelectSmallIndex: block: [14,0,0], thread: [1,0,0] Assertion `srcIndex < srcSelectDimSize` failed.                                                    
../aten/src/ATen/native/cuda/Indexing.cu:1236: indexSelectSmallIndex: block: [14,0,0], thread: [2,0,0] Assertion `srcIndex < srcSelectDimSize` failed.                                                    
../aten/src/ATen/native/cuda/Indexing.cu:1236: indexSelectSmallIndex: block: [14,0,0], thread: [3,0,0] Assertion `srcIndex < srcSelectDimSize` failed.                                                    
../aten/src/ATen/native/cuda/Indexing.cu:1236: indexSelectSmallIndex: block: [14,0,0], thread: [4,0,0] Assertion `srcIndex < srcSelectDimSize` failed.                                                    
../aten/src/ATen/native/cuda/Indexing.cu:1236: indexSelectSmallIndex: block: [14,0,0], thread: [5,0,0] Assertion `srcIndex < srcSelectDimSize` failed.

File "/disk/universal_assisted_generation/perf_comparison_llama_qwen.py", line 70, in <module>
    assisted_text, assisted_time = generate_assisted(
                                   ^^^^^^^^^^^^^^^^^^
  File "/disk/universal_assisted_generation/perf_comparison_llama_qwen.py", line 51, in generate_assisted
    outputs = target_model.generate(
              ^^^^^^^^^^^^^^^^^^^^^^
  File "/disk/anaconda3/envs/uag/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/disk/universal_assisted_generation/transformers/src/transformers/generation/utils.py", line 2213, in generate
    result = self._assisted_decoding(
             ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/disk/universal_assisted_generation/transformers/src/transformers/generation/utils.py", line 4318, in _assisted_decoding
    outputs = self(**model_inputs)
              ^^^^^^^^^^^^^^^^^^^^
  File "/disk/anaconda3/envs/uag/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/disk/anaconda3/envs/uag/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/disk/anaconda3/envs/uag/lib/python3.12/site-packages/accelerate/hooks.py", line 170, in new_forward
    output = module._old_forward(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/disk/universal_assisted_generation/transformers/src/transformers/models/llama/modeling_llama.py", line 1163, in forward
    outputs = self.model(
              ^^^^^^^^^^^
  File "/disk/envs/uag/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/disk/anaconda3/envs/uag/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/disk/universal_assisted_generation/transformers/src/transformers/models/llama/modeling_llama.py", line 883, in forward
    causal_mask = self._update_causal_mask(
                  ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/disk/universal_assisted_generation/transformers/src/transformers/models/llama/modeling_llama.py", line 973, in _update_causal_mask
    if AttentionMaskConverter._ignore_causal_mask_sdpa(
       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/disk/universal_assisted_generation/transformers/src/transformers/modeling_attn_mask_utils.py", line 284, in _ignore_causal_mask_sdpa
    elif not is_tracing and torch.all(attention_mask == 1):
                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

This only occurs when the assistant model and the target model have different tokenizers. At least what I have consistently observed till now. Any idea what could be causing this?

I am running these evaluations on 4xT4 system with about 64GB memory. The models I am using are

target_checkpoint = "meta-llama/Llama-3.1-8B-Instruct"
assistant_checkpoint = "Qwen/Qwen2.5-0.5B-Instruct"
``

seems to be related to #22546 (comment)

@jmamou
Copy link
Contributor

jmamou commented Dec 2, 2024

We have run USD with target='meta-llama/Llama-3.1-70B', draft='Qwen/Qwen2-0.5B-Instruct' on scroll on 2 A100 GPU's
We got speedup of 2.65x.
As reported #34760 (comment), overlap of draft vocab w.r.t. to target vocab is 85 %

@keyboardAnt
Copy link
Contributor Author

@jmamou - I’m replying to your question here.

Our evaluation so far has focused solely on single-threaded inference, but I don’t see a strong reason to restrict the code to single-threading. The thread-safe lock implementation helps prevent race conditions during multi-threaded execution and is considered standard.

@gante, do you happen to know if there are any multithreading use cases today? I know there was no multiprocessing support as of August 2024 (#32864). If there aren’t any actual use cases, do you think removing this thread-safe locking functionality would make sense?

@gante
Copy link
Member

gante commented Feb 26, 2025

@keyboardAnt In general, we avoid threading in the core library whenever possible -- transformers is used in many places, and not all of them are thread-safe. Here's an example of a threading issue caused by transformers: gradio-app/gradio#4016

Not to say that these issues are not fixable, but since we lack the capacity to handle existing issues, we're trying to prevent code changes that are likely to cause issues in the future 🤗

@jmamou
Copy link
Contributor

jmamou commented Feb 26, 2025

@keyboardAnt In general, we avoid threading in the core library whenever possible -- transformers is used in many places, and not all of them are thread-safe. Here's an example of a threading issue caused by transformers: gradio-app/gradio#4016

Not to say that these issues are not fixable, but since we lack the capacity to handle existing issues, we're trying to prevent code changes that are likely to cause issues in the future 🤗

@gante
as you suggested, we worked around using threading by building the map in generate, and passing the map to the translator candidate generator.
We addressed all the comments.

Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me :)

Two minor nits and I'm happy to approve and merge:

  1. See comment below
  2. Missing an integration test in tests/generation/test_candidate_generator: with temperature set to nearly 0, USD should match vanilla sampling. (make sure to use a small model, it can even be a pair of dummy models like hf-internal-testing/tiny-random-gpt2 and hf-internal-testing/tiny-random-MistralForCausalLM, since the actual test doesn't matter)

Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, let's make CI green and merge 🤗

@jmamou
Copy link
Contributor

jmamou commented Feb 26, 2025

LGTM, let's make CI green and merge 🤗

@gante
CircleCI tests are green 😄

@gante gante merged commit d18d9c3 into huggingface:main Feb 26, 2025
21 checks passed
@jmamou jmamou mentioned this pull request Mar 25, 2025
5 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants