Skip to content

Speculative sampling does not maintain probability distribution of main model #32867

@dmelcer9

Description

@dmelcer9

System Info

  • transformers version: 4.44.0
  • Platform: macOS-10.16-x86_64-i386-64bit
  • Python version: 3.10.13
  • Huggingface_hub version: 0.24.5
  • Safetensors version: 0.4.2
  • Accelerate version: 0.27.2
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.1.2 (False)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using distributed or parallel set-up in script?:

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

In the speculative sampling procedure:

probability_ratio = p_i / q_i

The probability ratio is calculated as compared to the output probability of the assistant model.

However, the speculative model is always used greedily:

self.generation_config.do_sample = False

This is equivalent to setting the temperature to zero, so the output probability of the assistant model should always be 1 (for the selected token).

As a more concrete example, if the assistant model outputs [0.51, 0.49], as long as the main model outputs [x >= 0.51, y <= 0.49], this will lead to the first token always being sampled by the procedure.

This is evident when you use a model as its own assistant, at least for the first 5 tokens from the speculative model (there is still some randomness from the extra token generated by the main model but not the assistant).

from transformers import AutoModelForCausalLM, AutoTokenizer

model_name = "openai-community/gpt2-medium"
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

inputs = tokenizer("public int", return_tensors="pt")

# Greedy
# Always outputs `public int get_current_time()` (and then some)
tokenizer.decode(model.generate(**inputs, do_sample=False, max_new_tokens=25)[0])

# Sampling
# Gives different method names each time
tokenizer.decode(model.generate(**inputs, do_sample=True, max_new_tokens=25)[0])

# Should theoretically be sampling but is not
# Always outputs `public int get_current_time()`
tokenizer.decode(model.generate(**inputs, assistant_model=model, do_sample=True, max_new_tokens=25)[0])

Expected behavior

Assisted decoding should use a correct sampling method.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions