-
Notifications
You must be signed in to change notification settings - Fork 30.3k
Description
Feature request
Hi there,
I'd like to be able to re-use the hidden states for a common (potentially long) prompt across multiple calls to model.generate()
in order to reduce redundant computation. Here is how I envision a final API, though I'm sure there are multiple ways to do it.
# Load stuff
model = AutoModel.from_pretrained('huggyllama/llama-7b')
tokenizer = AutoTokenizer.from_pretrained('huggyllama/llama-7b')
# Common prompt that we'd prepend to every example
prompt = "This is a common prompt in every example."
prompt_ids = tokenizer(prompt, return_tensors='pt')
# Examples to pass to generate
examples = ["Ackbar went to", "Billaba enjoys", "Cody is eating some"]
# Generation loop
outputs = []
prompt_hidden_state = None
for ex in examples:
# Current way of doing things
out = model.generate(
**tokenizer(prompt + ex, return_tensors='pt'),
)
# Proposed method to re-use prompt_hidden_state
out = model.generate(
**tokenizer(x, return_tensors='pt'),
common_prompt_ids=prompt_ids,
prompt_hidden_state=prompt_hidden_state
)
prompt_hidden_state = out.prompt_hidden_state
outputs.append(out.sequences)
Thanks in advance.
Motivation
A very common pattern for LLM usage is having a common prompt (e.g., instructions and input/output pairs), a sample input, and asking it to generate the sample output. For example:
You are a programmer's assistant which converts English descriptions to Python functions.
English: <example 1 description>
Python: <example 1 function>
English: <example 2 description>
Python: <example 2 function>
English: <example 3 description>
Python: <example 3 function>
English: <input description>
Python:
I'd like to be able to cache the common part of the prompt across inputs, that is, everything before <input description>
which appears in every example to avoid potentially expensive re-computation.
Your contribution
The only existing info I could find is the short discussion here. I tried messing around a bit to get this to work but had little luck. I'm not familiar with the inner-workings of transformers
and ran into numerous errors. One problem is padding, which if we're using left padding, can cause some misalignment with the prompt hidden states, e.g.:
<p> <p> <p> common prompt x_1 x_2 x_3
<p> <p> common prompt x_1 x_2 x_3 x_4
<p> <p> <p> <p> common prompt x_1 x_2
I don't know the best way to solve this. Do we dynamically pad every tensor in past_key_values
? That seems slow but I don't know if it actually is.
If someone can suggest a better/easier way or maybe give some more pointers on how to solve padding. I'd be happy to try again myself.
Thanks in advance.