From 695bdbb0307cca1a2b4a34d2ade0f1d69cadfeeb Mon Sep 17 00:00:00 2001 From: Guangming Sheng Date: Mon, 27 Jan 2025 21:44:25 +0800 Subject: [PATCH] [misc] fix: gradient accumulation in seq balance and modify default vllm log level (#141) - Previous gradient accumulation value is computed by micro_batch_size, which is wrong when using dynamic_bsz - Fix ci script to avoid overlooking this issue - Change vLLM state log default value to True to disable log. - We will check the `self.config.actor.ppo_mini_batch_size % self.config.actor.ppo_micro_batch_size_per_gpu == 0` after normalization in fsdp_workers instead of in dp_actor and dp_critic. --- .github/workflows/dataset.yml | 2 ++ .github/workflows/e2e_digit_completion.yml | 2 ++ .github/workflows/e2e_gsm8k.yml | 2 ++ .github/workflows/e2e_lora.yml | 2 ++ .github/workflows/e2e_sft.yml | 2 ++ .github/workflows/model.yml | 2 ++ .github/workflows/ray_test.yml | 2 ++ .../arithmetic_sequence/rl/config/ray_trainer.yaml | 2 ++ tests/e2e/run_qwen_gsm8k_model_rm_seq_balance.sh | 4 ---- tests/e2e/run_ray_trainer.sh | 8 ++++---- verl/trainer/config/ppo_megatron_trainer.yaml | 2 ++ verl/trainer/config/ppo_trainer.yaml | 2 ++ verl/workers/actor/dp_actor.py | 9 ++++++--- verl/workers/critic/dp_critic.py | 11 +++++++---- verl/workers/fsdp_workers.py | 2 ++ verl/workers/rollout/vllm_rollout/vllm_rollout.py | 3 ++- 16 files changed, 41 insertions(+), 16 deletions(-) diff --git a/.github/workflows/dataset.yml b/.github/workflows/dataset.yml index 45f26bef..a138337e 100644 --- a/.github/workflows/dataset.yml +++ b/.github/workflows/dataset.yml @@ -16,6 +16,8 @@ on: - "**/*.py" - .github/workflows/dataset.yml + + jobs: ray: runs-on: [self-hosted, gpu] diff --git a/.github/workflows/e2e_digit_completion.yml b/.github/workflows/e2e_digit_completion.yml index bed26203..7b8678e6 100644 --- a/.github/workflows/e2e_digit_completion.yml +++ b/.github/workflows/e2e_digit_completion.yml @@ -17,6 +17,8 @@ on: - .github/workflows/e2e_digit_completion.yml - "tests/e2e/*.sh" + + jobs: e2e_digit_completion: runs-on: [self-hosted, l20-0] diff --git a/.github/workflows/e2e_gsm8k.yml b/.github/workflows/e2e_gsm8k.yml index 3d16d771..1295f689 100644 --- a/.github/workflows/e2e_gsm8k.yml +++ b/.github/workflows/e2e_gsm8k.yml @@ -17,6 +17,8 @@ on: - .github/workflows/e2e_gsm8k.yml - "tests/e2e/*.sh" + + jobs: e2e_gsm8k: runs-on: [self-hosted, l20-1] diff --git a/.github/workflows/e2e_lora.yml b/.github/workflows/e2e_lora.yml index bea3b67b..b2163b5f 100644 --- a/.github/workflows/e2e_lora.yml +++ b/.github/workflows/e2e_lora.yml @@ -17,6 +17,8 @@ on: - .github/workflows/e2e_lora.yml - "tests/e2e/*.sh" + + jobs: e2e_lora: runs-on: [self-hosted, l20-1] diff --git a/.github/workflows/e2e_sft.yml b/.github/workflows/e2e_sft.yml index 4d42430a..4cd6fbe7 100644 --- a/.github/workflows/e2e_sft.yml +++ b/.github/workflows/e2e_sft.yml @@ -17,6 +17,8 @@ on: - .github/workflows/e2e_sft.yml - "tests/e2e/*.sh" + + jobs: e2e_sft: runs-on: [self-hosted, l20-1] diff --git a/.github/workflows/model.yml b/.github/workflows/model.yml index d634c241..6ff7aacb 100644 --- a/.github/workflows/model.yml +++ b/.github/workflows/model.yml @@ -16,6 +16,8 @@ on: - "**/*.py" - .github/workflows/model.yml + + jobs: model_rmpad: runs-on: [self-hosted, l20-1] diff --git a/.github/workflows/ray_test.yml b/.github/workflows/ray_test.yml index 83ec8711..8c63f9d2 100644 --- a/.github/workflows/ray_test.yml +++ b/.github/workflows/ray_test.yml @@ -16,6 +16,8 @@ on: - "**/*.py" - .github/workflows/ray_test.yml + + jobs: ray: runs-on: [self-hosted, l20-0] diff --git a/tests/e2e/arithmetic_sequence/rl/config/ray_trainer.yaml b/tests/e2e/arithmetic_sequence/rl/config/ray_trainer.yaml index d2c5e056..da1294e3 100644 --- a/tests/e2e/arithmetic_sequence/rl/config/ray_trainer.yaml +++ b/tests/e2e/arithmetic_sequence/rl/config/ray_trainer.yaml @@ -78,6 +78,8 @@ actor_rollout_ref: log_prob_micro_batch_size_per_gpu: null log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu} + disable_log_stats: True + enable_chunked_prefill: False # could get higher throughput # for hf rollout do_sample: True # number of responses (i.e. num sample times) diff --git a/tests/e2e/run_qwen_gsm8k_model_rm_seq_balance.sh b/tests/e2e/run_qwen_gsm8k_model_rm_seq_balance.sh index 53b3c8cf..c4a686c6 100644 --- a/tests/e2e/run_qwen_gsm8k_model_rm_seq_balance.sh +++ b/tests/e2e/run_qwen_gsm8k_model_rm_seq_balance.sh @@ -15,13 +15,11 @@ python3 -m verl.trainer.main_ppo \ actor_rollout_ref.model.use_remove_padding=True \ actor_rollout_ref.actor.optim.lr_warmup_steps_ratio=0.1 \ actor_rollout_ref.actor.ppo_mini_batch_size=256 \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \ actor_rollout_ref.actor.use_dynamic_bsz=True \ actor_rollout_ref.actor.ppo_max_token_len_per_gpu=12000 \ actor_rollout_ref.actor.fsdp_config.param_offload=False \ actor_rollout_ref.actor.fsdp_config.grad_offload=False \ actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ - actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \ actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ actor_rollout_ref.rollout.name=vllm \ actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \ @@ -33,7 +31,6 @@ python3 -m verl.trainer.main_ppo \ critic.optim.lr_warmup_steps_ratio=0.05 \ critic.model.path=Qwen/Qwen2.5-0.5B \ critic.model.enable_gradient_checkpointing=False \ - critic.ppo_micro_batch_size_per_gpu=4 \ critic.use_dynamic_bsz=True \ critic.ppo_max_token_len_per_gpu=98304 \ critic.model.fsdp_config.param_offload=False \ @@ -43,7 +40,6 @@ python3 -m verl.trainer.main_ppo \ reward_model.model.path=Qwen/Qwen2.5-0.5B\ reward_model.model.use_remove_padding=True \ reward_model.model.fsdp_config.param_offload=True \ - reward_model.micro_batch_size_per_gpu=16 \ reward_model.use_dynamic_bsz=True \ reward_model.forward_max_token_len_per_gpu=98304 \ algorithm.kl_ctrl.kl_coef=0.001 \ diff --git a/tests/e2e/run_ray_trainer.sh b/tests/e2e/run_ray_trainer.sh index f06597cc..30457e64 100644 --- a/tests/e2e/run_ray_trainer.sh +++ b/tests/e2e/run_ray_trainer.sh @@ -11,10 +11,10 @@ python3 tests/e2e/arithmetic_sequence/rl/main_trainer.py \ data.train_files=tests/e2e/arithmetic_sequence/data/train.parquet \ data.val_files=tests/e2e/arithmetic_sequence/data/test.parquet \ actor_rollout_ref.model.path=tests/e2e/arithmetic_sequence/model \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1 \ - actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=1 \ - actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=1 \ - critic.ppo_micro_batch_size_per_gpu=1 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=200 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=200 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=200 \ + critic.ppo_micro_batch_size_per_gpu=200 \ critic.model.path=tests/e2e/arithmetic_sequence/model | tee $OUTPUT_FILE; python3 tests/e2e/check_results.py --output_file=$OUTPUT_FILE diff --git a/verl/trainer/config/ppo_megatron_trainer.yaml b/verl/trainer/config/ppo_megatron_trainer.yaml index 368d6512..fce8a89a 100644 --- a/verl/trainer/config/ppo_megatron_trainer.yaml +++ b/verl/trainer/config/ppo_megatron_trainer.yaml @@ -71,6 +71,8 @@ actor_rollout_ref: max_num_seqs: 1024 log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu log_prob_micro_batch_size_per_gpu: null + disable_log_stats: True + enable_chunked_prefill: False # could get higher throughput # for hf rollout do_sample: True layer_name_map: diff --git a/verl/trainer/config/ppo_trainer.yaml b/verl/trainer/config/ppo_trainer.yaml index 6dd688a0..04561cbe 100644 --- a/verl/trainer/config/ppo_trainer.yaml +++ b/verl/trainer/config/ppo_trainer.yaml @@ -80,6 +80,8 @@ actor_rollout_ref: log_prob_micro_batch_size_per_gpu: null log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu} + disable_log_stats: True + enable_chunked_prefill: False # could get higher throughput # for hf rollout do_sample: True # number of responses (i.e. num sample times) diff --git a/verl/workers/actor/dp_actor.py b/verl/workers/actor/dp_actor.py index 3dde8dd3..cedfd53c 100644 --- a/verl/workers/actor/dp_actor.py +++ b/verl/workers/actor/dp_actor.py @@ -204,8 +204,6 @@ def update_policy(self, data: DataProto): # make sure we are in training mode self.actor_module.train() - assert self.config.ppo_mini_batch_size % self.config.ppo_micro_batch_size_per_gpu == 0 - self.gradient_accumulation = self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu temperature = data.meta_info['temperature'] # temperature must be in the data.meta_info to avoid slient error select_keys = ['responses', 'input_ids', 'attention_mask', 'position_ids', 'old_log_probs', 'advantages'] @@ -225,6 +223,7 @@ def update_policy(self, data: DataProto): max_token_len = self.config.ppo_max_token_len_per_gpu * self.ulysses_sequence_parallel_size micro_batches, _ = rearrange_micro_batches(batch=mini_batch, max_token_len=max_token_len) else: + self.gradient_accumulation = self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu # split batch into micro_batches micro_batches = mini_batch.split(self.config.ppo_micro_batch_size_per_gpu) @@ -268,7 +267,11 @@ def update_policy(self, data: DataProto): metrics['actor/kl_loss'] = kl_loss.detach().item() metrics['actor/kl_coef'] = self.config.kl_loss_coef - loss = policy_loss / self.gradient_accumulation + if self.config.use_dynamic_bsz: + # relative to the dynamic bsz + loss = policy_loss * (len(data) / self.config.ppo_mini_batch_size) + else: + loss = policy_loss / self.gradient_accumulation loss.backward() data = { diff --git a/verl/workers/critic/dp_critic.py b/verl/workers/critic/dp_critic.py index bdad6ecf..f2eb44c2 100644 --- a/verl/workers/critic/dp_critic.py +++ b/verl/workers/critic/dp_critic.py @@ -45,9 +45,6 @@ def __init__(self, config, critic_module: nn.Module, critic_optimizer: optim.Opt self.use_remove_padding = self.config.model.get('use_remove_padding', False) print(f'Critic use_remove_padding={self.use_remove_padding}') - assert self.config.ppo_mini_batch_size % self.config.ppo_micro_batch_size_per_gpu == 0 - self.gradient_accumulation = self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu - self.ulysses_sequence_parallel_size = self.config.get('ulysses_sequence_parallel_size', 1) def _forward_micro_batch(self, micro_batch): @@ -162,6 +159,7 @@ def update_critic(self, data: DataProto): micro_batches, _ = rearrange_micro_batches(batch=mini_batch, max_token_len=max_token_len) else: micro_batches = mini_batch.split(self.config.ppo_micro_batch_size_per_gpu) + self.gradient_accumulation = self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu self.critic_optimizer.zero_grad() @@ -186,7 +184,12 @@ def update_critic(self, data: DataProto): returns=returns, eos_mask=eos_mask, cliprange_value=self.config.cliprange_value) - loss = vf_loss / self.gradient_accumulation + if self.config.use_dynamic_bsz: + # relative to the dynamic bsz + loss = vf_loss * (len(data) / self.config.ppo_mini_batch_size) + else: + loss = vf_loss / self.gradient_accumulation + loss.backward() data = { diff --git a/verl/workers/fsdp_workers.py b/verl/workers/fsdp_workers.py index fb3a406c..50fc04e9 100644 --- a/verl/workers/fsdp_workers.py +++ b/verl/workers/fsdp_workers.py @@ -125,6 +125,7 @@ def __init__(self, config: DictConfig, role: str): self.config.actor.ppo_micro_batch_size //= (self.device_mesh.shape[0] // self.ulysses_sequence_parallel_size) self.config.actor.ppo_micro_batch_size_per_gpu = self.config.actor.ppo_micro_batch_size + assert self.config.actor.ppo_mini_batch_size % self.config.actor.ppo_micro_batch_size_per_gpu == 0 # normalize rollout config if self._is_rollout and self.config.rollout.log_prob_micro_batch_size is not None: self.config.rollout.log_prob_micro_batch_size //= (self.device_mesh.shape[0] // @@ -582,6 +583,7 @@ def __init__(self, config): self.ulysses_sequence_parallel_size) self.config.ppo_micro_batch_size_per_gpu = self.config.ppo_micro_batch_size self.config.forward_micro_batch_size_per_gpu = self.config.forward_micro_batch_size + assert self.config.ppo_mini_batch_size % self.config.ppo_micro_batch_size_per_gpu == 0 def _build_critic_model_optimizer(self, config): # the following line is necessary diff --git a/verl/workers/rollout/vllm_rollout/vllm_rollout.py b/verl/workers/rollout/vllm_rollout/vllm_rollout.py index 8bd430f3..7014ff1e 100644 --- a/verl/workers/rollout/vllm_rollout/vllm_rollout.py +++ b/verl/workers/rollout/vllm_rollout/vllm_rollout.py @@ -100,8 +100,9 @@ def __init__(self, actor_module: nn.Module, config: DictConfig, tokenizer, model skip_tokenizer_init=False, max_model_len=config.prompt_length + config.response_length, load_format=config.load_format, - disable_log_stats=False, + disable_log_stats=config.disable_log_stats, max_num_batched_tokens=max_num_batched_tokens, + enable_chunked_prefill=config.enable_chunked_prefill, ) # Offload vllm model to reduce peak memory usage