-
Notifications
You must be signed in to change notification settings - Fork 129
Much faster CPU prompt processing (part 1) #531
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
With that we hit 360 t/s for LlaMA-3.1-8B on a Ryzen-7950X. q8_k_r8 is 386 t/s, so for a batch size of 512 repacking costs ~7% of the time taken by the actual GEMM.
Does this also improve the behavior at higher contexts? For me running Deepseek at higher contexts PP and TG both approach ~1 t/s. |
This indicates that your computer spends the entire time computing self attention for long enough context. If so, this PR will have zero impact on your long context performance. |
I'm trying to understand but that explanation (at least to me) doesn't explain why at low context PP uses a lot more power than TG (as it is compute bound), but at higher context the power usage looks a lot closer to TG (which is memory/QPI bandwidth bound). I don't have actual numbers (as I don't think the exact numbers matter) but the difference is stark enough for me to notice based on the CPU temperatures. |
Or is it rather the other way around (TG looks a lot closer to PP)? If you buy my explanation that for a large context all the time is spent in the self attention calculation, then there isn't that much of a difference between TG and PP: for DeepSeek each row in the KV cache multiples 128 rows of activations ( |
I do, I was just trying to understand it.
That makes sense. I did attempt to look at the PCM data I had from earlier and just generated, and looked at CPU power usage and IPC but I'm not sure if the numbers are actually useful since I found during TG that it was causing paging (there really isn't much spare RAM on my system during inference). |
Not a comprehensive test, but this So while not as dramatic given only 58 The Note, to keep it simple, I did not use DeepSeek-R1-0528-IQ1_S
DeepSeek-R1-0528-IQ1_S_R4
Importantly,
👈 sweep-bench datamodel=/mnt/raid/models/ubergarm/DeepSeek-R1-0528-GGUF/IQ1_S_R4/DeepSeek-R1-0528-IQ1_S_R4-00001-of-00003.gguf
#model=/mnt/raid/models/ubergarm/DeepSeek-R1-0528-GGUF/IQ1_S/DeepSeek-R1-0528-IQ1_S-00001-of-00003.gguf
numactl -N 0 -m 0 \
./build/bin/llama-sweep-bench \
--model "$model" \
-c 8704 \
-ctk q8_0 \
-mla 3 -fa \
-fmoe \
--no-mmap \
--threads 80 \
--threads-batch 128 \
--numa numactl \
--warmup-batch DeepSeek-R1-0528-IQ1_S_R4
PR531@72fd9faa DeepSeek-R1-0528-IQ1_S
main@6fc5bbb6 DeepSeek-R1-0528-IQ1_S
|
This is because of the extremely high total_experts/active_experts=32 ratio in DeeSeek-V3. For u_batch size of 512 we are still far away from the regime where this new repacking scheme pays large dividends. Perhaps the gains will be bigger for But yes, I see that this PR may not have the huge impact that it should because people have somehow decided that |
Yes, looks like even with the high ratio of deepseek MoE, this new repacking scheme begins to outstrip the PR531@72fd9faa DeepSeek-R1-0528-IQ1_S_R4 -b 4096 -ub 4096
PR531@72fd9faa DeepSeek-R1-0528-IQ1_S -b 4096 -ub 4096
I might try quanting this qwen2.5-72b finetune moonshotai/Kimi-Dev-72B today. your recent improvements (and reading commit logs for honestly, the biggest hurdle to general adoption of this fork, imo, is the lack of a pre-compiled distributible binary e.g. appimage format etc... my guess is the majority of possible end-users don't know how to |
I would be curious to the cutoff point. With something like |
It is model and quantization type dependent. But I'm not removing the |
I had been so used to V3 where I never enabled high batch sizes with amb because I rarely requested over the default batch size of 512. But with R1 that is not in the case (due to thought tokens removal which results in reprocessing context). I ran an experiment at high context, processing 4096 tokens (33640 to 37736) and this went from 2950 to 1619 seconds, and even a reduction in compute buffer ( |
This PR is a continuation of #515, #516, #517, #518 with the following differences
Q8_K_R8
instead ofQ8_0_R8
.Q8_K_R8
is the fastest quant known to human kind (see Q8_K_R8: Fastest quantized matrix multiplications #141), and that helps achieve significant performance gains when batch size is greater than 32 tokens or soIQ1_M, IQ2_XS, IQ2_S, Q3_K
in addition toIQ1_S, IQ2_XXS, IQ3_XXS, IQ3_S
already improved in the quoted PRsQ6_K
added, but in this case repacking is toQ8_0_R8
asQ6_K
cannot be losslessly repacked toQ8_K
, and I was worried that there could be a non-negligible accuracy loss due to that.The following table shows a PP-512 performance comparison between the main branch and this PR. Model is LlaMA-3.1-8B-Instruct. Quantization is always "pure" (i.e., all tensors except the output tensor and the token embedding tensor are quantized with the selected quantization type). CPU is Ryzen-7950X
A few notes:
Q8_0_R8
(IQ1_S, IQ2_XXS, IQ3_XXS, IQ3_S
) are in the range of 15-20%IQ1_M
stands out because it did not have a fastiqk
GEMM implementation at all, so we gain a factor of 12X!IQX_K
quants).I have the impression that most people use
ik_llama.cpp
for MoE models. MoE models are quite different compared to dense models such as LLaMA-3.1-8B because each routed expert "sees" a small fraction of the tokens in a batch, so effective batch size is much smaller compared to a dense model. Hence, PP performance gains for MoE models will be more modest. It is instructive to look as PP performance as a function of batch size. The following graph shows the result forQ3_K
, which has a reasonably efficientiqk
GEMM implementation. The repacking strategy kicks in at 32 tokens, so up to that point performance is the same. The relative performance gain from this PR then slowly grows to about 1.9X at 256 tokens, and remains (nearly) the same from there on.Based on this we can expect lower performance gains for a MoE model. For instance, DeepSeek-R1/V3 have 256 total experts but only 8 active experts, so effectively this strategy will not become active (or will have a very small impact) up to u-batch sizes of 1024 tokens. I cannot run DeepSeek-R1/V3, but I can run Qwen3-30B-A3B, and the next graphs shows performance for this model quantized with
Q3_K
. As expected, performance gains are smaller, about 1.4X at the peak, and poerformance improvement is not significant before 64 tokens.