Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion recipe/prime/prime_ray_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def _validate_config(self):
super()._validate_config()
# TODO: Additional config checks can be added here

def _create_dataloader(self):
def _create_dataloader(self, *args, **kwargs):
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler

# TODO: we have to make sure the batch size is divisible by the dp size
Expand Down
12 changes: 11 additions & 1 deletion tests/e2e/sft/test_sp_loss_match.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,17 @@ def create_trainer(config):
dp_size = world_size // config.ulysses_sequence_parallel_size
ulysses_device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(dp_size, config.ulysses_sequence_parallel_size), mesh_dim_names=("dp", "sp"))

return FSDPSFTTrainer(config=config, device_mesh=device_mesh, ulysses_device_mesh=ulysses_device_mesh)
# build tokenizer and datasets first
from verl.trainer.fsdp_sft_trainer import create_sft_dataset
from verl.utils import hf_tokenizer
from verl.utils.fs import copy_to_local

local_model_path = copy_to_local(src=config.model.partial_pretrain, verbose=True)
tokenizer = hf_tokenizer(local_model_path, trust_remote_code=config.model.trust_remote_code)
train_dataset = create_sft_dataset(config.data.train_files, config.data, tokenizer)
val_dataset = create_sft_dataset(config.data.val_files, config.data, tokenizer)

return FSDPSFTTrainer(config=config, device_mesh=device_mesh, ulysses_device_mesh=ulysses_device_mesh, tokenizer=tokenizer, train_dataset=train_dataset, val_dataset=val_dataset)


def main(config):
Expand Down
63 changes: 37 additions & 26 deletions verl/trainer/fsdp_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
from torch.distributed.fsdp import CPUOffload, MixedPrecision, ShardingStrategy
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.utils.data import DataLoader, DistributedSampler
from torch.utils.data import DataLoader, Dataset, DistributedSampler
from tqdm import tqdm
from transformers import AutoConfig, AutoModelForCausalLM, PreTrainedModel

Expand Down Expand Up @@ -82,16 +82,12 @@ def convert_to_regular_types(obj):


class FSDPSFTTrainer:
def __init__(self, config, device_mesh: DeviceMesh, ulysses_device_mesh: DeviceMesh):
def __init__(self, config, device_mesh: DeviceMesh, ulysses_device_mesh: DeviceMesh, tokenizer, train_dataset: Dataset, val_dataset: Dataset):
self.config = config
self.device_mesh = device_mesh
self.ulysses_device_mesh = ulysses_device_mesh
self.sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh)
# build tokenizer first
local_model_path = copy_to_local(src=self.config.model.partial_pretrain, verbose=True)
from verl.utils import hf_tokenizer

self.tokenizer = hf_tokenizer(local_model_path, trust_remote_code=self.config.model.trust_remote_code)
self.tokenizer = tokenizer
if self.config.data.chat_template is not None:
raise ValueError("Apply Chat template from config is not supported yet.")

Expand All @@ -105,7 +101,7 @@ def __init__(self, config, device_mesh: DeviceMesh, ulysses_device_mesh: DeviceM
print(f"Using sequence parallel size: {self.config.ulysses_sequence_parallel_size}")
print(f"Using remove padding: {self.use_remove_padding}")

self._build_dataloader()
self._build_dataloader(train_dataset, val_dataset)
# build model
self._build_model_optimizer()

Expand All @@ -124,24 +120,10 @@ def _normalize_config_bsz(self):

assert self.config.data.train_batch_size % self.config.data.micro_batch_size_per_gpu == 0

def _build_dataloader(self):
config = self.config
def _build_dataloader(self, train_dataset, val_dataset):
# build dataset
from verl.utils.import_utils import load_extern_type

# First check if a custom dataset class is specified
if config.data.custom_cls.get("path", None):
dataset_cls = load_extern_type(config.data.custom_cls.path, config.data.custom_cls.name)
# Then check if multi-turn dataset should be used
elif config.data.get("multiturn", {}).get("enable", False):
dataset_cls = MultiTurnSFTDataset
# Default to single-turn dataset
else:
dataset_cls = SFTDataset

# Create datasets based on the selected class
self.train_dataset = dataset_cls(parquet_files=config.data.train_files, tokenizer=self.tokenizer, config=config.data)
self.val_dataset = dataset_cls(parquet_files=config.data.val_files, tokenizer=self.tokenizer, config=config.data)
config = self.config
self.train_dataset, self.val_dataset = train_dataset, val_dataset

# build dataloader
# Use data parallel rank and size instead of global rank and world size
Expand Down Expand Up @@ -525,9 +507,38 @@ def main(config):
device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(world_size,), mesh_dim_names=("fsdp",))
dp_size = world_size // config.ulysses_sequence_parallel_size
ulysses_device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(dp_size, config.ulysses_sequence_parallel_size), mesh_dim_names=("dp", "sp"))
trainer = FSDPSFTTrainer(config=config, device_mesh=device_mesh, ulysses_device_mesh=ulysses_device_mesh)
# build tokenizer and datasets first
from verl.utils import hf_tokenizer

