Skip to content

About the speed test #8

@MzeroMiko

Description

@MzeroMiko

Thank you for sharing your fantastic work.
We have noticed the image that with rising the dimension of d_state, the mamba's time occupation doesn't rise.
However, we found in code that writes (selective_scan_fwd_kernel.cuh#163):

for (int state_idx = 0; state_idx < params.dstate; ++state_idx) {
    ...
   if constexpr (kIsVariableB) {
                load_weight<Ktraits>(Bvar + state_idx * params.B_dstate_stride, B_vals,
                    smem_load_weight, (params.seqlen - chunk * kChunkSize) * (!kIsComplex ? 1 : 2));
   }
}

which shows a for loop with related to state_idx that reads from HBM to shared memory.

Then I tested the speed again and finds that with the d_state rises, the time occupation of mamba rises linearly, which is aligned with the code.

    device = torch.device("cuda")
    dtype = torch.float32
    B, L, G, D, N, R = 3, 4096, 4, 192, 16, 192 // 16
    xi = torch.randn((B, G * D, L), device=device, dtype=dtype)
    Ai = torch.randn((G * D, N), device=device, dtype=dtype)
    Di = torch.randn((G * D), device=device, dtype=dtype)
    dti = torch.randn((B, G * D, L), device=device, dtype=dtype)
    Bi = torch.randn((B, G, N, L), device=device, dtype=dtype)
    Ci = torch.randn((B, G, N, L), device=device, dtype=dtype)
    tpb = torch.randn((G * D), device=device, dtype=dtype)

    Ai2 = torch.randn((G * D, 4*N), device=device, dtype=dtype)
    Bi2 = torch.randn((B, G, 4*N, L), device=device, dtype=dtype)
    Ci2 = torch.randn((B, G, 4*N, L), device=device, dtype=dtype)

    import time
    tim0 = time.time()
    for _ in range(1000):
        y = selective_scan_fn(xi, dti, Ai, Bi, Ci, Di, tpb, True)
    torch.cuda.synchronize()
    torch.cuda.empty_cache()
    tim1 = time.time()
    for _ in range(1000):
        y = selective_scan_fn(xi, dti, Ai2, Bi2, Ci2, Di, tpb, True)
    torch.cuda.synchronize()
    torch.cuda.empty_cache()
    tim2 = time.time()
    print(tim1-tim0, tim2-tim1, torch.cuda.max_memory_allocated()) # 0.7172577381134033 2.400775194168091 185063424
    time.sleep(100000)

So what did I miss?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions