From c9b5086de04b87efd52e581f849b7564f169d8e2 Mon Sep 17 00:00:00 2001 From: shengguangming Date: Wed, 11 Dec 2024 18:17:00 +0800 Subject: [PATCH 1/2] [example] add a split placement tutorial --- examples/split_placement/README.md | 61 ++++++ .../config/ppo_trainer_split.yaml | 131 ++++++++++++ examples/split_placement/main_ppo_split.py | 199 ++++++++++++++++++ .../split_placement/run_deepseek7b_llm.sh | 38 ++++ .../split_placement/split_monkey_patch.py | 161 ++++++++++++++ verl/trainer/ppo/ray_trainer.py | 2 +- 6 files changed, 591 insertions(+), 1 deletion(-) create mode 100644 examples/split_placement/README.md create mode 100644 examples/split_placement/config/ppo_trainer_split.yaml create mode 100644 examples/split_placement/main_ppo_split.py create mode 100644 examples/split_placement/run_deepseek7b_llm.sh create mode 100644 examples/split_placement/split_monkey_patch.py diff --git a/examples/split_placement/README.md b/examples/split_placement/README.md new file mode 100644 index 00000000..a5e4ffde --- /dev/null +++ b/examples/split_placement/README.md @@ -0,0 +1,61 @@ +# Split Placement Example +Here we introduce how to run the naive implementation of the split placement of PPO algorithm. +We will release the complete version of flexible placement in the near future. + + For quickstart, you can only follow Step 2 to modify the code and then follow Step 4 to execute the split placement example. + +### Step 1: Placing the models to different GPUs +Specify the placement and resource allocation. In the example, we place the actor and reference in the first half of the GPUs while map the critic and reward model (if any) to the second half of the GPUs. +```python +actor_rollout_ref_pool_id = 'actor_rollout_ref_pool' +critic_pool_id = 'critic_pool' +if config.trainer.nnodes // 2 == 0 and config.trainer.n_gpus_per_node // 2 > 0: + resource_pool_spec = { + actor_rollout_ref_pool_id: [config.trainer.n_gpus_per_node // 2] * config.trainer.nnodes, + critic_pool_id: [config.trainer.n_gpus_per_node // 2] * config.trainer.nnodes, + } +else: + resource_pool_spec = { + actor_rollout_ref_pool_id: [config.trainer.n_gpus_per_node] * (config.trainer.nnodes // 2), + critic_pool_id: [config.trainer.n_gpus_per_node] * (config.trainer.nnodes // 2), + } +print(f'resource_pool_spec: {resource_pool_spec}') +mapping = { + Role.ActorRollout: actor_rollout_ref_pool_id, + Role.Critic: critic_pool_id, + Role.RefPolicy: actor_rollout_ref_pool_id, +} +mapping[Role.RewardModel] = critic_pool_id +``` + +### Step 2: Make the models executed asynchronously +Based on the model placement, we need to make the models executed asynchronously. + +To do so, you need to turn off the `blocking` flag (i.e., `blocking=False`) in our decorator of some model operations. +For example, we hope the actor update and critic update can be executed in parallel, then we need to make the following modification in `fsdp_workers.py` + +``` +@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO, blocking=False) +def update_actor(self, data: DataProto): + ... + +@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO, blocking=False) +def update_critic(self, data: DataProto): + ... +``` + +We can also parallelize the computation of `ref_log_prob` and `values` and `rewards` in the split placement. For simplicity of the tutorial, we + +### Step 3: Execute these operation in parallel in the single controller process +To implement the parallel execution of the actor and critic update, the only thing we need to modify in the `ray_trainer.py` is to `get` the concurrent `futures` on the single controller process. + +```python +critic_output = critic_output.get() +actor_output = actor_output.get() +``` + +### Step 4: Run the split placement example + +``` +bash run_deepseek7b_llm.sh +``` \ No newline at end of file diff --git a/examples/split_placement/config/ppo_trainer_split.yaml b/examples/split_placement/config/ppo_trainer_split.yaml new file mode 100644 index 00000000..bd6bcf22 --- /dev/null +++ b/examples/split_placement/config/ppo_trainer_split.yaml @@ -0,0 +1,131 @@ +data: + tokenizer: null + train_files: ~/data/rlhf/gsm8k/train.parquet + val_files: ~/data/rlhf/gsm8k/test.parquet + prompt_key: prompt + max_prompt_length: 512 + max_response_length: 512 + train_batch_size: 1024 + val_batch_size: 1312 + return_raw_input_ids: False # This should be set to true when the tokenizer between policy and rm differs + return_raw_chat: False + +actor_rollout_ref: + hybrid_engine: True + model: + path: ~/models/deepseek-llm-7b-chat + external_lib: null + override_config: {} + enable_gradient_checkpointing: False + actor: + strategy: fsdp # This is for backward-compatibility + ppo_mini_batch_size: 256 + ppo_micro_batch_size: 64 + grad_clip: 1.0 + clip_ratio: 0.2 + entropy_coeff: 0.001 + ppo_epochs: 1 + shuffle: True + optim: + lr: 1e-6 + lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime + min_lr_ratio: null # only useful for warmup with cosine + warmup_style: constant # select from constant/cosine + total_training_steps: -1 # must be override by program + fsdp_config: + wrap_policy: + # transformer_layer_cls_to_wrap: None + min_num_params: 0 + param_offload: False + grad_offload: False + optimizer_offload: False + ref: + fsdp_config: + param_offload: False + wrap_policy: + # transformer_layer_cls_to_wrap: None + min_num_params: 0 + log_prob_micro_batch_size: 128 + rollout: + name: vllm + temperature: 1.0 + top_k: -1 # 0 for hf rollout, -1 for vllm rollout + top_p: 1 + prompt_length: ${data.max_prompt_length} # not use for opensource + response_length: ${data.max_response_length} + # for vllm rollout + dtype: bfloat16 # should align with FSDP + gpu_memory_utilization: 0.5 + ignore_eos: False + enforce_eager: True + free_cache_engine: True + load_format: dummy_dtensor + tensor_model_parallel_size: 2 + max_num_batched_tokens: 8192 + max_num_seqs: 1024 + log_prob_micro_batch_size: 128 + # for hf rollout + do_sample: True + +critic: + strategy: fsdp + optim: + lr: 1e-5 + lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime + min_lr_ratio: null # only useful for warmup with cosine + warmup_style: constant # select from constant/cosine + total_training_steps: -1 # must be override by program + model: + path: ~/models/deepseek-llm-7b-chat + tokenizer_path: ${actor_rollout_ref.model.path} + override_config: {} + external_lib: ${actor_rollout_ref.model.external_lib} + enable_gradient_checkpointing: False + fsdp_config: + param_offload: False + grad_offload: False + optimizer_offload: False + wrap_policy: + # transformer_layer_cls_to_wrap: None + min_num_params: 0 + ppo_mini_batch_size: ${actor_rollout_ref.actor.ppo_mini_batch_size} + ppo_micro_batch_size: 64 + ppo_epochs: ${actor_rollout_ref.actor.ppo_epochs} + shuffle: ${actor_rollout_ref.actor.shuffle} + grad_clip: 1.0 + cliprange_value: 0.5 + +reward_model: + enable: False + strategy: fsdp + model: + input_tokenizer: ${actor_rollout_ref.model.path} # set this to null if the chat template is identical + path: ~/models/FsfairX-LLaMA3-RM-v0.1 + external_lib: ${actor_rollout_ref.model.external_lib} + fsdp_config: + min_num_params: 0 + param_offload: False + micro_batch_size: 64 + max_length: null + +algorithm: + gamma: 1.0 + lam: 1.0 + adv_estimator: gae + kl_penalty: kl # how to estimate kl divergence + kl_ctrl: + type: fixed + kl_coef: 0.001 + +trainer: + total_epochs: 30 + project_name: verl_examples + experiment_name: gsm8k + logger: ['console', 'tracking'] + nnodes: 1 + n_gpus_per_node: 8 + save_freq: -1 + test_freq: 2 + critic_warmup: 0 + default_hdfs_dir: ~/experiments/gsm8k/ppo/${trainer.experiment_name} + default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name} diff --git a/examples/split_placement/main_ppo_split.py b/examples/split_placement/main_ppo_split.py new file mode 100644 index 00000000..fb607f9b --- /dev/null +++ b/examples/split_placement/main_ppo_split.py @@ -0,0 +1,199 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Note that we don't combine the main with ray_trainer as ray_trainer is used by other main. +""" + +from verl import DataProto +import torch +from verl.utils.reward_score import gsm8k, math +from verl.trainer.ppo.ray_trainer import RayPPOTrainer + + +def _select_rm_score_fn(data_source): + if data_source == 'openai/gsm8k': + return gsm8k.compute_score + elif data_source == 'lighteval/MATH': + return math.compute_score + else: + raise NotImplementedError + + +class RewardManager(): + + def __init__(self, tokenizer, num_examine) -> None: + self.tokenizer = tokenizer + self.num_examine = num_examine # the number of batches of decoded responses to print to the console + + def __call__(self, data: DataProto): + """We will expand this function gradually based on the available datasets""" + + # If there is rm score, we directly return rm score. Otherwise, we compute via rm_score_fn + if 'rm_scores' in data.batch.keys(): + return data.batch['rm_scores'] + + reward_tensor = torch.zeros_like(data.batch['responses'], dtype=torch.float32) + + already_print_data_sources = {} + + for i in range(len(data)): + data_item = data[i] # DataProtoItem + + prompt_ids = data_item.batch['prompts'] + + prompt_length = prompt_ids.shape[-1] + + valid_prompt_length = data_item.batch['attention_mask'][:prompt_length].sum() + valid_prompt_ids = prompt_ids[-valid_prompt_length:] + + response_ids = data_item.batch['responses'] + valid_response_length = data_item.batch['attention_mask'][prompt_length:].sum() + valid_response_ids = response_ids[:valid_response_length] + + # decode + sequences = torch.cat((valid_prompt_ids, valid_response_ids)) + sequences_str = self.tokenizer.decode(sequences) + + ground_truth = data_item.non_tensor_batch['reward_model']['ground_truth'] + + # select rm_score + data_source = data_item.non_tensor_batch['data_source'] + compute_score_fn = _select_rm_score_fn(data_source) + + score = compute_score_fn(solution_str=sequences_str, ground_truth=ground_truth) + reward_tensor[i, valid_response_length - 1] = score + + if data_source not in already_print_data_sources: + already_print_data_sources[data_source] = 0 + + if already_print_data_sources[data_source] < self.num_examine: + already_print_data_sources[data_source] += 1 + print(sequences_str) + + return reward_tensor + + +import ray +import hydra +from split_monkey_patch import fit + +@hydra.main(config_path='config', config_name='ppo_trainer_split', version_base=None) +def main(config): + if not ray.is_initialized(): + # this is for local ray cluster + ray.init(runtime_env={'env_vars': {'TOKENIZERS_PARALLELISM': 'true', 'NCCL_DEBUG': 'WARN'}}) + + ray.get(main_task.remote(config)) + + +@ray.remote +def main_task(config): + from verl.utils.fs import copy_local_path_from_hdfs + from transformers import AutoTokenizer + + # print initial config + from pprint import pprint + from omegaconf import OmegaConf + pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values + OmegaConf.resolve(config) + + # download the checkpoint from hdfs + local_path = copy_local_path_from_hdfs(config.actor_rollout_ref.model.path) + + # instantiate tokenizer + tokenizer = AutoTokenizer.from_pretrained(local_path) + from verl.utils import set_pad_token_id + set_pad_token_id(tokenizer) + + # define worker classes + if config.actor_rollout_ref.actor.strategy == 'fsdp': + assert config.actor_rollout_ref.actor.strategy == config.critic.strategy + from verl.trainer.ppo.workers.fsdp_workers import ActorRolloutRefWorker, CriticWorker + from single_controller.ray import RayWorkerGroup + ray_worker_group_cls = RayWorkerGroup + + elif config.actor_rollout_ref.actor.strategy == 'megatron': + assert config.actor_rollout_ref.actor.strategy == config.critic.strategy + from verl.trainer.ppo.workers.megatron_workers import ActorRolloutRefWorker, CriticWorker + from single_controller.ray.megatron import NVMegatronRayWorkerGroup + ray_worker_group_cls = NVMegatronRayWorkerGroup + + else: + raise NotImplementedError + + from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role + + role_worker_mapping = { + Role.ActorRollout: ActorRolloutRefWorker, + Role.Critic: CriticWorker, + Role.RefPolicy: ActorRolloutRefWorker + } + + # NOTE: initialze two resource pool + actor_rollout_ref_pool_id = 'actor_rollout_ref_pool' + critic_pool_id = 'critic_pool' + if config.trainer.nnodes // 2 == 0 and config.trainer.n_gpus_per_node // 2 > 0: + resource_pool_spec = { + actor_rollout_ref_pool_id: [config.trainer.n_gpus_per_node // 2] * config.trainer.nnodes, + critic_pool_id: [config.trainer.n_gpus_per_node // 2] * config.trainer.nnodes, + } + else: + resource_pool_spec = { + actor_rollout_ref_pool_id: [config.trainer.n_gpus_per_node] * (config.trainer.nnodes // 2), + critic_pool_id: [config.trainer.n_gpus_per_node] * (config.trainer.nnodes // 2), + } + print(f'resource_pool_spec: {resource_pool_spec}') + mapping = { + Role.ActorRollout: actor_rollout_ref_pool_id, + Role.Critic: critic_pool_id, + Role.RefPolicy: actor_rollout_ref_pool_id, + } + + # we should adopt a multi-source reward function here + # - for rule-based rm, we directly call a reward score + # - for model-based rm, we call a model + # - for code related prompt, we send to a sandbox if there are test cases + # - finally, we combine all the rewards together + # - The reward type depends on the tag of the data + if config.reward_model.enable: + if config.reward_model.strategy == 'fsdp': + from verl.trainer.ppo.workers.fsdp_workers import RewardModelWorker + elif config.reward_model.strategy == 'megatron': + from verl.trainer.ppo.workers.megatron_workers import RewardModelWorker + else: + raise NotImplementedError + role_worker_mapping[Role.RewardModel] = RewardModelWorker + mapping[Role.RewardModel] = critic_pool_id + + reward_fn = RewardManager(tokenizer=tokenizer, num_examine=0) + + # Note that we always use function-based RM for validation + val_reward_fn = RewardManager(tokenizer=tokenizer, num_examine=1) + + resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping) + + RayPPOTrainer.fit = fit + trainer = RayPPOTrainer(config=config, + tokenizer=tokenizer, + role_worker_mapping=role_worker_mapping, + resource_pool_manager=resource_pool_manager, + ray_worker_group_cls=ray_worker_group_cls, + reward_fn=reward_fn, + val_reward_fn=val_reward_fn) + trainer.init_workers() + trainer.fit() + + +if __name__ == '__main__': + main() diff --git a/examples/split_placement/run_deepseek7b_llm.sh b/examples/split_placement/run_deepseek7b_llm.sh new file mode 100644 index 00000000..6afd3994 --- /dev/null +++ b/examples/split_placement/run_deepseek7b_llm.sh @@ -0,0 +1,38 @@ +set -x + +python3 main_ppo_split.py \ + data.train_files=$HOME/data/gsm8k/train.parquet \ + data.val_files=$HOME/data/gsm8k/test.parquet \ + data.train_batch_size=1024 \ + data.val_batch_size=1312 \ + data.max_prompt_length=512 \ + data.max_response_length=512 \ + actor_rollout_ref.model.path=deepseek-ai/deepseek-llm-7b-chat \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.ppo_mini_batch_size=256 \ + actor_rollout_ref.actor.ppo_micro_batch_size=16 \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.grad_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.rollout.log_prob_micro_batch_size=32 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=4 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \ + actor_rollout_ref.ref.log_prob_micro_batch_size=32 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + critic.optim.lr=1e-5 \ + critic.model.path=deepseek-ai/deepseek-llm-7b-chat \ + critic.model.enable_gradient_checkpointing=False \ + critic.ppo_micro_batch_size=16 \ + critic.model.fsdp_config.param_offload=False \ + critic.model.fsdp_config.grad_offload=False \ + critic.model.fsdp_config.optimizer_offload=False \ + algorithm.kl_ctrl.kl_coef=0.001 \ + trainer.critic_warmup=0 \ + trainer.logger=['console','tracking'] \ + trainer.project_name='verl_example_gsm8k' \ + trainer.experiment_name='deepseek_llm_7b_function_rm' \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=1 \ + trainer.save_freq=-1 \ + trainer.total_epochs=15 $@ diff --git a/examples/split_placement/split_monkey_patch.py b/examples/split_placement/split_monkey_patch.py new file mode 100644 index 00000000..19576b11 --- /dev/null +++ b/examples/split_placement/split_monkey_patch.py @@ -0,0 +1,161 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +An naive implementation of split placment example +""" +import os +from pprint import pprint +from single_controller.ray import RayResourcePool, RayWorkerGroup, RayClassWithInitArgs +from verl import DataProto +from verl.trainer.ppo.ray_trainer import compute_advantage, apply_kl_penalty, reduce_metrics, compute_data_metrics, Role, create_colocated_worker_cls +from codetiming import Timer + + +def fit(self): + """ + The training loop of PPO. + The driver process only need to call the compute functions of the worker group through RPC to construct the PPO dataflow. + The light-weight advantage computation is done on the driver process. + """ + from verl.utils.tracking import Tracking + from omegaconf import OmegaConf + + logger = Tracking(project_name=self.config.trainer.project_name, + experiment_name=self.config.trainer.experiment_name, + default_backend=self.config.trainer.logger, + config=OmegaConf.to_container(self.config, resolve=True)) + + global_steps = 0 + + # perform validation before training + # currently, we only support validation using the reward_function. + if self.val_reward_fn is not None: + val_metrics = self._validate() + pprint(f'Initial validation metrics: {val_metrics}') + + for epoch in range(self.config.trainer.total_epochs): + for batch_dict in self.train_dataloader: + metrics = {} + + batch: DataProto = DataProto.from_single_dict(batch_dict) + # batch = batch.to('cuda') + + # pop those keys for generation + gen_batch = batch.pop(batch_keys=['input_ids', 'attention_mask', 'position_ids']) + + # generate a batch + with Timer(name='gen', logger=None) as timer: + gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch) + metrics['timing/gen'] = timer.last + + batch = batch.union(gen_batch_output) + + if self.use_reference_policy: + # compute reference log_prob + with Timer(name='ref', logger=None) as timer: + ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch) + batch = batch.union(ref_log_prob) + metrics['timing/ref'] = timer.last + + # compute values + with Timer(name='values', logger=None) as timer: + values = self.critic_wg.compute_values(batch) + batch = batch.union(values) + metrics['timing/values'] = timer.last + + with Timer(name='adv', logger=None) as timer: + # compute scores. Support both model and function-based. + # We first compute the scores using reward model. Then, we call reward_fn to combine + # the results from reward model and rule-based results. + if self.use_rm: + # we first compute reward model score + reward_tensor = self.rm_wg.compute_rm_score(batch) + batch = batch.union(reward_tensor) + + # we combine with rule-based rm + reward_tensor = self.reward_fn(batch) + batch.batch['token_level_scores'] = reward_tensor + + # compute rewards. apply_kl_penalty if available + batch, kl_metrics = apply_kl_penalty(batch, + kl_ctrl=self.kl_ctrl, + kl_penalty=self.config.algorithm.kl_penalty) + metrics.update(kl_metrics) + + # compute advantages, executed on the driver process + batch = compute_advantage(batch, + self.config.algorithm.gamma, + self.config.algorithm.lam, + adv_estimator=self.config.algorithm.adv_estimator) + metrics['timing/adv'] = timer.last + + # update critic + if self.use_critic: + with Timer(name='update_critic_call', logger=None) as timer: + critic_output = self.critic_wg.update_critic(batch) + metrics['timing/update_critic_call'] = timer.last + + # implement critic warmup + if self.config.trainer.critic_warmup <= global_steps: + # update actor + with Timer(name='update_actor_call', logger=None) as timer: + actor_output = self.actor_rollout_wg.update_actor(batch) + metrics['timing/update_acto_call'] = timer.last + + # NOTE: make sure you set blocking=False in update_actor and update_crtic in the worker class + with Timer(name='update_actor_critic', logger=None) as timer: + # NOTE: get the DataProtoFuture + critic_output = critic_output.get() + critic_output_metrics = reduce_metrics(critic_output.meta_info['metrics']) + metrics.update(critic_output_metrics) + + # NOTE: get the DataProtoFuture + actor_output = actor_output.get() + actor_output_metrics = reduce_metrics(actor_output.meta_info['metrics']) + metrics.update(actor_output_metrics) + metrics['timing/update_actor_critic'] = timer.last + + # validate + if self.val_reward_fn is not None and (global_steps + 1) % self.config.trainer.test_freq == 0: + with Timer(name='testing', logger=None) as timer: + val_metrics: dict = self._validate() + val_metrics = {f'val/{key}': val for key, val in val_metrics.items()} + metrics['timing/testing'] = timer.last + metrics.update(val_metrics) + + # collect metrics + data_metrics = compute_data_metrics(batch=batch) + metrics.update(data_metrics) + + # TODO: make a canonical logger that supports various backend + logger.log(data=metrics, step=global_steps) + + if self.config.trainer.save_freq > 0 and (global_steps + 1) % self.config.trainer.save_freq == 0: + actor_local_path = os.path.join(self.config.trainer.default_local_dir, 'actor', + f'global_step_{global_steps}') + actor_remote_path = os.path.join(self.config.trainer.default_hdfs_dir, 'actor') + self.actor_rollout_wg.save_checkpoint(actor_local_path, actor_remote_path) + + if self.use_critic: + critic_local_path = os.path.join(self.config.trainer.default_local_dir, 'critic', + f'global_step_{global_steps}') + critic_remote_path = os.path.join(self.config.trainer.default_hdfs_dir, 'critic') + self.critic_wg.save_checkpoint(critic_local_path, critic_remote_path) + + global_steps += 1 + + # perform validation after training + if self.val_reward_fn is not None: + val_metrics = self._validate() + pprint(f'Final validation metrics: {val_metrics}') diff --git a/verl/trainer/ppo/ray_trainer.py b/verl/trainer/ppo/ray_trainer.py index 1316e141..cc4a0fd5 100644 --- a/verl/trainer/ppo/ray_trainer.py +++ b/verl/trainer/ppo/ray_trainer.py @@ -63,7 +63,7 @@ def create_resource_pool(self): # Due to the Ray issue, we can only support max_colocate_count=1 for now. # This means that each GPU can only have one process. # We can support max_colocate > 1 when applying this pull request: https://github.com/ray-project/ray/pull/44385 - resource_pool = RayResourcePool(process_on_nodes=process_on_nodes, use_gpu=True, max_colocate_count=1) + resource_pool = RayResourcePool(process_on_nodes=process_on_nodes, use_gpu=True, max_colocate_count=1, name_prefix=resource_pool_name) self.resource_pool_dict[resource_pool_name] = resource_pool def get_resource_pool(self, role: Role) -> RayResourcePool: From 00540ef4d2160031893f5a62831a2c2ba8fbfd95 Mon Sep 17 00:00:00 2001 From: shengguangming Date: Wed, 11 Dec 2024 18:21:35 +0800 Subject: [PATCH 2/2] lint --- examples/split_placement/main_ppo_split.py | 1 + .../split_placement/split_monkey_patch.py | 20 +++++++++---------- verl/trainer/ppo/ray_trainer.py | 5 ++++- 3 files changed, 15 insertions(+), 11 deletions(-) diff --git a/examples/split_placement/main_ppo_split.py b/examples/split_placement/main_ppo_split.py index fb607f9b..524f35e5 100644 --- a/examples/split_placement/main_ppo_split.py +++ b/examples/split_placement/main_ppo_split.py @@ -88,6 +88,7 @@ def __call__(self, data: DataProto): import hydra from split_monkey_patch import fit + @hydra.main(config_path='config', config_name='ppo_trainer_split', version_base=None) def main(config): if not ray.is_initialized(): diff --git a/examples/split_placement/split_monkey_patch.py b/examples/split_placement/split_monkey_patch.py index 19576b11..70ed267d 100644 --- a/examples/split_placement/split_monkey_patch.py +++ b/examples/split_placement/split_monkey_patch.py @@ -32,9 +32,9 @@ def fit(self): from omegaconf import OmegaConf logger = Tracking(project_name=self.config.trainer.project_name, - experiment_name=self.config.trainer.experiment_name, - default_backend=self.config.trainer.logger, - config=OmegaConf.to_container(self.config, resolve=True)) + experiment_name=self.config.trainer.experiment_name, + default_backend=self.config.trainer.logger, + config=OmegaConf.to_container(self.config, resolve=True)) global_steps = 0 @@ -89,15 +89,15 @@ def fit(self): # compute rewards. apply_kl_penalty if available batch, kl_metrics = apply_kl_penalty(batch, - kl_ctrl=self.kl_ctrl, - kl_penalty=self.config.algorithm.kl_penalty) + kl_ctrl=self.kl_ctrl, + kl_penalty=self.config.algorithm.kl_penalty) metrics.update(kl_metrics) # compute advantages, executed on the driver process batch = compute_advantage(batch, - self.config.algorithm.gamma, - self.config.algorithm.lam, - adv_estimator=self.config.algorithm.adv_estimator) + self.config.algorithm.gamma, + self.config.algorithm.lam, + adv_estimator=self.config.algorithm.adv_estimator) metrics['timing/adv'] = timer.last # update critic @@ -112,7 +112,7 @@ def fit(self): with Timer(name='update_actor_call', logger=None) as timer: actor_output = self.actor_rollout_wg.update_actor(batch) metrics['timing/update_acto_call'] = timer.last - + # NOTE: make sure you set blocking=False in update_actor and update_crtic in the worker class with Timer(name='update_actor_critic', logger=None) as timer: # NOTE: get the DataProtoFuture @@ -149,7 +149,7 @@ def fit(self): if self.use_critic: critic_local_path = os.path.join(self.config.trainer.default_local_dir, 'critic', - f'global_step_{global_steps}') + f'global_step_{global_steps}') critic_remote_path = os.path.join(self.config.trainer.default_hdfs_dir, 'critic') self.critic_wg.save_checkpoint(critic_local_path, critic_remote_path) diff --git a/verl/trainer/ppo/ray_trainer.py b/verl/trainer/ppo/ray_trainer.py index cc4a0fd5..95814f57 100644 --- a/verl/trainer/ppo/ray_trainer.py +++ b/verl/trainer/ppo/ray_trainer.py @@ -63,7 +63,10 @@ def create_resource_pool(self): # Due to the Ray issue, we can only support max_colocate_count=1 for now. # This means that each GPU can only have one process. # We can support max_colocate > 1 when applying this pull request: https://github.com/ray-project/ray/pull/44385 - resource_pool = RayResourcePool(process_on_nodes=process_on_nodes, use_gpu=True, max_colocate_count=1, name_prefix=resource_pool_name) + resource_pool = RayResourcePool(process_on_nodes=process_on_nodes, + use_gpu=True, + max_colocate_count=1, + name_prefix=resource_pool_name) self.resource_pool_dict[resource_pool_name] = resource_pool def get_resource_pool(self, role: Role) -> RayResourcePool: