-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Description
Why
There is a paper discussing generalizing DPO with different f-divergences (present implementation using log() is one of its kind: reverse kL w/ a = 0) to help model better balance alignment performance and generation diversity.
According to the paper,
"
Empirically, adopting these f-divergences ensures a balance
between alignment performance and generation diversity. Importantly, f-DPO
outperforms PPO-based methods in divergence efficiency, and divergence con
straints directly influence expected calibration error (ECE).
"
How
Would like to work out a PR to add these f-divergences besides the current supported (reverse KL)
Implementation should be straightforward.
A simple code update for illustration purposes
`
rejected_rewards = policy_chosen_logps.to(self.accelerator.device) - reference_chosen_logps.to(self.accelerator.device)
rejected_rewards_exp = torch.exp(rejected_rewards)
chosen_rewards = policy_chosen_logps.to(self.accelerator.device) - reference_chosen_logps.to(self.accelerator.device)
chosen_rewards_exp = torch.exp(chosen_rewards)
# Forward KL
logits = -1/chosen_rewards_exp - (-1 / rejected_rewards_exp)
# Js-divergence
logits = (chosen_rewards - torch.log(1 + chosen_rewards_exp)) - (rejected_rewards - torch.log(1 + rejected_rewards_exp))
# alpha-divergence
logits = (1 - chosen_rewards_exp ** (-self.alpha_div)) / self.alpha_div - (1 - rejected_rewards_exp ** (-self.alpha_div)) / self.alpha_div `
Possible updates to existing class
class DPOTrainer(Trainer): def __init__( ... **f_divergence_kwargs: Optional[Dict] = None,** ... )
Any concerns, please let me know.