Skip to content

clarification on flashmla_backend #5154

@chesterout

Description

@chesterout

Hi, in flashmla_backend,
when I print self.num_q_heads & self.num_kv_heads, the numbers are all 16.
However, in the original flashmla repo, the num_kv_heads should be 1.
Could you clarify this difference?

Thanks.


mla_metadata, num_splits = get_mla_metadata(
 forward_batch.seq_lens.to(torch.int32),
 Q_LEN * self.num_q_heads // self.num_kv_heads,
 self.num_kv_heads,
)

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions