Skip to content

Commit

Permalink
merge with master
Browse files Browse the repository at this point in the history
  • Loading branch information
eric-haibin-lin committed Dec 18, 2024
2 parents c70cb24 + f535977 commit 6806b4b
Show file tree
Hide file tree
Showing 39 changed files with 66 additions and 85 deletions.
16 changes: 11 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ veRL is fast with:


<p align="center">
| <a href="https://verl.readthedocs.io/en/latest/index.html"><b>Documentation</b></a> | <a href="https://arxiv.org/abs/2409.19256v2"><b>Paper</b></a> |
| <a href="https://verl.readthedocs.io/en/latest/index.html"><b>Documentation</b></a> | <a href="https://arxiv.org/abs/2409.19256v2"><b>Paper</b></a> | <a href="https://join.slack.com/t/verlgroup/shared_invite/zt-2w5p9o4c3-yy0x2Q56s_VlGLsJ93A6vA"><b>Slack</b></a> |
<!-- <a href=""><b>Slides</b></a> | -->
</p>

Expand Down Expand Up @@ -150,16 +150,22 @@ export PYTHONPATH=$PYTHONPATH:$(pwd)
## Getting Started
Visit our [documentation](https://verl.readthedocs.io/en/latest/index.html) to learn more.

**Running an PPO example should follow:**
- Preparation
- [Installation](https://verl.readthedocs.io/en/latest/preparation/install.html)
**Quickstart:**
- [Installation](https://verl.readthedocs.io/en/latest/preparation/install.html)
- [Quickstart](https://verl.readthedocs.io/en/latest/start/quickstart.html)

**Running an PPO example step-by-step:**
- Data and Reward Preparation
- [Prepare Data (Parquet) for Post-Training](https://verl.readthedocs.io/en/latest/preparation/prepare_data.html)
- [Implement Reward Function for Dataset](https://verl.readthedocs.io/en/latest/preparation/reward_function.html)
- PPO Example (Run an example)
- Understanding the PPO Example
- [PPO Example Architecture](https://verl.readthedocs.io/en/latest/examples/ppo_code_architecture.html)
- [Config Explanation](https://verl.readthedocs.io/en/latest/examples/config.html)
- [Run GSM8K Example](https://verl.readthedocs.io/en/latest/examples/gsm8k_example.html)

**Reproducible algorithm baselines:**
- [PPO](https://verl.readthedocs.io/en/latest/experiment/ppo.html)

**For code explanation and advance usage (extension):**
- PPO Trainer and Workers
- [PPO Ray Trainer](https://verl.readthedocs.io/en/latest/workers/ray_trainer.html)
Expand Down
6 changes: 3 additions & 3 deletions docs/examples/ppo_code_architecture.rst
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,13 @@ Define worker classes
if config.actor_rollout_ref.actor.strategy == 'fsdp': # for FSDP backend
assert config.actor_rollout_ref.actor.strategy == config.critic.strategy
from verl.trainer.ppo.workers.fsdp_workers import ActorRolloutRefWorker, CriticWorker
from verl.workers.fsdp_workers import ActorRolloutRefWorker, CriticWorker
from verl.single_controller.ray import RayWorkerGroup
ray_worker_group_cls = RayWorkerGroup
elif config.actor_rollout_ref.actor.strategy == 'megatron': # for Megatron backend
assert config.actor_rollout_ref.actor.strategy == config.critic.strategy
from verl.trainer.ppo.workers.megatron_workers import ActorRolloutRefWorker, CriticWorker
from verl.workers.megatron_workers import ActorRolloutRefWorker, CriticWorker
from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup
ray_worker_group_cls = NVMegatronRayWorkerGroup # Ray worker class for Megatron-LM
Expand Down Expand Up @@ -139,7 +139,7 @@ Defining reward model/function
# - finally, we combine all the rewards together
# - The reward type depends on the tag of the data
if config.reward_model.enable:
from verl.trainer.ppo.workers.fsdp_workers import RewardModelWorker
from verl.workers.fsdp_workers import RewardModelWorker
role_worker_mapping[Role.RewardModel] = RewardModelWorker
mapping[Role.RewardModel] = global_pool_id
Expand Down
8 changes: 4 additions & 4 deletions examples/split_placement/main_ppo_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,13 +119,13 @@ def main_task(config):
# define worker classes
if config.actor_rollout_ref.actor.strategy == 'fsdp':
assert config.actor_rollout_ref.actor.strategy == config.critic.strategy
from verl.trainer.ppo.workers.fsdp_workers import ActorRolloutRefWorker, CriticWorker
from verl.workers.fsdp_workers import ActorRolloutRefWorker, CriticWorker
from verl.single_controller.ray import RayWorkerGroup
ray_worker_group_cls = RayWorkerGroup

elif config.actor_rollout_ref.actor.strategy == 'megatron':
assert config.actor_rollout_ref.actor.strategy == config.critic.strategy
from verl.trainer.ppo.workers.megatron_workers import ActorRolloutRefWorker, CriticWorker
from verl.workers.megatron_workers import ActorRolloutRefWorker, CriticWorker
from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup
ray_worker_group_cls = NVMegatronRayWorkerGroup

Expand Down Expand Up @@ -168,9 +168,9 @@ def main_task(config):
# - The reward type depends on the tag of the data
if config.reward_model.enable:
if config.reward_model.strategy == 'fsdp':
from verl.trainer.ppo.workers.fsdp_workers import RewardModelWorker
from verl.workers.fsdp_workers import RewardModelWorker
elif config.reward_model.strategy == 'megatron':
from verl.trainer.ppo.workers.megatron_workers import RewardModelWorker
from verl.workers.megatron_workers import RewardModelWorker
else:
raise NotImplementedError
role_worker_mapping[Role.RewardModel] = RewardModelWorker
Expand Down
20 changes: 11 additions & 9 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
transformers
hydra-core
tensordict==0.5.0
numpy
pytest
pybind11
# vllm==0.6.3 # vllm is installed in image building to avoid ray conflicts
# TODO: add version info to requirements
accelerate
codetiming
yapf
wandb
git+https://github.com/NVIDIA/TransformerEngine.git@stable
datasets
dill
hydra-core
numpy
pybind11
ray==2.10
tensordict
transformers
22 changes: 6 additions & 16 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,29 +20,19 @@
with open(os.path.join(version_folder, 'verl/version/version')) as f:
__version__ = f.read().strip()

# TODO: add version info to requirements
install_requires = [
'torch==2.4.0',
'tensordict',
'transformers',
'codetiming',
'pybind11',
'hydra-core',
'numpy',
'yapf',
"dill",
"accelerate"
]

with open('requirements.txt') as f:
required = f.read().splitlines()
install_requires = [item.strip() for item in required if item.strip()[0] != '#']

install_optional = [
'vllm==0.6.3',
'torch==2.4.0', # required by vllm
]

extras_require = {
'demo': ['hydra-core', 'transformers', ''],
'single-controller': ['ray', 'kubernetes'],
'single-controller-ray': ['ray'],
'test': ['fsspec', 'pytest', 'datasets']
'test': ['pytest']
}

from pathlib import Path
Expand Down
4 changes: 2 additions & 2 deletions verl/third_party/vllm/vllm_v_0_3_1/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from vllm.utils import Counter
import torch
from torch.nn.utils.rnn import pad_sequence
from verl.trainer.ppo.rollout.tokenizer import HybridEngineBaseTokenizer
from verl.workers.rollout.tokenizer import HybridEngineBaseTokenizer


class LLM:
Expand Down Expand Up @@ -125,7 +125,7 @@ def __init__(
if not isinstance(tokenizer, tokenizer_cls):
raise ValueError(
f"Unexpected tokenizer type: {type(tokenizer)}. Must be"
"one of the following: PreTrainedTokenizer, PreTrainedTokenizerFast, verl.trainer.ppo.rollout.HybridEngineBaseTokenizer"
"one of the following: PreTrainedTokenizer, PreTrainedTokenizerFast, verl.workers.rollout.HybridEngineBaseTokenizer"
)
self.llm_engine = LLMEngine.from_engine_args(model, tokenizer, engine_args)
self.request_counter = Counter()
Expand Down
4 changes: 2 additions & 2 deletions verl/third_party/vllm/vllm_v_0_4_2/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from vllm.utils import Counter
import torch
from torch.nn.utils.rnn import pad_sequence
from verl.trainer.ppo.rollout.tokenizer import HybridEngineBaseTokenizer
from verl.workers.rollout.tokenizer import HybridEngineBaseTokenizer


class LLM:
Expand Down Expand Up @@ -129,7 +129,7 @@ def __init__(
if not isinstance(tokenizer, tokenizer_cls):
raise ValueError(
f"Unexpected tokenizer type: {type(tokenizer)}. Must be"
"one of the following: PreTrainedTokenizer, PreTrainedTokenizerFast, verl.trainer.ppo.rollout.HybridEngineBaseTokenizer"
"one of the following: PreTrainedTokenizer, PreTrainedTokenizerFast, verl.workers.rollout.HybridEngineBaseTokenizer"
)
self.llm_engine = LLMEngine.from_engine_args(model, tokenizer, engine_args)
self.request_counter = Counter()
Expand Down
4 changes: 2 additions & 2 deletions verl/third_party/vllm/vllm_v_0_5_4/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from vllm.utils import Counter, deprecate_kwargs
import torch
from torch.nn.utils.rnn import pad_sequence
from verl.trainer.ppo.rollout.tokenizer import HybridEngineBaseTokenizer
from verl.workers.rollout.tokenizer import HybridEngineBaseTokenizer


class LLM(LLM):
Expand Down Expand Up @@ -143,7 +143,7 @@ def __init__(
if not isinstance(tokenizer, tokenizer_cls):
raise ValueError(
f"Unexpected tokenizer type: {type(tokenizer)}. Must be"
"one of the following: PreTrainedTokenizer, PreTrainedTokenizerFast, verl.trainer.ppo.rollout.HybridEngineBaseTokenizer"
"one of the following: PreTrainedTokenizer, PreTrainedTokenizerFast, verl.workers.rollout.HybridEngineBaseTokenizer"
)
self.llm_engine = LLMEngine.from_engine_args(model, tokenizer, engine_args) # TODO: check usagecontext
self.request_counter = Counter()
Expand Down
4 changes: 2 additions & 2 deletions verl/third_party/vllm/vllm_v_0_6_3/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence
from transformers import PretrainedConfig, PreTrainedTokenizer, PreTrainedTokenizerFast
from verl.trainer.ppo.rollout.tokenizer import HybridEngineBaseTokenizer
from verl.workers.rollout.tokenizer import HybridEngineBaseTokenizer
from vllm import LLM
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.utils import Counter
Expand Down Expand Up @@ -137,7 +137,7 @@ def __init__(
if not isinstance(tokenizer, tokenizer_cls):
raise ValueError(
f"Unexpected tokenizer type: {type(tokenizer)}. Must be"
"one of the following: PreTrainedTokenizer, PreTrainedTokenizerFast, verl.trainer.ppo.rollout.HybridEngineBaseTokenizer"
"one of the following: PreTrainedTokenizer, PreTrainedTokenizerFast, verl.workers.rollout.HybridEngineBaseTokenizer"
)
self.llm_engine = LLMEngine.from_engine_args(model, tokenizer, engine_args) # TODO: check usagecontext
self.request_counter = Counter()
Expand Down
2 changes: 1 addition & 1 deletion verl/trainer/main_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@

from verl import DataProto
from verl.utils.fs import copy_local_path_from_hdfs
from verl.trainer.ppo.workers.fsdp_workers import ActorRolloutRefWorker
from verl.workers.fsdp_workers import ActorRolloutRefWorker
from verl.utils.hdfs_io import makedirs
from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup

Expand Down
8 changes: 4 additions & 4 deletions verl/trainer/main_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,13 +120,13 @@ def main_task(config):
# define worker classes
if config.actor_rollout_ref.actor.strategy == 'fsdp':
assert config.actor_rollout_ref.actor.strategy == config.critic.strategy
from verl.trainer.ppo.workers.fsdp_workers import ActorRolloutRefWorker, CriticWorker
from verl.workers.fsdp_workers import ActorRolloutRefWorker, CriticWorker
from verl.single_controller.ray import RayWorkerGroup
ray_worker_group_cls = RayWorkerGroup

elif config.actor_rollout_ref.actor.strategy == 'megatron':
assert config.actor_rollout_ref.actor.strategy == config.critic.strategy
from verl.trainer.ppo.workers.megatron_workers import ActorRolloutRefWorker, CriticWorker
from verl.workers.megatron_workers import ActorRolloutRefWorker, CriticWorker
from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup
ray_worker_group_cls = NVMegatronRayWorkerGroup

Expand Down Expand Up @@ -159,9 +159,9 @@ def main_task(config):
# - The reward type depends on the tag of the data
if config.reward_model.enable:
if config.reward_model.strategy == 'fsdp':
from verl.trainer.ppo.workers.fsdp_workers import RewardModelWorker
from verl.workers.fsdp_workers import RewardModelWorker
elif config.reward_model.strategy == 'megatron':
from verl.trainer.ppo.workers.megatron_workers import RewardModelWorker
from verl.workers.megatron_workers import RewardModelWorker
else:
raise NotImplementedError
role_worker_mapping[Role.RewardModel] = RewardModelWorker
Expand Down
17 changes: 0 additions & 17 deletions verl/trainer/ppo/rollout/megatron/__init__.py

This file was deleted.

File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from verl import DataProto
from verl.trainer.ppo import core_algos
from verl.trainer.ppo.actor import BasePPOActor
from verl.workers.actor import BasePPOActor
from verl.utils.py_functional import append_to_dict
from verl.utils.torch_functional import logprobs_from_logits

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from verl.utils.megatron.pipeline_parallel import (compute_transformers_input_shapes, make_batch_generator)
from verl import DataProto
from verl.trainer.ppo import core_algos
from verl.trainer.ppo.actor import BasePPOActor
from verl.workers.actor import BasePPOActor
from verl.utils.py_functional import append_to_dict
from verl.utils.torch_functional import logprobs_from_logits, broadcast_dict_tensor, split_dict_tensor_into_batches

Expand Down
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

from verl import DataProto
from verl.trainer.ppo import core_algos
from verl.trainer.ppo.critic import BasePPOCritic
from verl.workers.critic import BasePPOCritic
from verl.utils.py_functional import append_to_dict
from verl.utils.torch_functional import masked_mean

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

from verl import DataProto
from verl.trainer.ppo import core_algos
from verl.trainer.ppo.critic import BasePPOCritic
from verl.workers.critic import BasePPOCritic
from verl.utils.megatron.pipeline_parallel import (compute_transformers_input_shapes, make_batch_generator)
from verl.utils.py_functional import append_to_dict
from verl.utils.torch_dtypes import PrecisionType
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
from verl.single_controller.base.decorator import register, Dispatch
import verl.utils.torch_functional as verl_F
from verl import DataProto
from verl.trainer.ppo.actor import DataParallelPPOActor
from verl.utils.model import compute_position_id_with_mask
from verl.utils.fs import copy_local_path_from_hdfs
from verl.utils.fsdp_utils import get_fsdp_wrap_policy, load_fsdp_grad, offload_fsdp_grad, init_fn, get_init_weight_context_manager
Expand Down Expand Up @@ -220,14 +219,14 @@ def _build_model_optimizer(self,

def _build_rollout(self):
if self.config.rollout.name == 'hf':
from verl.trainer.ppo.rollout import HFRollout
from verl.trainer.ppo.hybrid_engine import BaseShardingManager
from verl.workers.rollout import HFRollout
from verl.workers.hybrid_engine import BaseShardingManager
rollout = HFRollout(module=self.actor_module_fsdp, config=self.config.rollout)
sharding_manager = BaseShardingManager()
# TODO: a sharding manager that do nothing?
elif self.config.rollout.name == 'vllm':
from verl.trainer.ppo.rollout.vllm_rollout import vLLMRollout
from verl.trainer.ppo.hybrid_engine import FSDPVLLMShardingManager
from verl.workers.rollout.vllm_rollout import vLLMRollout
from verl.workers.hybrid_engine import FSDPVLLMShardingManager
log_gpu_memory_usage('Before building vllm rollout', logger=None)
rollout = vLLMRollout(actor_module=self.actor_module_fsdp,
config=self.config.rollout,
Expand All @@ -246,6 +245,7 @@ def _build_rollout(self):

@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def init_model(self):
from verl.workers.actor import DataParallelPPOActor
# This is used to import external_lib into the huggingface systems
import_external_libs(self.config.model.get('external_lib', None))

Expand Down Expand Up @@ -557,7 +557,7 @@ def init_model(self):
# This is used to import external_lib into the huggingface systems
import_external_libs(self.config.model.get('external_lib', None))

from verl.trainer.ppo.critic import DataParallelPPOCritic
from verl.workers.critic import DataParallelPPOCritic
self.critic_module, self.critic_optimizer, self.critic_lr_scheduler = self._build_critic_model_optimizer(
self.config)

Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@
import torch.nn as nn
from omegaconf import DictConfig
from verl.single_controller.base.megatron.worker import MegatronWorker
from verl.trainer.ppo.actor.megatron_actor import MegatronPPOActor
from verl.trainer.ppo.critic.megatron_critic import MegatronPPOCritic
from verl.trainer.ppo.hybrid_engine import AllGatherPPModel
from verl.trainer.ppo.reward_model.megatron.reward_model import MegatronRewardModel
from verl.workers.actor.megatron_actor import MegatronPPOActor
from verl.workers.critic.megatron_critic import MegatronPPOCritic
from verl.workers.hybrid_engine import AllGatherPPModel
from verl.workers.reward_model.megatron.reward_model import MegatronRewardModel

from verl.single_controller.base.decorator import register, Dispatch
from verl import DataProto
Expand Down Expand Up @@ -216,8 +216,8 @@ def megatron_actor_model_provider(pre_process, post_process):

def _build_rollout(self):
if self.config.rollout.name == 'vllm':
from verl.trainer.ppo.rollout.vllm_rollout import vLLMRollout
from verl.trainer.ppo.hybrid_engine import MegatronVLLMShardingManager
from verl.workers.rollout.vllm_rollout import vLLMRollout
from verl.workers.hybrid_engine import MegatronVLLMShardingManager
from verl.utils.model import normalize_pp_vpp_params

# NOTE(sgm): If the QKV and gate_up projection layer are concate together in actor,
Expand Down
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from verl import DataProto
from verl.utils.torch_functional import logprobs_from_logits, broadcast_dict_tensor, split_dict_tensor_into_batches
from verl.utils.torch_dtypes import PrecisionType
from verl.trainer.ppo.reward_model.base import BasePPORewardModel
from verl.workers.reward_model.base import BasePPORewardModel
from verl.utils.megatron import sequence_parallel as sp_utils
from megatron.core import parallel_state as mpu
from megatron.core.pipeline_parallel import get_forward_backward_func
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@

from verl import DataProto
from verl.utils.torch_functional import get_eos_mask, pad_sequence_to_length
from verl.trainer.ppo.rollout.base import BaseRollout
from verl.workers.rollout.base import BaseRollout
from verl.third_party.vllm import LLM, vllm_version
from verl.third_party.vllm import parallel_state as vllm_ps
from vllm import SamplingParams
Expand Down

0 comments on commit 6806b4b

Please sign in to comment.