Skip to content

Commit

Permalink
[misc] fix reward model issue with TokenClassification model and supp…
Browse files Browse the repository at this point in the history
…ort running particular steps instead of epochs (#99)

* support user specify training steps

* fix typo

* update ci

* add ci

* fix reward model and write  more ci script

* update ci

* lint

* align

* delete post training val

* fix script
  • Loading branch information
PeterSH6 authored Jan 13, 2025
1 parent 53c3ff4 commit a0e8ed2
Show file tree
Hide file tree
Showing 12 changed files with 267 additions and 26 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: e2e_gpu
name: e2e_digit_completion

on:
# Trigger the workflow on push or pull request,
Expand All @@ -8,16 +8,16 @@ on:
- main
paths:
- "**/*.py"
- .github/workflows/e2e_gpu.yml
- .github/workflows/e2e_digit_completion.yml
pull_request:
branches:
- main
paths:
- "**/*.py"
- .github/workflows/e2e_gpu.yml
- .github/workflows/e2e_digit_completion.yml

jobs:
e2e_gpu:
e2e_digit_completion:
runs-on: [self-hosted, l20-1]
env:
HTTP_PROXY: ${{ secrets.PROXY_HTTP }}
Expand All @@ -38,7 +38,3 @@ jobs:
- name: Running digit completon e2e training tests on 8 L20 GPUs
run: |
bash tests/e2e/run_ray_trainer.sh
- name: Running digit completon e2e training tests on 8 L20 GPUs (with rmpad)
run: |
pip3 install --upgrade transformers
bash tests/e2e/run_ray_trainer_rmpad.sh
52 changes: 52 additions & 0 deletions .github/workflows/e2e_gsm8k.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
name: e2e_gsm8k

on:
# Trigger the workflow on push or pull request,
# but only for the main branch
push:
branches:
- main
paths:
- "**/*.py"
- .github/workflows/e2e_gsm8k.yml
pull_request:
branches:
- main
paths:
- "**/*.py"
- .github/workflows/e2e_gsm8k.yml

jobs:
e2e_gsm8k:
runs-on: [self-hosted, l20-1]
env:
HTTP_PROXY: ${{ secrets.PROXY_HTTP }}
HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }}
NO_PROXY: "localhost,127.0.0.1"
HF_HUB_ENABLE_HF_TRANSFER: 1
container:
image: verlai/verl:vemlp-th2.4.0-cu124-vllm0.6.3-ray2.10-te1.7-v0.0.3
options: --gpus all --shm-size=10g
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
fetch-depth: 0
- name: Install the current repository
run: |
pip3 install hf_transfer
pip3 install -e .[test]
- name: Prepare gsm8k dataset
run: |
python3 examples/data_preprocess/gsm8k.py
- name: Running gsm8k e2e training tests on 8 L20 GPUs with rmpad using function rm
run: |
bash tests/e2e/run_qwen_gsm8k_function_rm.sh
- name: Running gsm8k e2e without rmpad using function rm
run: |
bash tests/e2e/run_qwen_gsm8k_function_rm_no_rmpad.sh
- name: Running gsm8k e2e with rmpad using model rm
run: |
bash tests/e2e/run_qwen_gsm8k_model_rm.sh
- name: Running gsm8k e2e without rmpad using model rm
run: |
bash tests/e2e/run_qwen_gsm8k_model_rm_no_rmpad.sh
4 changes: 3 additions & 1 deletion examples/ppo_trainer/run_qwen2-7b_rm.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ math_test_path=$HOME/data/math/test.parquet
train_files="['$gsm8k_train_path', '$math_train_path']"
test_files="['$gsm8k_test_path', '$math_test_path']"

export VLLM_ATTENTION_BACKEND=XFORMERS # vllm + qwen2-7b with flash_attn has some issues

python3 -m verl.trainer.main_ppo \
data.train_files="$train_files" \
data.val_files="$test_files" \
Expand Down Expand Up @@ -41,7 +43,7 @@ python3 -m verl.trainer.main_ppo \
critic.model.fsdp_config.grad_offload=False \
critic.model.fsdp_config.optimizer_offload=False \
reward_model.enable=True \
reward_model.model.path=sfairXC/FsfairX-Gemma2-RM-v0.1\
reward_model.model.path=sfairXC/FsfairX-LLaMA3-RM-v0.1\
reward_model.model.use_remove_padding=True \
reward_model.model.fsdp_config.param_offload=True \
reward_model.micro_batch_size=16 \
Expand Down
1 change: 1 addition & 0 deletions tests/e2e/arithmetic_sequence/rl/config/ray_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ algorithm:

