Skip to content

[RFC] Supporting KV-cache toggling #1675

@SalmanMohammadi

Description

@SalmanMohammadi

Problem

Currently, when we use model.setup_caches(), KV-caches are always updated for every subsequent forward pass on the model. We have valid use cases for using model.setup_caches(), but then not updating the KV-cache on forward passes. The most immediate use-case we have for this is in the eval recipe (see #1621). When specifying multiple tasks in the recipe:

  • If one task is a generation task and another a log-likelihood task and
  • The generation task is evaluated before the log-likelihood task and
  • KV-cacheing is enabled

Then, KV-cacheing will still be enabled for the log-likelihood task which is incorrect behaviour for a number of reasons. In this example, we have two forward passes occurring, one with KV-cacheing enabled, one with KV-cacheing disabled.

Another use case is for my work on improving our RLHF offerings.

  1. For the current PPO recipe, the overall structure looks something like:
batch = batch.to(device)
with torch.no_grad():
    ...
    completions = policy_model.generate(...)

loss = self._ppo_step(completions, policy_model, ...)
loss.backward()
optimizer.step()

This isn't 1:1 but you get the gist. Here, we have two kinds of forward passes: one under torch.no_grad() with KV-cacheing, and one with requires_grad=True and without KV-cacheing (infact, we should never really have the case where requires_grad=True + KV-cacheing, but this part is kind of relevant for compile).

  1. For a LoRA PPO recipe I've been working on, the structure is a bit different:
batch = batch.to(device)
with torch.no_grad():
    ...
    completions = policy_model.generate(...)
    with torchtune.modules.peft.disable_adapter(policy_model):
        ref_logits = policy_model(completions)

loss = self._ppo_step(completions, policy_model, ...)
loss.backward()
optimizer.step()

So, now three kinds of forward passes: two under torch.no_grad(): one with KV-cacheing and one without, and one with requires_grad=True and without KV-cacheing.

We want to support these use-cases in a compile friendly manner.

Solution

1) Context managing KV-cacheing

To me this feels like the most user-friendly solution. We define a context manager like:

@contextlib.contextmanager
def enable_kv_cache(model):
    if not model.caches_are_enabled():
        raise ValueError()
    for layer in model.layers:
        layer.attn.cache_enabled = True
    try:
        yield
    finally:
        for layer in model.layers:
            layer.attn.cache_enabled = False

Which relies on a small modification to MultiHeadAttention:

class MultiHeadAttention(...):
    def __init__(...):
        ...
        self.cache_enabled = False

    def forward(...):
        ...
        if self.kv_cache is not None and self.cache_enabled:
            k, v = self.kv_cache.update(k, v)

The UX for this change would rely on users always having to use with enable_kv_cache(model): if they wanted to use KV-cacheing. This is nice because the behaviour is quite explicit. However, specifically for the eval recipe, KV-cacheing is configurable so we maybe don't want to be erroring out. Two options:

  1. inspired by torch.inference_mode(mode: bool)
@contextlib.contextmanager
def kv_cache_mode(model, mode=True):
    if not model.caches_are_enabled():
        raise ValueError()
    for layer in model.layers:
        layer.attn.cache_enabled = mode
    try:
        yield
    finally:
        for layer in model.layers:
            layer.attn.cache_enabled = not mode 


with kv_cache_mode(self._model, mode=self._enable_kv_cache)
    toks, _ = generation.generate(
        self._model,
        maybe_padded_context,
        max_generated_tokens=self.max_gen_toks,
        temperature=temperature,
        top_k=None,  # do_sample is not supported currently
        stop_tokens=self._tokenizer.stop_tokens,
    )

I'm not 100% sure about this one because I can't imagine many use cases when model.caches_are_enabled=True but mode=False. Maybe if your default is inference/KV-cacheing and you want to disable it for a single pass?

  1. Don't error out on model.caches_are_enabled. If you try using this without setup_caches, this is a no-op and nothing will happen:
@contextlib.contextmanager
def kv_cache_mode(model):
    if not model.caches_are_enabled():
        # maybe warn here?
        yield
    for layer in model.layers:
        layer.attn.cache_enabled = True
    try:
        yield
    finally:
        for layer in model.layers:
            layer.attn.cache_enabled = False 

I can confirm the above solutions work with compile (I reserve the right to retract this claim in light of any future knowledge).

2) Rely on inference_mode

We could differentiate forward passes which require KV-cacheing by using torch.inference_mode() (like we currently decorate _generation.generate with). Then, the change would simply be:

class MultiHeadAttention(...):
    def __init__(...):
        ...


    def forward(x, ...):
        ...
        if self.kv_cache is not None and x.is_inference():
            k, v = self.kv_cache.update(k, v)

However, this will only work with compile with nightlies/until next release (see pytorch/pytorch#136450).

This is a very minimal change but very non-obvious, so I don't really like it. Open to thoughts though.

Metadata

Metadata

Assignees

No one assigned

    Labels

    inferenceAnything related to our inference capabilitiesrfcRequest for comments

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions