-
Notifications
You must be signed in to change notification settings - Fork 72
Closed
Description
Line 129 in 619a8b3
queries, keys, vals = self.pos_embed(queries, keys, vals) |
It seems to me that the rotary position embedding is being applied on the head dimension (dim -2) of the vectors q, k instead of the sequence dimension (dim 1).
I think the head and sequence dimensions should be swapped before calling position embedding .
(see https://github.com/facebookresearch/xformers/blob/748c159096d4f9fcfe3eaf22801e5aed4777210b/xformers/components/positional_embedding/rotary.py#L85)
What I'm proposing is simply to re-write RotaryWithCast as follow:
class RotaryWithCast(RotaryEmbedding):
def forward(self, q, k, v):
q, k = super().forward(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3))
q = q.permute(0, 2, 1, 3)
k = k.permute(0, 2, 1, 3)
return q.to(v.dtype), k.to(v.dtype), v
Metadata
Metadata
Assignees
Labels
No labels