local_model_path = copy_to_local(src=config.model.partial_pretrain, verbose=True)
tokenizer = hf_tokenizer(local_model_path, trust_remote_code=config.model.trust_remote_code)
train_dataset = create_sft_dataset(config.data.train_files, config.data, tokenizer)
val_dataset = create_sft_dataset(config.data.val_files, config.data, tokenizer)

trainer = FSDPSFTTrainer(config=config, device_mesh=device_mesh, ulysses_device_mesh=ulysses_device_mesh, tokenizer=tokenizer, train_dataset=train_dataset, val_dataset=val_dataset)

trainer.fit()


def create_sft_dataset(data_paths, data_config, tokenizer):
"""Create a dataset."""
# build dataset
# First check if a custom dataset class is specified
if data_config.custom_cls.get("path", None):
from verl.utils.import_utils import load_extern_type

dataset_cls = load_extern_type(data_config.custom_cls.path, data_config.custom_cls.name)
# Then check if multi-turn dataset should be used
elif data_config.get("multiturn", {}).get("enable", False):
dataset_cls = MultiTurnSFTDataset
# Default to single-turn dataset
else:
dataset_cls = SFTDataset

# Create datasets based on the selected class
dataset = dataset_cls(parquet_files=data_paths, tokenizer=tokenizer, config=data_config)
return dataset


if __name__ == "__main__":
main()
68 changes: 68 additions & 0 deletions verl/trainer/main_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,11 @@ def run(self, config):
val_reward_fn = load_reward_manager(config, tokenizer, num_examine=1)
resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping)

from verl.utils.dataset.rl_dataset import collate_fn

train_dataset = create_rl_dataset(config.data.train_files, config.data, tokenizer, processor)
val_dataset = create_rl_dataset(config.data.val_files, config.data, tokenizer, processor)
train_sampler = create_rl_sampler(config.data, train_dataset)
trainer = RayPPOTrainer(
config=config,
tokenizer=tokenizer,
Expand All @@ -172,10 +177,73 @@ def run(self, config):
ray_worker_group_cls=ray_worker_group_cls,
reward_fn=reward_fn,
val_reward_fn=val_reward_fn,
train_dataset=train_dataset,
val_dataset=val_dataset,
collate_fn=collate_fn,
train_sampler=train_sampler,
)
trainer.init_workers()
trainer.fit()


def create_rl_dataset(data_paths, data_config, tokenizer, processor):
"""Create a dataset.

Arguments:
data_config: The data config.
tokenizer (Tokenizer): The tokenizer.
processor (Processor): The processor.

Returns:
dataset (Dataset): The dataset.
"""
from torch.utils.data import Dataset

from verl.utils.dataset.rl_dataset import RLHFDataset

if "custom_cls" in data_config and data_config.custom_cls.get("path", None) is not None:
from verl.utils.import_utils import load_extern_type

dataset_cls = load_extern_type(data_config.custom_cls.path, data_config.custom_cls.name)
if not issubclass(dataset_cls, Dataset):
raise TypeError(f"The custom dataset class '{data_config.custom_cls.name}' from '{data_config.custom_cls.path}' must inherit from torch.utils.data.Dataset")
else:
dataset_cls = RLHFDataset
print(f"Using dataset class: {dataset_cls.__name__}")

dataset = dataset_cls(
data_files=data_paths,
tokenizer=tokenizer,
processor=processor,
config=data_config,
)

return dataset


def create_rl_sampler(data_config, dataset):
"""Create a sampler for the dataset.

Arguments:
data_config: The data config.
dataset (Dataset): The dataset.

Returns:
sampler (Sampler): The sampler.
"""
import torch
from torch.utils.data import RandomSampler, SequentialSampler

# use sampler for better ckpt resume
if data_config.shuffle:
train_dataloader_generator = torch.Generator()
train_dataloader_generator.manual_seed(data_config.get("seed", 1))
sampler = RandomSampler(data_source=dataset, generator=train_dataloader_generator)
else:
sampler = SequentialSampler(data_source=dataset)

return sampler


if __name__ == "__main__":
main()
65 changes: 24 additions & 41 deletions verl/trainer/ppo/ray_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,14 @@
from dataclasses import dataclass, field
from enum import Enum
from pprint import pprint
from typing import Dict, Type
from typing import Dict, Optional, Type

