Skip to content

error when using PPO in Gemma #1663

@mostafamdy

Description

@mostafamdy

System Info

Hi,
I tried using ppo with gemma model but I get this error
I think the issue is here is_encoder_decoder

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[26], line 68
     66 print(response_tensors)
     67 #### Run PPO step
---> 68 stats = ppo_trainer.step(query_tensors, response_tensors, rewards)
     69 ppo_trainer.log_stats(stats, batch, rewards)
     70 break

File /opt/conda/lib/python3.10/contextlib.py:79, in ContextDecorator.__call__.<locals>.inner(*args, **kwds)
     76 @wraps(func)
     77 def inner(*args, **kwds):
     78     with self._recreate_cm():
---> 79         return func(*args, **kwds)

File /opt/conda/lib/python3.10/site-packages/trl/trainer/ppo_trainer.py:721, in PPOTrainer.step(self, queries, responses, scores, response_masks)
    718 full_kl_penalty = self.config.kl_penalty == "full"
    720 with torch.no_grad():
--> 721     all_logprobs, logits_or_none, values, masks = self.batched_forward_pass(
    722         self.model,
    723         queries,
    724         responses,
    725         model_inputs,
    726         response_masks=response_masks,
    727         return_logits=full_kl_penalty,
    728     )
    729     with self.optional_peft_ctx():
    730         ref_logprobs, ref_logits_or_none, _, _ = self.batched_forward_pass(
    731             self.model if self.is_peft_model else self.ref_model,
    732             queries,
   (...)
    735             return_logits=full_kl_penalty,
    736         )

File /opt/conda/lib/python3.10/contextlib.py:79, in ContextDecorator.__call__.<locals>.inner(*args, **kwds)
     76 @wraps(func)
     77 def inner(*args, **kwds):
     78     with self._recreate_cm():
---> 79         return func(*args, **kwds)

File /opt/conda/lib/python3.10/site-packages/trl/trainer/ppo_trainer.py:994, in PPOTrainer.batched_forward_pass(self, model, queries, responses, model_inputs, return_logits, response_masks)
    992 if response_masks is not None:
    993     response_masks_batch = response_masks[i * fbs : (i + 1) * fbs]
--> 994 logits, _, values = model(**input_kwargs)
    996 if self.is_encoder_decoder:
    997     input_ids = input_kwargs["decoder_input_ids"]

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1568, in Module._call_impl(self, *args, **kwargs)
   1565     bw_hook = hooks.BackwardHook(self, full_backward_hooks, backward_pre_hooks)
   1566     args = bw_hook.setup_input_hook(args)
-> 1568 result = forward_call(*args, **kwargs)
   1569 if _global_forward_hooks or self._forward_hooks:
   1570     for hook_id, hook in (
   1571         *_global_forward_hooks.items(),
   1572         *self._forward_hooks.items(),
   1573     ):
   1574         # mark that always called hook is run

File /opt/conda/lib/python3.10/site-packages/trl/models/modeling_value_head.py:171, in AutoModelForCausalLMWithValueHead.forward(self, input_ids, past_key_values, attention_mask, **kwargs)
    168 if self.is_peft_model and self.pretrained_model.active_peft_config.peft_type == "PREFIX_TUNING":
    169     kwargs.pop("past_key_values")
--> 171 base_model_output = self.pretrained_model(
    172     input_ids=input_ids,
    173     attention_mask=attention_mask,
    174     **kwargs,
    175 )
    177 last_hidden_state = base_model_output.hidden_states[-1]
    178 lm_logits = base_model_output.logits

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

File /opt/conda/lib/python3.10/site-packages/peft/peft_model.py:1326, in PeftModelForSeq2SeqLM.forward(self, input_ids, attention_mask, inputs_embeds, decoder_input_ids, decoder_attention_mask, decoder_inputs_embeds, labels, output_attentions, output_hidden_states, return_dict, task_ids, **kwargs)
   1324     with self._enable_peft_forward_hooks(**kwargs):
   1325         kwargs = {k: v for k, v in kwargs.items() if k not in self.special_peft_forward_args}
-> 1326         return self.base_model(
   1327             input_ids=input_ids,
   1328             attention_mask=attention_mask,
   1329             inputs_embeds=inputs_embeds,
   1330             decoder_input_ids=decoder_input_ids,
   1331             decoder_attention_mask=decoder_attention_mask,
   1332             decoder_inputs_embeds=decoder_inputs_embeds,
   1333             labels=labels,
   1334             output_attentions=output_attentions,
   1335             output_hidden_states=output_hidden_states,
   1336             return_dict=return_dict,
   1337             **kwargs,
   1338         )
   1340 batch_size = _get_batch_size(input_ids, inputs_embeds)
   1341 if decoder_attention_mask is not None:
   1342     # concat prompt attention mask

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

File /opt/conda/lib/python3.10/site-packages/peft/tuners/tuners_utils.py:161, in BaseTuner.forward(self, *args, **kwargs)
    160 def forward(self, *args: Any, **kwargs: Any):
--> 161     return self.model.forward(*args, **kwargs)

File /opt/conda/lib/python3.10/site-packages/accelerate/hooks.py:166, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
    164         output = module._old_forward(*args, **kwargs)
    165 else:
--> 166     output = module._old_forward(*args, **kwargs)
    167 return module._hf_hook.post_forward(module, output)

TypeError: GemmaForCausalLM.forward() got an unexpected keyword argument 'decoder_input_ids'

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

.

Expected behavior

.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions