generated from fastai/nbdev_template
-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Closed
Description
System Info
Hi,
I tried using ppo with gemma model but I get this error
I think the issue is here is_encoder_decoder
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In[26], line 68
66 print(response_tensors)
67 #### Run PPO step
---> 68 stats = ppo_trainer.step(query_tensors, response_tensors, rewards)
69 ppo_trainer.log_stats(stats, batch, rewards)
70 break
File /opt/conda/lib/python3.10/contextlib.py:79, in ContextDecorator.__call__.<locals>.inner(*args, **kwds)
76 @wraps(func)
77 def inner(*args, **kwds):
78 with self._recreate_cm():
---> 79 return func(*args, **kwds)
File /opt/conda/lib/python3.10/site-packages/trl/trainer/ppo_trainer.py:721, in PPOTrainer.step(self, queries, responses, scores, response_masks)
718 full_kl_penalty = self.config.kl_penalty == "full"
720 with torch.no_grad():
--> 721 all_logprobs, logits_or_none, values, masks = self.batched_forward_pass(
722 self.model,
723 queries,
724 responses,
725 model_inputs,
726 response_masks=response_masks,
727 return_logits=full_kl_penalty,
728 )
729 with self.optional_peft_ctx():
730 ref_logprobs, ref_logits_or_none, _, _ = self.batched_forward_pass(
731 self.model if self.is_peft_model else self.ref_model,
732 queries,
(...)
735 return_logits=full_kl_penalty,
736 )
File /opt/conda/lib/python3.10/contextlib.py:79, in ContextDecorator.__call__.<locals>.inner(*args, **kwds)
76 @wraps(func)
77 def inner(*args, **kwds):
78 with self._recreate_cm():
---> 79 return func(*args, **kwds)
File /opt/conda/lib/python3.10/site-packages/trl/trainer/ppo_trainer.py:994, in PPOTrainer.batched_forward_pass(self, model, queries, responses, model_inputs, return_logits, response_masks)
992 if response_masks is not None:
993 response_masks_batch = response_masks[i * fbs : (i + 1) * fbs]
--> 994 logits, _, values = model(**input_kwargs)
996 if self.is_encoder_decoder:
997 input_ids = input_kwargs["decoder_input_ids"]
File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1517 else:
-> 1518 return self._call_impl(*args, **kwargs)
File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1568, in Module._call_impl(self, *args, **kwargs)
1565 bw_hook = hooks.BackwardHook(self, full_backward_hooks, backward_pre_hooks)
1566 args = bw_hook.setup_input_hook(args)
-> 1568 result = forward_call(*args, **kwargs)
1569 if _global_forward_hooks or self._forward_hooks:
1570 for hook_id, hook in (
1571 *_global_forward_hooks.items(),
1572 *self._forward_hooks.items(),
1573 ):
1574 # mark that always called hook is run
File /opt/conda/lib/python3.10/site-packages/trl/models/modeling_value_head.py:171, in AutoModelForCausalLMWithValueHead.forward(self, input_ids, past_key_values, attention_mask, **kwargs)
168 if self.is_peft_model and self.pretrained_model.active_peft_config.peft_type == "PREFIX_TUNING":
169 kwargs.pop("past_key_values")
--> 171 base_model_output = self.pretrained_model(
172 input_ids=input_ids,
173 attention_mask=attention_mask,
174 **kwargs,
175 )
177 last_hidden_state = base_model_output.hidden_states[-1]
178 lm_logits = base_model_output.logits
File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1517 else:
-> 1518 return self._call_impl(*args, **kwargs)
File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
1522 # If we don't have any hooks, we want to skip the rest of the logic in
1523 # this function, and just call forward.
1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1525 or _global_backward_pre_hooks or _global_backward_hooks
1526 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527 return forward_call(*args, **kwargs)
1529 try:
1530 result = None
File /opt/conda/lib/python3.10/site-packages/peft/peft_model.py:1326, in PeftModelForSeq2SeqLM.forward(self, input_ids, attention_mask, inputs_embeds, decoder_input_ids, decoder_attention_mask, decoder_inputs_embeds, labels, output_attentions, output_hidden_states, return_dict, task_ids, **kwargs)
1324 with self._enable_peft_forward_hooks(**kwargs):
1325 kwargs = {k: v for k, v in kwargs.items() if k not in self.special_peft_forward_args}
-> 1326 return self.base_model(
1327 input_ids=input_ids,
1328 attention_mask=attention_mask,
1329 inputs_embeds=inputs_embeds,
1330 decoder_input_ids=decoder_input_ids,
1331 decoder_attention_mask=decoder_attention_mask,
1332 decoder_inputs_embeds=decoder_inputs_embeds,
1333 labels=labels,
1334 output_attentions=output_attentions,
1335 output_hidden_states=output_hidden_states,
1336 return_dict=return_dict,
1337 **kwargs,
1338 )
1340 batch_size = _get_batch_size(input_ids, inputs_embeds)
1341 if decoder_attention_mask is not None:
1342 # concat prompt attention mask
File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1517 else:
-> 1518 return self._call_impl(*args, **kwargs)
File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
1522 # If we don't have any hooks, we want to skip the rest of the logic in
1523 # this function, and just call forward.
1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1525 or _global_backward_pre_hooks or _global_backward_hooks
1526 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527 return forward_call(*args, **kwargs)
1529 try:
1530 result = None
File /opt/conda/lib/python3.10/site-packages/peft/tuners/tuners_utils.py:161, in BaseTuner.forward(self, *args, **kwargs)
160 def forward(self, *args: Any, **kwargs: Any):
--> 161 return self.model.forward(*args, **kwargs)
File /opt/conda/lib/python3.10/site-packages/accelerate/hooks.py:166, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
164 output = module._old_forward(*args, **kwargs)
165 else:
--> 166 output = module._old_forward(*args, **kwargs)
167 return module._hf_hook.post_forward(module, output)
TypeError: GemmaForCausalLM.forward() got an unexpected keyword argument 'decoder_input_ids'
Who can help?
No response
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examples
folder (such as GLUE/SQuAD, ...) - My own task or dataset (give details below)
Reproduction
.
Expected behavior
.
Metadata
Metadata
Assignees
Labels
No labels