Skip to content

[New Feature Enhancement] Generalize DPO with different f-divergeneces #1259

@1485840691

Description

@1485840691

Why

There is a paper discussing generalizing DPO with different f-divergences (present implementation using log() is one of its kind: reverse kL w/ a = 0) to help model better balance alignment performance and generation diversity.

According to the paper,
"
Empirically, adopting these f-divergences ensures a balance
between alignment performance and generation diversity. Importantly, f-DPO
outperforms PPO-based methods in divergence efficiency, and divergence con
straints directly influence expected calibration error (ECE).
"

How

Would like to work out a PR to add these f-divergences besides the current supported (reverse KL)

image

image

Implementation should be straightforward.
A simple code update for illustration purposes
`

   rejected_rewards = policy_chosen_logps.to(self.accelerator.device) - reference_chosen_logps.to(self.accelerator.device)
   rejected_rewards_exp = torch.exp(rejected_rewards)

   chosen_rewards = policy_chosen_logps.to(self.accelerator.device) - reference_chosen_logps.to(self.accelerator.device)
    chosen_rewards_exp = torch.exp(chosen_rewards)

   # Forward KL
   logits = -1/chosen_rewards_exp - (-1 / rejected_rewards_exp)

   # Js-divergence
   logits = (chosen_rewards - torch.log(1 + chosen_rewards_exp)) - (rejected_rewards - torch.log(1 + rejected_rewards_exp))

   # alpha-divergence
   logits = (1 - chosen_rewards_exp ** (-self.alpha_div)) / self.alpha_div -  (1 - rejected_rewards_exp ** (-self.alpha_div)) / self.alpha_div `

Possible updates to existing class
class DPOTrainer(Trainer): def __init__( ... **f_divergence_kwargs: Optional[Dict] = None,** ... )

Any concerns, please let me know.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions