-
Notifications
You must be signed in to change notification settings - Fork 30.2k
Generate: return past_key_values
#25086
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
The documentation is not available anymore as the PR was closed or merged. |
past_key_values
past_key_values
This is a killer feature 👍 |
@gante Hi! Thanks for PR! |
past_key_values
past_key_values
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding this!
V. nice and clean PR and tests ❤️
for layer in model_kwargs["past_key_values"]: | ||
layer_past_key_values = [] | ||
for item in layer: | ||
layer_past_key_values.append(item[..., :-1, :]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are we guaranteed to have a consistent item
shape here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Guaranteed consistent item shape? No. Guaranteed that the second to last dimension is the sequence length? Yes :)
Some models squash together or permutate the first two dimensions (by default, the cache is (batch_size, num_heads, sequence_length, embed_size_per_head)
), Bloom and GPTBigCode aka StarCoder being the biggest offenders.
One of the tests I added is in the mixin and touches contrastive search, so the fact that this output can be correctly used for continuations is tested :D In fact, I only noticed that this was a problem precisely because of the test!
# 2. generate to max length (which can be achieved by setting the eos token to an invalid value), which | ||
# would make the test flaky (e.g. EOS is generated on iteration 1 on both generations, but the | ||
# continuation would force it to generate beyond an EOS token) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, I don't understand this comment 😅 I parse it as generating to the max length to make the test flaky.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, correct -- it would make the test failure a false positive, i.e. failing on expected behavior.
# If "past_key_values" is not returned, pass the test (e.g. RWKV uses a different cache name and format) | ||
outputs = model(**inputs) | ||
if "past_key_values" not in outputs: | ||
return |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that rather than have model specific cases when we skip by early returning, which looks like "pass" on pytest, it would be clearer for the test to assume past_key_values exist and then have each model individually handle it within their test files. This way each model can either use an explicit skip reason with unittest.skip or implement their model equivalent test. This is opinion though - would be good to have thoughts from @ydshieh
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I 100% agree!
Can I make it part of a follow-up PR? The same pattern is used in other places, and I am not sure how many models will break (there should be a few)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, of course :D
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can use self.skipTest
(but generally not a good useage neither), especially here we are in a loop
for model_class
Also, better to use continue
(compared to return
) here.
Potentailly, using
with self.subTest(....)
is a better approach I guess.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry - I meant to approve before!
(merging and leaving the conversion to a skip as a TODO) |
I dont see a version number when will this be out? |
Next release :) (v4.36) |
@nevakrien 4.36v is now out :) |
What does this PR do?
Enables returning
past_key_values
fromgenerate
, ifreturn_dict_in_generate=True
(otherwise only the generatedinput_ids
are returned) anduse_cache=True
(otherwise there is no cache to return ;) ).In more abstract terms, this enables features like:
generate
👀 🐛The added code for the feature is minimal, so most of the PR is docs and tests 🤗
Fixes #24841