Skip to content

Bug in example DPO script in dataloading #1541

@sohrabi1

Description

@sohrabi1

Since the example DPO script uses hh-rlhf dataset in OpenAI messages format, the loading in the script here seems incorrect:

    def process(row):
        row["chosen"] = tokenizer.apply_chat_template(row["chosen"], tokenize=False)
        row["rejected"] = tokenizer.apply_chat_template(row["rejected"], tokenize=False)
        return row

since it adds all messages to both chosen and rejected. But it also ignores the prompt template for the prompt.
If my understanding is correct the right process function would be

    def process(row):
        # we should extract the final turn of messages to define chosen/rejected responses and keep the rest as prompt
        prompt_messages = row["chosen"][:-1]
        chosen_messages = row["chosen"][-1:]
        rejected_messages = row["rejected"][-1:]

        row["prompt"] = tokenizer.apply_chat_template(prompt_messages, tokenize=False)
        row["chosen"] = tokenizer.apply_chat_template(chosen_messages, tokenize=False)
        row["rejected"] = tokenizer.apply_chat_template(rejected_messages, tokenize=False)
        return row

As far as i see only the answer is expected in the chosen / rejected parts in the DPO trainer.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions