Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions recipe/char_count/READMD.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Char Count
## Introduction
Char count is a simple NLP task. We create it for beginners to grasp the idea of RLVR. The task can be trained using a tiny model (e.g., https://huggingface.co/HuggingFaceTB/SmolLM2-135M) on a consumer GPU with only 8GB.

## Problem formulation
The prompt is: "How many {char} are there in {word}?". In order for LLM to better answer this question, we create SFT dataset with intermediate steps. For example,

```text
Question: How many n are there in n-i-n-e?
Answer:
n = n
i != n
n = n
e != n
\boxed{2}
```

Note that
- We add a dash between each individual char to make the task easier because each individual char will be tokenized to the same token by most tokenizer.
- In the SFT dataset, we create a CoT by listing all the individual chars and whether it equals to the target. In the end, it outputs the final answer inside the box.
- The task can be verified.
- The word is not always meaningful. Each char is sampled uniformly from a to z. We make the total length and the answer uniformly distributed within a range.

## Scripts
To create the dataset, run
```bash
python3 create_dataset.py
```
We create a train set and a val set. Both of them are used of SFT and RL. You can specify the total number of data, min/max length and data path.

To run the SFT
```bash
bash train_sft.sh
```
We train SFT for 3 epochs. After 3 epochs, the validation score is around 0.12.

To run GRPO
```bash
bash train_grpo.sh
```
We train GRPO for 2 epochs. After 2 epochs, the validation score is around 0.36.
187 changes: 187 additions & 0 deletions recipe/char_count/create_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Task description:
Given a random word and a random char, count the number of occurrence of char in the word.

Create CoT dataset that split the word into separate char. Then list the char and count the occurrence.

The word set comes from shakespeare
"""

import os.path
import random

prompt_template = "How many {} are there in word {}?"


def generate_random_char():
return chr(97 + random.randint(0, 25))


def create_prompt_response(min_length=3, max_length=5):
# randomly generate a length
word_length = random.randint(min_length, max_length)
# randomly generate a target count number. This makes the target number
target_count_number = random.randint(1, word_length)

char_lst = []
# generate the word
# step 1: generate the target word
target_char = generate_random_char()

for _ in range(target_count_number):
char_lst.append(target_char)

# step 2: generate other words
for _ in range(word_length - target_count_number):
while True:
char = generate_random_char()
if char != target_char:
char_lst.append(char)
break

# step 3: random permute char_lst
random.shuffle(char_lst)

word = "-".join(char_lst)

prompt = prompt_template.format(target_char, word)
final_answer = []

# cot
number = 0
for i, char in enumerate(char_lst):
cot = f"{char}"
if char != target_char:
cot += " != "
else:
cot += " = "
number += 1
cot += f"{target_char}."

final_answer.append(cot)

conclusion = f"\\boxed{{{number}}} {target_char} in {word}."

final_answer.append(conclusion)

final_answer = "\n".join(final_answer)

return prompt, final_answer


if __name__ == "__main__":
import argparse

parser = argparse.ArgumentParser()
parser.add_argument("--total_number", type=int, default=10000)
parser.add_argument("--min_length", type=int, default=5)
parser.add_argument("--max_length", type=int, default=20)
parser.add_argument("--data_path", type=str, default="~/data/char_count")

args = vars(parser.parse_args())

total_number = args["total_number"]
min_length = args["min_length"]
max_length = args["max_length"]
data_path = args["data_path"]
data_path = os.path.expanduser(data_path)

full_output = []
for _ in range(total_number):
output = create_prompt_response(min_length=min_length, max_length=max_length)
full_output.append(output)

# random reorder
random.shuffle(full_output)

# split for train and test
train_split_len = int(0.9 * len(full_output))
train_outputs = full_output[:train_split_len]
test_output = full_output[train_split_len:]

sft_train_dataset = {"prompt": [], "response": []}

for o in train_outputs:
sft_train_dataset["prompt"].append(o[0])
sft_train_dataset["response"].append(o[1])

sft_test_dataset = {"prompt": [], "response": []}

for o in test_output:
sft_test_dataset["prompt"].append(o[0])
sft_test_dataset["response"].append(o[1])

import pandas as pd

sft_train_dataset = pd.DataFrame(data=sft_train_dataset)
sft_test_dataset = pd.DataFrame(data=sft_test_dataset)

folder = os.path.join(data_path, "sft")

os.makedirs(folder, exist_ok=True)

sft_train_dataset.to_parquet(os.path.join(folder, "train.parquet"))
sft_test_dataset.to_parquet(os.path.join(folder, "test.parquet"))

# build RL dataset
rl_train_dataset = {"prompt": [], "data_source": [], "ability": [], "reward_model": [], "extra_info": []}

rl_test_dataset = {"prompt": [], "data_source": [], "ability": [], "reward_model": [], "extra_info": []}

from verl.utils.reward_score.math import last_boxed_only_string, remove_boxed

for o in train_outputs:
prompt = o[0]
response = o[1]
prompt_with_template = [
{
"role": "user",
"content": prompt,
}
]

rl_train_dataset["prompt"].append(prompt_with_template)
rl_train_dataset["data_source"].append("char_count")
rl_train_dataset["ability"].append("other")
rl_train_dataset["reward_model"].append({"style": "rule", "ground_truth": remove_boxed(last_boxed_only_string(response))})
rl_train_dataset["extra_info"].append({"response": response})

for o in test_output:
prompt = o[0]
response = o[1]
prompt_with_template = [
{
"role": "user",
"content": prompt,
}
]

rl_test_dataset["prompt"].append(prompt_with_template)
rl_test_dataset["data_source"].append("char_count")
rl_test_dataset["ability"].append("other")
rl_test_dataset["reward_model"].append({"style": "rule", "ground_truth": remove_boxed(last_boxed_only_string(response))})
rl_test_dataset["extra_info"].append({"response": response})

rl_train_dataset = pd.DataFrame(data=rl_train_dataset)
rl_test_dataset = pd.DataFrame(data=rl_test_dataset)

folder = os.path.join(data_path, "rl")

os.makedirs(folder, exist_ok=True)

rl_train_dataset.to_parquet(os.path.join(folder, "train.parquet"))
rl_test_dataset.to_parquet(os.path.join(folder, "test.parquet"))
34 changes: 34 additions & 0 deletions recipe/char_count/reward_function.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Reward function
"""

from verl.utils.reward_score import math


def char_count_reward_function(data_source, solution_str, ground_truth, extra_info=None):
try:
last_boxed_string = math.last_boxed_only_string(solution_str)
if last_boxed_string is None:
return 0
solution = math.remove_boxed(last_boxed_string)
if solution == ground_truth:
return 1
else:
return 0
except Exception:
print(ground_truth, solution_str)
return 0
44 changes: 44 additions & 0 deletions recipe/char_count/train_grpo.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
set -x

#export VLLM_ATTENTION_BACKEND=XFORMERS

python3 -m verl.trainer.main_ppo \
algorithm.adv_estimator=grpo \
data.train_files=$HOME/data/char_count/rl/train.parquet \
data.val_files=$HOME/data/char_count/rl/test.parquet \
data.train_batch_size=128 \
data.max_prompt_length=128 \
data.max_response_length=128 \
data.filter_overlong_prompts=False \
data.truncation='error' \
actor_rollout_ref.model.path=./models/sft/global_step_105 \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.actor.ppo_mini_batch_size=16 \
actor_rollout_ref.actor.use_dynamic_bsz=True \
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=5000 \
actor_rollout_ref.actor.use_kl_loss=False \
actor_rollout_ref.actor.kl_loss_coef=0.0 \
actor_rollout_ref.actor.kl_loss_type=low_var_kl \
actor_rollout_ref.actor.entropy_coeff=0 \
actor_rollout_ref.model.enable_gradient_checkpointing=True \
actor_rollout_ref.actor.fsdp_config.param_offload=True \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \
actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
actor_rollout_ref.rollout.name=vllm \
actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \
actor_rollout_ref.rollout.n=8 \
actor_rollout_ref.ref.fsdp_config.param_offload=True \
algorithm.use_kl_in_reward=False \
trainer.critic_warmup=0 \
trainer.logger=['console','tensorboard'] \
trainer.project_name='verl_example' \
trainer.experiment_name='smol135m_grpo' \
trainer.val_before_train=True \
trainer.n_gpus_per_node=1 \
trainer.nnodes=1 \
trainer.save_freq=-1 \
trainer.test_freq=5 \
trainer.total_epochs=2 \
custom_reward_function.path=/home/chi/Developer/verl/recipe/char_count/reward_function.py \
custom_reward_function.name=char_count_reward_function
22 changes: 22 additions & 0 deletions recipe/char_count/train_sft.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
set -x

nproc_per_node=1
save_path=./models/sft

torchrun --standalone --nnodes=1 --nproc_per_node=$nproc_per_node \
-m verl.trainer.fsdp_sft_trainer \
data.train_files=$HOME/data/char_count/sft/train.parquet \
data.val_files=$HOME/data/char_count/sft/test.parquet \
data.prompt_key=prompt \
data.response_key=response \
data.micro_batch_size_per_gpu=8 \
data.max_length=256 \
data.train_batch_size=256 \
use_remove_padding=True \
model.partial_pretrain=$HOME/models/SmolLM2-135M-Instruct \
trainer.default_local_dir=$save_path \
trainer.project_name=char_count-sft \
trainer.experiment_name=char_count-sft-SmolLM2-135M-Instruct \
trainer.total_epochs=3 \
trainer.logger=['console'] \
trainer.default_hdfs_dir=null
4 changes: 2 additions & 2 deletions verl/trainer/config/sft_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ data:
# Single-turn settings
prompt_key: question
response_key: answer
prompt_dict_keys: ['question']
response_dict_keys: ['answer']
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this intentional? this would be a breaking change

prompt_dict_keys: null
response_dict_keys: null
# Multi-turn settings
multiturn:
enable: false # Set to true to use multi-turn dataset
Expand Down
2 changes: 2 additions & 0 deletions verl/utils/flops_counter.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ def unit_convert(number, level):
flops = 148e12
elif "910B" in device_name:
flops = 354e12
elif "RTX 3070 Ti" in device_name:
flops = 21.75e12
flops_unit = unit_convert(flops, unit)
return flops_unit

Expand Down
Loading