trainer:
total_epochs: 200
total_training_steps: null
project_name: verl_examples
experiment_name: arithmetic_sequences
logger: ['console']
Expand Down
40 changes: 40 additions & 0 deletions tests/e2e/run_qwen_gsm8k_function_rm.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
set -x

python3 -m verl.trainer.main_ppo \
data.train_files=$HOME/data/gsm8k/train.parquet \
data.val_files=$HOME/data/gsm8k/test.parquet \
data.train_batch_size=1024 \
data.val_batch_size=1312 \
data.max_prompt_length=512 \
data.max_response_length=512 \
actor_rollout_ref.model.path=Qwen/Qwen2.5-0.5B \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.actor.ppo_mini_batch_size=256 \
actor_rollout_ref.actor.ppo_micro_batch_size=32 \
actor_rollout_ref.actor.fsdp_config.param_offload=False \
actor_rollout_ref.actor.fsdp_config.grad_offload=False \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
actor_rollout_ref.rollout.log_prob_micro_batch_size=128 \
actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
actor_rollout_ref.rollout.name=vllm \
actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \
actor_rollout_ref.ref.log_prob_micro_batch_size=128 \
actor_rollout_ref.ref.fsdp_config.param_offload=True \
critic.optim.lr=1e-5 \
critic.model.use_remove_padding=True \
critic.model.path=Qwen/Qwen2.5-0.5B \
critic.model.enable_gradient_checkpointing=False \
critic.ppo_micro_batch_size=32 \
critic.model.fsdp_config.param_offload=False \
critic.model.fsdp_config.grad_offload=False \
critic.model.fsdp_config.optimizer_offload=False \
algorithm.kl_ctrl.kl_coef=0.001 \
trainer.critic_warmup=0 \
trainer.logger=['console'] \
trainer.project_name='verl_example_gsm8k' \
trainer.experiment_name='qwen_e2e_ci_function_rm' \
trainer.n_gpus_per_node=8 \
trainer.nnodes=1 \
trainer.save_freq=-1 \
trainer.total_training_steps=1 $@
40 changes: 40 additions & 0 deletions tests/e2e/run_qwen_gsm8k_function_rm_no_rmpad.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
set -x

python3 -m verl.trainer.main_ppo \
data.train_files=$HOME/data/gsm8k/train.parquet \
data.val_files=$HOME/data/gsm8k/test.parquet \
data.train_batch_size=1024 \
data.val_batch_size=1312 \
data.max_prompt_length=512 \
data.max_response_length=512 \
actor_rollout_ref.model.path=Qwen/Qwen2.5-0.5B \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.model.use_remove_padding=False \
actor_rollout_ref.actor.ppo_mini_batch_size=256 \
actor_rollout_ref.actor.ppo_micro_batch_size=32 \
actor_rollout_ref.actor.fsdp_config.param_offload=False \
actor_rollout_ref.actor.fsdp_config.grad_offload=False \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
actor_rollout_ref.rollout.log_prob_micro_batch_size=128 \
actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
actor_rollout_ref.rollout.name=vllm \
actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \
actor_rollout_ref.ref.log_prob_micro_batch_size=128 \
actor_rollout_ref.ref.fsdp_config.param_offload=True \
critic.optim.lr=1e-5 \
critic.model.use_remove_padding=False \
critic.model.path=Qwen/Qwen2.5-0.5B \
critic.model.enable_gradient_checkpointing=False \
critic.ppo_micro_batch_size=32 \
critic.model.fsdp_config.param_offload=False \
critic.model.fsdp_config.grad_offload=False \
critic.model.fsdp_config.optimizer_offload=False \
algorithm.kl_ctrl.kl_coef=0.001 \
trainer.critic_warmup=0 \
trainer.logger=['console'] \
trainer.project_name='verl_example_gsm8k' \
trainer.experiment_name='qwen_e2e_ci_function_rm' \
trainer.n_gpus_per_node=8 \
trainer.nnodes=1 \
trainer.save_freq=-1 \
trainer.total_training_steps=1 $@
48 changes: 48 additions & 0 deletions tests/e2e/run_qwen_gsm8k_model_rm.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
set -x

