-
Notifications
You must be signed in to change notification settings - Fork 122
Description
Hi team,
When I try to run GRPO on a server equipped with Tesla V100-SXM2 (32 GB, compute capability 7.0) I hit the following exception coming from vLLM:
ValueError: Bfloat16 is only supported on GPUs with compute capability of at least 8.0.
Your Tesla V100-SXM2-32GB GPU has compute capability 7.0. You can use float16 instead
by explicitly setting the `dtype` flag in CLI, for example: --dtype=half.
Root cause
nemo_reinforcer/models/generation/vllm.py (around line 184) currently instantiates the model with dtype="auto", which resolves to bfloat16. Because V100 GPUs do not support bfloat16, the run aborts before any computation starts.
Work-around that solves the issue
Manually patching the constructor like this allows the model to run on V100:
self.llm = vllm.LLM(
model=self.model_name,
load_format=self.cfg["vllm_cfg"]["load_format"],
skip_tokenizer_init=self.cfg["vllm_cfg"]["skip_tokenizer_init"],
tensor_parallel_size=self.cfg["vllm_cfg"]["tensor_parallel_size"],
gpu_memory_utilization=self.cfg["vllm_cfg"]["gpu_memory_utilization"],
# --- changes below ---
dtype="float16", # instead of default "auto"
enable_chunked_prefill=False, # added
disable_sliding_window=True, # added
enable_prefix_caching=False, # added
# ----------------------
seed=seed,
enforce_eager=True, # CUDA-graph causes convergence issues (see #186)
max_model_len=self.cfg["vllm_cfg"]["max_model_len"],
trust_remote_code=True,
worker_extension_cls=(
"nemo_reinforcer.models.generation.vllm_backend.VllmInternalWorkerExtension"
),
enable_sleep_mode=True,
disable_log_stats=True,
**vllm_kwargs,
)
Proposal
It would be great if these parameters were exposed through vllm_cfg so users can switch precision (and the related flags) from their YAML without editing source code. For example:
vllm_cfg:
load_format: auto
tensor_parallel_size: 1
gpu_memory_utilization: 0.90
max_model_len: 8192
# new additions
dtype: half # or float16
enable_chunked_prefill: false
disable_sliding_window: true
enable_prefix_caching: false
self.llm = vllm.LLM(
model=self.model_name,
load_format=self.cfg["vllm_cfg"]["load_format"],
skip_tokenizer_init=self.cfg["vllm_cfg"]["skip_tokenizer_init"],
tensor_parallel_size=self.cfg["vllm_cfg"]["tensor_parallel_size"],
gpu_memory_utilization=self.cfg["vllm_cfg"]["gpu_memory_utilization"],
# --- changes below ---
dtype=self.cfg["vllm_cfg"]["dtype"],
enable_chunked_prefill=self.cfg["vllm_cfg"]["enable_chunked_prefill"],
disable_sliding_window=self.cfg["vllm_cfg"]["disable_sliding_window"],
enable_prefix_caching=self.cfg["vllm_cfg"]["enable_prefix_caching"],
# ----------------------
...
Thanks for the great work on Reinforcer! Let me know if I can provide any additional logs or help test a patch.
Even better, the code could automatically fall back to float16 when it detects a GPU with compute capability < 8.0 to provide a smoother out-of-the-box experience.