Skip to content

Commit

Permalink
update quickstart
Browse files Browse the repository at this point in the history
  • Loading branch information
eric-haibin-lin committed Dec 18, 2024
1 parent b69af03 commit 21e7354
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 25 deletions.
2 changes: 2 additions & 0 deletions docs/examples/config.rst
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
.. _config-explain-page:

Config Explaination
===================

Expand Down
2 changes: 2 additions & 0 deletions docs/experiment/ppo.rst
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
.. _algo-baseline-page:

Algorithm Baselines
===================

Expand Down
59 changes: 35 additions & 24 deletions docs/start/quickstart.rst
Original file line number Diff line number Diff line change
Expand Up @@ -77,58 +77,69 @@ For mode details, please refer to `verl/utils/reward_score/gsm8k.py <https://git

Now let's run PPO training with the dataset and model above[2]_.

.. code:: bash
bash examples/ppo_trainer/run_deepseek7b_llm.sh

The script of `run_deepseek7b_llm.sh`

- Prepare your own run.sh script. Here's an example for GSM8k dataset
and deepseek-llm-7b-chat model.
- Users could replace the ``data.train_files`` ,\ ``data.val_files``,
``actor_rollout_ref.model.path`` and ``critic.model.path`` based on
their environment.
- See :doc:`examples/config` for detailed explaination of each config field.
Set the ``data.train_files`` ,\ ``data.val_files``, ``actor_rollout_ref.model.path`` and ``critic.model.path`` based on your dataset and model names or paths.

.. code:: bash
python3 -m verl.trainer.main_ppo \
PYTHONUNBUFFERED=1 python3 -m verl.trainer.main_ppo \
data.train_files=$HOME/data/gsm8k/train.parquet \
data.val_files=$HOME/data/gsm8k/test.parquet \
data.train_batch_size=512 \
data.train_batch_size=256 \
data.val_batch_size=1312 \
data.max_prompt_length=512 \
data.max_response_length=256 \
actor_rollout_ref.model.path=Qwen/Qwen2.5-0.5B-Instruct \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.actor.ppo_mini_batch_size=128 \
actor_rollout_ref.actor.ppo_micro_batch_size=1 \
actor_rollout_ref.rollout.log_prob_micro_batch_size=1 \
actor_rollout_ref.actor.ppo_mini_batch_size=64 \
actor_rollout_ref.actor.ppo_micro_batch_size=4 \
actor_rollout_ref.rollout.log_prob_micro_batch_size=4 \
actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \
actor_rollout_ref.ref.log_prob_micro_batch_size=4 \
actor_rollout_ref.ref.fsdp_config.param_offload=False \
critic.optim.lr=1e-5 \
critic.model.path=Qwen/Qwen2.5-0.5B-Instruct \
critic.ppo_micro_batch_size=1 \
critic.ppo_micro_batch_size=4 \
algorithm.kl_ctrl.kl_coef=0.001 \
trainer.logger=['console'] \
+trainer.val_before_train=False \
trainer.n_gpus_per_node=1 \
trainer.nnodes=1 \
trainer.save_freq=10 \
trainer.test_freq=10 \
trainer.total_epochs=15 $@ 2>&1 | tee verl_demo.log
You are expected to see the following logs, indicating training in progress:

.. code::
step:0 - timing/gen:21.470 - timing/ref:4.360 - timing/values:5.800 - critic/kl:0.000 - critic/kl_coeff:0.001 - timing/adv:0.109 - timing/update_critic:15.664 - critic/vf_loss:14.947 - critic/vf_clipfrac:0.000 - critic/vpred_mean:-2.056 - critic/grad_norm:1023.278 - critic/lr(1e-4):0.100 - timing/update_actor:20.314 - actor/entropy_loss:0.433 - actor/pg_loss:-0.005 - actor/pg_clipfrac:0.000 - actor/ppo_kl:0.000 - actor/grad_norm:1.992 - actor/lr(1e-4):0.010 - critic/score/mean:0.004 - critic/score/max:1.000 - critic/score/min:0.000 - critic/rewards/mean:0.004 - critic/rewards/max:1.000 - critic/rewards/min:0.000 - critic/advantages/mean:-0.000 - critic/advantages/max:2.360 - critic/advantages/min:-2.280 - critic/returns/mean:0.003 - critic/returns/max:0.000 - critic/returns/min:0.000 - critic/values/mean:-2.045 - critic/values/max:9.500 - critic/values/min:-14.000 - response_length/mean:239.133 - response_length/max:256.000 - response_length/min:77.000 - prompt_length/mean:104.883 - prompt_length/max:175.000 - prompt_length/min:68.000
step:1 - timing/gen:23.020 - timing/ref:4.322 - timing/values:5.953 - critic/kl:0.000 - critic/kl_coeff:0.001 - timing/adv:0.118 - timing/update_critic:15.646 - critic/vf_loss:18.472 - critic/vf_clipfrac:0.384 - critic/vpred_mean:1.038 - critic/grad_norm:942.924 - critic/lr(1e-4):0.100 - timing/update_actor:20.526 - actor/entropy_loss:0.440 - actor/pg_loss:0.000 - actor/pg_clipfrac:0.002 - actor/ppo_kl:0.000 - actor/grad_norm:2.060 - actor/lr(1e-4):0.010 - critic/score/mean:0.000 - critic/score/max:0.000 - critic/score/min:0.000 - critic/rewards/mean:0.000 - critic/rewards/max:0.000 - critic/rewards/min:0.000 - critic/advantages/mean:0.000 - critic/advantages/max:2.702 - critic/advantages/min:-2.616 - critic/returns/mean:0.000 - critic/returns/max:0.000 - critic/returns/min:0.000 - critic/values/mean:-2.280 - critic/values/max:11.000 - critic/values/min:-16.000 - response_length/mean:232.242 - response_length/max:256.000 - response_length/min:91.000 - prompt_length/mean:102.398 - prompt_length/max:185.000 - prompt_length/min:70.000
Checkout :ref:`algo-baseline-page` for full training and validation logs for reference.

The checkpoint is saved at the following dir by default:

- checkpoints/${trainer.project_name}/${trainer.experiment_name}

To enable ``wandb`` for experiment tracking, set the following configs:

.. code:: bash
trainer.logger=['console','wandb'] \
trainer.project_name=$YOUR_PROJECT_NAME \
trainer.experiment_name=$YOUR_RUN_NAME \
If you encounter out of memory issues, enable the following configs would help:

- actor_rollout_ref.actor.ppo_micro_batch_size=1 \

# checkpoints/${trainer.project_name}/${trainer.experiment_name}
- critic.ppo_micro_batch_size=1 \

# trainer.logger=['console','wandb']
# trainer.project_name='verl_post_training' \
# trainer.experiment_name='gsm8k_function_rm' \
- actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \

# actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
# critic.model.fsdp_config.optimizer_offload=False \
- critic.model.fsdp_config.optimizer_offload=False \

For the full set of configs, please refer to :ref:`config-explain-page` for detailed explaination and performance tuning.


.. [1] The original paper (https://arxiv.org/pdf/2110.14168) mainly focuses on training a verifier (a reward model) to solve math problems via Best-of-N sampling. In this example, we train an RL agent using a rule-based reward model.
Expand Down
2 changes: 1 addition & 1 deletion verl/trainer/ppo/ray_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,7 @@ def fit(self):

# perform validation before training
# currently, we only support validation using the reward_function.
if self.val_reward_fn is not None:
if self.val_reward_fn is not None and self.config.trainer.get('val_before_train', True):
val_metrics = self._validate()
pprint(f'Initial validation metrics: {val_metrics}')
logger.log(data=val_metrics, step=global_steps)
Expand Down

0 comments on commit 21e7354

Please sign in to comment.