-
Notifications
You must be signed in to change notification settings - Fork 127
Description
Hello, I'm thrilled to see that linear and NTK interpolation have been elegantly combined to create a much stronger interpolation strategy—YARN. However, while going through the code in modeling_llama.py, I find myself a bit confused by the calculation of inv_freq
, particularly at line398.
According to the YaRN paper, in equation 23, it is stated as follows:
Consequently, we can derive:
However, in the paper, the calculation of
Hence, I think there might be some problem with equation 25 and also with line398
. Perhaps we can revise the yarn
function as follows, since I've empirically found that this fix can further enhance performance:
def revised_yarn(self, device):
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
low, high = _yarn_find_correction_range(self.beta_fast, self.beta_slow, self.dim, self.base, self.original_max_position_embeddings)
inv_freq_mask = (1 - _yarn_linear_ramp_mask(low, high, self.dim // 2).float().to(device)) * self.extrapolation_factor
inv_freq = inv_freq / ((1-inv_freq_mask)*self.scale + inv_freq_mask)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.mscale = float(_yarn_get_mscale(self.scale) * self.attn_factor)