python3 -m verl.trainer.main_ppo \
data.train_files=$HOME/data/gsm8k/train.parquet \
data.val_files=$HOME/data/gsm8k/test.parquet \
data.train_batch_size=1024 \
data.val_batch_size=1312 \
data.max_prompt_length=512 \
data.max_response_length=512 \
data.return_raw_chat=True \
actor_rollout_ref.model.path=Qwen/Qwen2.5-0.5B \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.actor.optim.lr_warmup_steps_ratio=0.1 \
actor_rollout_ref.actor.ppo_mini_batch_size=256 \
actor_rollout_ref.actor.ppo_micro_batch_size=32 \
actor_rollout_ref.actor.fsdp_config.param_offload=False \
actor_rollout_ref.actor.fsdp_config.grad_offload=False \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
actor_rollout_ref.rollout.log_prob_micro_batch_size=128 \
actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
actor_rollout_ref.rollout.name=vllm \
actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \
actor_rollout_ref.ref.log_prob_micro_batch_size=128 \
actor_rollout_ref.ref.fsdp_config.param_offload=True \
critic.optim.lr=1e-5 \
critic.model.use_remove_padding=True \
critic.optim.lr_warmup_steps_ratio=0.05 \
critic.model.path=Qwen/Qwen2.5-0.5B \
critic.model.enable_gradient_checkpointing=False \
critic.ppo_micro_batch_size=32 \
critic.model.fsdp_config.param_offload=False \
critic.model.fsdp_config.grad_offload=False \
critic.model.fsdp_config.optimizer_offload=False \
reward_model.enable=True \
reward_model.model.path=Qwen/Qwen2.5-0.5B\
reward_model.model.use_remove_padding=True \
reward_model.model.fsdp_config.param_offload=True \
reward_model.micro_batch_size=16 \
algorithm.kl_ctrl.kl_coef=0.001 \
trainer.critic_warmup=0 \
trainer.logger=['console'] \
trainer.project_name='verl_example' \
trainer.experiment_name='Qwen2.5-0.5B-ci_hybrid_rm' \
trainer.n_gpus_per_node=8 \
trainer.nnodes=1 \
trainer.save_freq=-1 \
trainer.total_training_steps=1 $@
48 changes: 48 additions & 0 deletions tests/e2e/run_qwen_gsm8k_model_rm_no_rmpad.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
set -x

python3 -m verl.trainer.main_ppo \
data.train_files=$HOME/data/gsm8k/train.parquet \
data.val_files=$HOME/data/gsm8k/test.parquet \
data.train_batch_size=1024 \
data.val_batch_size=1312 \
data.max_prompt_length=512 \
data.max_response_length=512 \
data.return_raw_chat=True \
actor_rollout_ref.model.path=Qwen/Qwen2.5-0.5B \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.model.use_remove_padding=False \
actor_rollout_ref.actor.optim.lr_warmup_steps_ratio=0.1 \
actor_rollout_ref.actor.ppo_mini_batch_size=256 \
actor_rollout_ref.actor.ppo_micro_batch_size=32 \
actor_rollout_ref.actor.fsdp_config.param_offload=False \
actor_rollout_ref.actor.fsdp_config.grad_offload=False \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
actor_rollout_ref.rollout.log_prob_micro_batch_size=128 \
actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
actor_rollout_ref.rollout.name=vllm \
actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \
actor_rollout_ref.ref.log_prob_micro_batch_size=128 \
actor_rollout_ref.ref.fsdp_config.param_offload=True \
critic.optim.lr=1e-5 \
critic.model.use_remove_padding=False \
critic.optim.lr_warmup_steps_ratio=0.05 \
critic.model.path=Qwen/Qwen2.5-0.5B \
critic.model.enable_gradient_checkpointing=False \
critic.ppo_micro_batch_size=32 \
critic.model.fsdp_config.param_offload=False \
critic.model.fsdp_config.grad_offload=False \
critic.model.fsdp_config.optimizer_offload=False \
reward_model.enable=True \
reward_model.model.path=Qwen/Qwen2.5-0.5B\
reward_model.model.use_remove_padding=False \
reward_model.model.fsdp_config.param_offload=True \
reward_model.micro_batch_size=16 \
algorithm.kl_ctrl.kl_coef=0.001 \
trainer.critic_warmup=0 \
trainer.logger=['console'] \
trainer.project_name='verl_example' \
trainer.experiment_name='Qwen2.5-0.5B-ci_hybrid_rm' \
trainer.n_gpus_per_node=8 \
trainer.nnodes=1 \
trainer.save_freq=-1 \
trainer.total_training_steps=1 $@
1 change: 1 addition & 0 deletions verl/trainer/config/ppo_megatron_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ algorithm:

