-
Notifications
You must be signed in to change notification settings - Fork 111
Closed
Description
Thank you for sharing your fantastic work.
We have noticed the image that with rising the dimension of d_state
, the mamba's time occupation doesn't rise.
However, we found in code that writes (selective_scan_fwd_kernel.cuh#163):
for (int state_idx = 0; state_idx < params.dstate; ++state_idx) {
...
if constexpr (kIsVariableB) {
load_weight<Ktraits>(Bvar + state_idx * params.B_dstate_stride, B_vals,
smem_load_weight, (params.seqlen - chunk * kChunkSize) * (!kIsComplex ? 1 : 2));
}
}
which shows a for loop with related to state_idx that reads from HBM to shared memory.
Then I tested the speed again and finds that with the d_state rises, the time occupation of mamba rises linearly, which is aligned with the code.
device = torch.device("cuda")
dtype = torch.float32
B, L, G, D, N, R = 3, 4096, 4, 192, 16, 192 // 16
xi = torch.randn((B, G * D, L), device=device, dtype=dtype)
Ai = torch.randn((G * D, N), device=device, dtype=dtype)
Di = torch.randn((G * D), device=device, dtype=dtype)
dti = torch.randn((B, G * D, L), device=device, dtype=dtype)
Bi = torch.randn((B, G, N, L), device=device, dtype=dtype)
Ci = torch.randn((B, G, N, L), device=device, dtype=dtype)
tpb = torch.randn((G * D), device=device, dtype=dtype)
Ai2 = torch.randn((G * D, 4*N), device=device, dtype=dtype)
Bi2 = torch.randn((B, G, 4*N, L), device=device, dtype=dtype)
Ci2 = torch.randn((B, G, 4*N, L), device=device, dtype=dtype)
import time
tim0 = time.time()
for _ in range(1000):
y = selective_scan_fn(xi, dti, Ai, Bi, Ci, Di, tpb, True)
torch.cuda.synchronize()
torch.cuda.empty_cache()
tim1 = time.time()
for _ in range(1000):
y = selective_scan_fn(xi, dti, Ai2, Bi2, Ci2, Di, tpb, True)
torch.cuda.synchronize()
torch.cuda.empty_cache()
tim2 = time.time()
print(tim1-tim0, tim2-tim1, torch.cuda.max_memory_allocated()) # 0.7172577381134033 2.400775194168091 185063424
time.sleep(100000)
So what did I miss?
Metadata
Metadata
Assignees
Labels
No labels