-
Notifications
You must be signed in to change notification settings - Fork 2.8k
[Fix] use torch.inference_mode()
instead of torch.no_grad()
#4372
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Fix] use torch.inference_mode()
instead of torch.no_grad()
#4372
Conversation
@Alcanderian Some unit tests failed, may you help fix that? Thanks! |
Refer to the error in
It seems that I am going to fix it by reintroducing |
It looks like most of the issues have been resolved, but there are still some accuracy issues. How should we approach this situation? @zhyncs Further work: Create a |
gentle ping @Alcanderian three gold bro, let's go :) |
@@ -127,6 +128,63 @@ def is_cuda_available(): | |||
return is_cuda() | |||
|
|||
|
|||
_ENABLE_TORCH_INFERENCE_MODE = os.getenv( | |||
"SGLANG_ENABLE_TORCH_INFERENCE_MODE", "false" | |||
).lower() in ("true", "1") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use this function instead
sglang/python/sglang/srt/utils.py
Line 1330 in 45212ce
def get_bool_env_var(name: str, default: str = "false") -> bool: |
Motivation
Fix issue #4366
Modifications
Replace all
torch.no_grad()
withtorch.inference_mode()
Benchmark on H100
Conclusion: There is basically no difference in performance. It seems to have a slight improvement, which could be due to fluctuations.
With CUDA Graph
command
python3 -m sglang.bench_one_batch --model-path Qwen/Qwen2.5-7B-Instruct --batch 32 --input-len 256 --output-len 32
Without CUDA Graph
command
python3 -m sglang.bench_one_batch --model-path Qwen/Qwen2.5-7B-Instruct --batch 32 --input-len 256 --output-len 32 --disable-cuda-graph
Checklist