trainer:
total_epochs: 30
total_training_steps: null
project_name: verl_examples
experiment_name: gsm8k
logger: ['console', 'wandb']
Expand Down
1 change: 1 addition & 0 deletions verl/trainer/config/ppo_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ algorithm:

trainer:
total_epochs: 30
total_training_steps: null
project_name: verl_examples
experiment_name: gsm8k
logger: [ 'console', 'wandb' ]
Expand Down
37 changes: 23 additions & 14 deletions verl/trainer/ppo/ray_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,12 @@ def _create_dataloader(self):
# inject total_training_steps to actor/critic optim_config. This is hacky.
total_training_steps = len(self.train_dataloader) * self.config.trainer.total_epochs

if self.config.trainer.total_training_steps is not None:
total_training_steps = self.config.trainer.total_training_steps

self.total_training_steps = total_training_steps
print(f'Total training steps: {self.total_training_steps}')

OmegaConf.set_struct(self.config, True)
with open_dict(self.config):
self.config.actor_rollout_ref.actor.optim.total_training_steps = total_training_steps
Expand Down Expand Up @@ -465,14 +471,14 @@ def fit(self):
default_backend=self.config.trainer.logger,
config=OmegaConf.to_container(self.config, resolve=True))

global_steps = 0
self.global_steps = 0

# perform validation before training
# currently, we only support validation using the reward_function.
if self.val_reward_fn is not None and self.config.trainer.get('val_before_train', True):
val_metrics = self._validate()
pprint(f'Initial validation metrics: {val_metrics}')
logger.log(data=val_metrics, step=global_steps)
logger.log(data=val_metrics, step=self.global_steps)
if self.config.trainer.get('val_only', False):
return

Expand Down Expand Up @@ -539,15 +545,15 @@ def fit(self):
metrics.update(critic_output_metrics)

# implement critic warmup
if self.config.trainer.critic_warmup <= global_steps:
if self.config.trainer.critic_warmup <= self.global_steps:
# update actor
with _timer('update_actor', timing_raw):
actor_output = self.actor_rollout_wg.update_actor(batch)
actor_output_metrics = reduce_metrics(actor_output.meta_info['metrics'])
metrics.update(actor_output_metrics)

# validate
if self.val_reward_fn is not None and (global_steps + 1) % self.config.trainer.test_freq == 0:
if self.val_reward_fn is not None and (self.global_steps + 1) % self.config.trainer.test_freq == 0:
with _timer('testing', timing_raw):
val_metrics: dict = self._validate()
val_metrics = {f'val/{key}': val for key, val in val_metrics.items()}
Expand All @@ -558,26 +564,29 @@ def fit(self):
metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw))

# TODO: make a canonical logger that supports various backend
logger.log(data=metrics, step=global_steps)
logger.log(data=metrics, step=self.global_steps)

if self.config.trainer.save_freq > 0 and (global_steps + 1) % self.config.trainer.save_freq == 0:
if self.config.trainer.save_freq > 0 and (self.global_steps + 1) % self.config.trainer.save_freq == 0:
actor_local_path = os.path.join(self.config.trainer.default_local_dir, 'actor',
f'global_step_{global_steps}')
f'global_step_{self.global_steps}')
actor_remote_path = None if self.config.trainer.default_hdfs_dir is None else os.path.join(
self.config.trainer.default_hdfs_dir, 'actor')
self.actor_rollout_wg.save_checkpoint(actor_local_path, actor_remote_path)

if self.use_critic:
critic_local_path = os.path.join(self.config.trainer.default_local_dir, 'critic',
f'global_step_{global_steps}')
f'global_step_{self.global_steps}')
critic_remote_path = None if self.config.trainer.default_hdfs_dir is None else os.path.join(
self.config.trainer.default_hdfs_dir, 'critic')
self.critic_wg.save_checkpoint(critic_local_path, critic_remote_path)

global_steps += 1
self.global_steps += 1

# perform validation after training
if self.val_reward_fn is not None:
val_metrics = self._validate()
pprint(f'Final validation metrics: {val_metrics}')
logger.log(data=val_metrics, step=global_steps)
if self.global_steps >= self.total_training_steps:

# perform validation after training
if self.val_reward_fn is not None:
val_metrics = self._validate()
pprint(f'Final validation metrics: {val_metrics}')
logger.log(data=val_metrics, step=self.global_steps)
return
Loading

0 comments on commit a0e8ed2

Please sign in to comment.