Skip to content

Support for caching prompt hidden states through multiple calls of generate() #24841

@offendo

Description

@offendo

Feature request

Hi there,

I'd like to be able to re-use the hidden states for a common (potentially long) prompt across multiple calls to model.generate() in order to reduce redundant computation. Here is how I envision a final API, though I'm sure there are multiple ways to do it.

# Load stuff
model = AutoModel.from_pretrained('huggyllama/llama-7b')
tokenizer = AutoTokenizer.from_pretrained('huggyllama/llama-7b')

# Common prompt that we'd prepend to every example
prompt = "This is a common prompt in every example."
prompt_ids = tokenizer(prompt, return_tensors='pt')

# Examples to pass to generate
examples = ["Ackbar went to", "Billaba enjoys", "Cody is eating some"]

# Generation loop
outputs = []
prompt_hidden_state = None
for ex in examples:
    # Current way of doing things
    out = model.generate(
        **tokenizer(prompt + ex, return_tensors='pt'),   
    )

    # Proposed method to re-use prompt_hidden_state
    out = model.generate(
        **tokenizer(x, return_tensors='pt'),   
        common_prompt_ids=prompt_ids,
        prompt_hidden_state=prompt_hidden_state
    )
    prompt_hidden_state = out.prompt_hidden_state
    outputs.append(out.sequences)

Thanks in advance.

Motivation

A very common pattern for LLM usage is having a common prompt (e.g., instructions and input/output pairs), a sample input, and asking it to generate the sample output. For example:

You are a programmer's assistant which converts English descriptions to Python functions.

English: <example 1 description>
Python: <example 1 function>

English: <example 2 description>
Python: <example 2 function>

English: <example 3 description>
Python: <example 3 function>

English: <input description>
Python: 

I'd like to be able to cache the common part of the prompt across inputs, that is, everything before <input description> which appears in every example to avoid potentially expensive re-computation.

Your contribution

The only existing info I could find is the short discussion here. I tried messing around a bit to get this to work but had little luck. I'm not familiar with the inner-workings of transformers and ran into numerous errors. One problem is padding, which if we're using left padding, can cause some misalignment with the prompt hidden states, e.g.:

<p> <p> <p> common prompt x_1 x_2 x_3
<p> <p> common prompt x_1 x_2 x_3 x_4
<p> <p> <p> <p> common prompt x_1 x_2

I don't know the best way to solve this. Do we dynamically pad every tensor in past_key_values? That seems slow but I don't know if it actually is.

If someone can suggest a better/easier way or maybe give some more pointers on how to solve padding. I'd be happy to try again myself.

Thanks in advance.

Metadata

Metadata

Assignees

Labels

WIPLabel your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progress

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions