Skip to content

Conversation

kashif
Copy link
Collaborator

@kashif kashif commented Dec 19, 2024

What does this PR do?

Calculate the ORPO chosen nll loss with respect to the chosen completion only rather than the whole prompt+compeletion.

Also return the shifted logits when the model is decoder only

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

attention_mask = concatenated_batch["concatenated_attention_mask"]
labels = torch.where(attention_mask == 1, labels, self.label_pad_token_id)

labels = concatenated_batch["concatenated_labels"].clone()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, checked together, if you do

labels = concatenated_batch["concatenated_input_ids"].clone()
attention_mask = concatenated_batch["concatenated_attention_mask"]
labels = torch.where(attention_mask == 1, labels, self.label_pad_token_id)

you don't ignore the prompt.

@kashif kashif merged commit 88ad1a0 into main Dec 19, 2024
14 checks passed
@kashif kashif deleted the orpo-nll-fix branch December 19, 2024 10:33
yxliu-TAMU pushed a commit to mincheolseong/ECEN743-GRPO-Project-Proposal that referenced this pull request Apr 20, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants