Skip to content

Conversation

gante
Copy link
Member

@gante gante commented Jul 25, 2023

What does this PR do?

Enables returning past_key_values from generate, if return_dict_in_generate=True (otherwise only the generated input_ids are returned) and use_cache=True (otherwise there is no cache to return ;) ).

In more abstract terms, this enables features like:

  1. continuing a given generation without having the more expensive prefill step -- like in multi-turn conversations
  2. exploring the KV values without having to place a breakpoint in generate 👀 🐛

The added code for the feature is minimal, so most of the PR is docs and tests 🤗

Fixes #24841

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Jul 25, 2023

The documentation is not available anymore as the PR was closed or merged.

@huggingface huggingface deleted a comment from github-actions bot Aug 25, 2023
@huggingface huggingface deleted a comment from github-actions bot Sep 20, 2023
@ArthurZucker ArthurZucker changed the title Generate: return past_key_values [WIP] Generate: return past_key_values Sep 20, 2023
@freckletonj
Copy link

This is a killer feature 👍

@kazzand
Copy link

kazzand commented Oct 19, 2023

@gante Hi! Thanks for PR!
Did you test feeding output past_key_values into .generate() method? Like take first 250 tokens input, run .generate(), get output past_key_values, take another 50 tokens input and run .generate() with previous 250 past_key_values? With beam search it seems to be kinda tricky. I'm trying to resolve multiple dimension mismatch problems.

@gante gante marked this pull request as ready for review October 31, 2023 12:16
@gante gante requested a review from amyeroberts October 31, 2023 12:16
@gante gante changed the title [WIP] Generate: return past_key_values Generate: return past_key_values Oct 31, 2023
Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this!

V. nice and clean PR and tests ❤️

for layer in model_kwargs["past_key_values"]:
layer_past_key_values = []
for item in layer:
layer_past_key_values.append(item[..., :-1, :])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we guaranteed to have a consistent item shape here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Guaranteed consistent item shape? No. Guaranteed that the second to last dimension is the sequence length? Yes :)

Some models squash together or permutate the first two dimensions (by default, the cache is (batch_size, num_heads, sequence_length, embed_size_per_head)), Bloom and GPTBigCode aka StarCoder being the biggest offenders.

One of the tests I added is in the mixin and touches contrastive search, so the fact that this output can be correctly used for continuations is tested :D In fact, I only noticed that this was a problem precisely because of the test!

Comment on lines +1843 to +1845
# 2. generate to max length (which can be achieved by setting the eos token to an invalid value), which
# would make the test flaky (e.g. EOS is generated on iteration 1 on both generations, but the
# continuation would force it to generate beyond an EOS token)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I don't understand this comment 😅 I parse it as generating to the max length to make the test flaky.

Copy link
Member Author

@gante gante Nov 2, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, correct -- it would make the test failure a false positive, i.e. failing on expected behavior.

Comment on lines +1858 to +1861
# If "past_key_values" is not returned, pass the test (e.g. RWKV uses a different cache name and format)
outputs = model(**inputs)
if "past_key_values" not in outputs:
return
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that rather than have model specific cases when we skip by early returning, which looks like "pass" on pytest, it would be clearer for the test to assume past_key_values exist and then have each model individually handle it within their test files. This way each model can either use an explicit skip reason with unittest.skip or implement their model equivalent test. This is opinion though - would be good to have thoughts from @ydshieh

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I 100% agree!

Can I make it part of a follow-up PR? The same pattern is used in other places, and I am not sure how many models will break (there should be a few)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, of course :D

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can use self.skipTest (but generally not a good useage neither), especially here we are in a loop

for model_class

Also, better to use continue (compared to return) here.

Potentailly, using

with self.subTest(....)

is a better approach I guess.

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry - I meant to approve before!

@gante
Copy link
Member Author

gante commented Nov 2, 2023

(merging and leaving the conversion to a skip as a TODO)

@nevakrien
Copy link

I dont see a version number when will this be out?

@gante
Copy link
Member Author

gante commented Nov 30, 2023

Next release :) (v4.36)

@amyeroberts
Copy link
Collaborator

@nevakrien 4.36v is now out :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Support for caching prompt hidden states through multiple calls of generate()
7 participants