import numpy as np
import ray
import torch
from codetiming import Timer
from omegaconf import OmegaConf, open_dict
from torch.utils.data import Dataset, RandomSampler, SequentialSampler
from torch.utils.data import Dataset, Sampler
from torchdata.stateful_dataloader import StatefulDataLoader
from tqdm import tqdm

Expand All @@ -54,7 +54,6 @@
)
from verl.trainer.ppo.reward import compute_reward, compute_reward_async
from verl.utils.checkpoint.checkpoint_manager import find_latest_ckpt_path
from verl.utils.dataset.rl_dataset import RLHFDataset, collate_fn
from verl.utils.seqlen_balancing import get_seqlen_balanced_partitions, log_seqlen_unbalance
from verl.utils.torch_functional import masked_mean
from verl.utils.tracking import ValidationGenerationsLogger
Expand Down Expand Up @@ -279,6 +278,10 @@ def __init__(
processor=None,
reward_fn=None,
val_reward_fn=None,
train_dataset: Optional[Dataset] = None,
val_dataset: Optional[Dataset] = None,
collate_fn=None,
train_sampler: Optional[Sampler] = None,
):
# assert torch.cuda.is_available(), 'cuda must be available on driver'

Expand Down Expand Up @@ -320,7 +323,7 @@ def __init__(
raise NotImplementedError

self._validate_config()
self._create_dataloader()
self._create_dataloader(train_dataset, val_dataset, collate_fn, train_sampler)

def _validate_config(self):
config = self.config
Expand Down Expand Up @@ -435,53 +438,33 @@ def check_mutually_exclusive(mbs, mbs_per_gpu, name: str):

print("[validate_config] All configuration checks passed successfully!")

def _create_dataloader(self):
def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampler):
"""
Creates the train and validation dataloaders.
"""
# make sure the batch size is divisible by the dp size
from verl.utils.import_utils import load_extern_type

if "custom_cls" in self.config.data and self.config.data.custom_cls.get("path", None) is not None:
# Dynamically load the custom dataset class specified in config
try:
dataset_cls = load_extern_type(self.config.data.custom_cls.path, self.config.data.custom_cls.name)
if not issubclass(dataset_cls, Dataset):
raise TypeError(f"The custom dataset class '{self.config.data.custom_cls.name}' from '{self.config.data.custom_cls.path}' must inherit from torch.utils.data.Dataset")
print(f"Using custom dataset class: {dataset_cls.__name__}")
except Exception as e:
print(f"Error loading custom dataset class: {e}")
raise e
else:
dataset_cls = RLHFDataset
print(f"Using default dataset class: {dataset_cls.__name__}")
self.train_dataset = dataset_cls(
data_files=self.config.data.train_files,
tokenizer=self.tokenizer,
processor=self.processor,
config=self.config.data,
)
if self.config.data.shuffle:
train_dataloader_generator = torch.Generator()
train_dataloader_generator.manual_seed(self.config.data.get("seed", 1))
sampler = RandomSampler(data_source=self.train_dataset, generator=train_dataloader_generator)
else:
sampler = SequentialSampler(data_source=self.train_dataset)
# TODO: we have to make sure the batch size is divisible by the dp size
from verl.trainer.main_ppo import create_rl_dataset, create_rl_sampler

if train_dataset is None:
train_dataset = create_rl_dataset(self.config.data.train_files, self.config.data, self.tokenizer, self.processor)
if val_dataset is None:
val_dataset = create_rl_dataset(self.config.data.val_files, self.config.data, self.tokenizer, self.processor)
self.train_dataset, self.val_dataset = train_dataset, val_dataset

if train_sampler is None:
train_sampler = create_rl_sampler(self.config.data, self.train_dataset)
if collate_fn is None:
from verl.utils.dataset.rl_dataset import collate_fn as default_collate_fn

collate_fn = default_collate_fn

self.train_dataloader = StatefulDataLoader(
dataset=self.train_dataset,
batch_size=self.config.data.get("gen_batch_size", self.config.data.train_batch_size),
num_workers=self.config.data.get("dataloader_num_workers", 8),
drop_last=True,
collate_fn=collate_fn,
sampler=sampler,
)

self.val_dataset = dataset_cls(
data_files=self.config.data.val_files,
tokenizer=self.tokenizer,
processor=self.processor,
config=self.config.data,
sampler=train_sampler,
)

val_batch_size = self.config.data.val_batch_size # Prefer config value if set
Expand Down
1 change: 1 addition & 0 deletions verl/utils/dataset/rl_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@


def collate_fn(data_list: list[dict]) -> dict:
"""Collate a batch of data."""
tensors = defaultdict(list)
non_tensors = defaultdict(list)

Expand Down