diff --git a/README.md b/README.md index 5df1912..caac101 100644 --- a/README.md +++ b/README.md @@ -21,7 +21,7 @@ veRL is fast with:

-| Documentation | Paper | +| Documentation | Paper | Slack |

@@ -150,16 +150,22 @@ export PYTHONPATH=$PYTHONPATH:$(pwd) ## Getting Started Visit our [documentation](https://verl.readthedocs.io/en/latest/index.html) to learn more. -**Running an PPO example should follow:** -- Preparation - - [Installation](https://verl.readthedocs.io/en/latest/preparation/install.html) +**Quickstart:** +- [Installation](https://verl.readthedocs.io/en/latest/preparation/install.html) +- [Quickstart](https://verl.readthedocs.io/en/latest/start/quickstart.html) + +**Running an PPO example step-by-step:** +- Data and Reward Preparation - [Prepare Data (Parquet) for Post-Training](https://verl.readthedocs.io/en/latest/preparation/prepare_data.html) - [Implement Reward Function for Dataset](https://verl.readthedocs.io/en/latest/preparation/reward_function.html) -- PPO Example (Run an example) +- Understanding the PPO Example - [PPO Example Architecture](https://verl.readthedocs.io/en/latest/examples/ppo_code_architecture.html) - [Config Explanation](https://verl.readthedocs.io/en/latest/examples/config.html) - [Run GSM8K Example](https://verl.readthedocs.io/en/latest/examples/gsm8k_example.html) +**Reproducible algorithm baselines:** +- [PPO](https://verl.readthedocs.io/en/latest/experiment/ppo.html) + **For code explanation and advance usage (extension):** - PPO Trainer and Workers - [PPO Ray Trainer](https://verl.readthedocs.io/en/latest/workers/ray_trainer.html) diff --git a/docs/examples/ppo_code_architecture.rst b/docs/examples/ppo_code_architecture.rst index 1cca830..6dbafae 100644 --- a/docs/examples/ppo_code_architecture.rst +++ b/docs/examples/ppo_code_architecture.rst @@ -48,13 +48,13 @@ Define worker classes if config.actor_rollout_ref.actor.strategy == 'fsdp': # for FSDP backend assert config.actor_rollout_ref.actor.strategy == config.critic.strategy - from verl.trainer.ppo.workers.fsdp_workers import ActorRolloutRefWorker, CriticWorker + from verl.workers.fsdp_workers import ActorRolloutRefWorker, CriticWorker from verl.single_controller.ray import RayWorkerGroup ray_worker_group_cls = RayWorkerGroup elif config.actor_rollout_ref.actor.strategy == 'megatron': # for Megatron backend assert config.actor_rollout_ref.actor.strategy == config.critic.strategy - from verl.trainer.ppo.workers.megatron_workers import ActorRolloutRefWorker, CriticWorker + from verl.workers.megatron_workers import ActorRolloutRefWorker, CriticWorker from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup ray_worker_group_cls = NVMegatronRayWorkerGroup # Ray worker class for Megatron-LM @@ -139,7 +139,7 @@ Defining reward model/function # - finally, we combine all the rewards together # - The reward type depends on the tag of the data if config.reward_model.enable: - from verl.trainer.ppo.workers.fsdp_workers import RewardModelWorker + from verl.workers.fsdp_workers import RewardModelWorker role_worker_mapping[Role.RewardModel] = RewardModelWorker mapping[Role.RewardModel] = global_pool_id diff --git a/examples/split_placement/main_ppo_split.py b/examples/split_placement/main_ppo_split.py index 1c608a7..27a7b35 100644 --- a/examples/split_placement/main_ppo_split.py +++ b/examples/split_placement/main_ppo_split.py @@ -119,13 +119,13 @@ def main_task(config): # define worker classes if config.actor_rollout_ref.actor.strategy == 'fsdp': assert config.actor_rollout_ref.actor.strategy == config.critic.strategy - from verl.trainer.ppo.workers.fsdp_workers import ActorRolloutRefWorker, CriticWorker + from verl.workers.fsdp_workers import ActorRolloutRefWorker, CriticWorker from verl.single_controller.ray import RayWorkerGroup ray_worker_group_cls = RayWorkerGroup elif config.actor_rollout_ref.actor.strategy == 'megatron': assert config.actor_rollout_ref.actor.strategy == config.critic.strategy - from verl.trainer.ppo.workers.megatron_workers import ActorRolloutRefWorker, CriticWorker + from verl.workers.megatron_workers import ActorRolloutRefWorker, CriticWorker from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup ray_worker_group_cls = NVMegatronRayWorkerGroup @@ -168,9 +168,9 @@ def main_task(config): # - The reward type depends on the tag of the data if config.reward_model.enable: if config.reward_model.strategy == 'fsdp': - from verl.trainer.ppo.workers.fsdp_workers import RewardModelWorker + from verl.workers.fsdp_workers import RewardModelWorker elif config.reward_model.strategy == 'megatron': - from verl.trainer.ppo.workers.megatron_workers import RewardModelWorker + from verl.workers.megatron_workers import RewardModelWorker else: raise NotImplementedError role_worker_mapping[Role.RewardModel] = RewardModelWorker diff --git a/requirements.txt b/requirements.txt index ca102e9..acc670a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,10 +1,12 @@ -transformers -hydra-core -tensordict==0.5.0 -numpy -pytest -pybind11 +# vllm==0.6.3 # vllm is installed in image building to avoid ray conflicts +# TODO: add version info to requirements +accelerate codetiming -yapf -wandb -git+https://github.com/NVIDIA/TransformerEngine.git@stable \ No newline at end of file +datasets +dill +hydra-core +numpy +pybind11 +ray==2.10 +tensordict +transformers \ No newline at end of file diff --git a/setup.py b/setup.py index 9a97a50..8954ef2 100644 --- a/setup.py +++ b/setup.py @@ -20,29 +20,19 @@ with open(os.path.join(version_folder, 'verl/version/version')) as f: __version__ = f.read().strip() -# TODO: add version info to requirements -install_requires = [ - 'torch==2.4.0', - 'tensordict', - 'transformers', - 'codetiming', - 'pybind11', - 'hydra-core', - 'numpy', - 'yapf', - "dill", - "accelerate" -] + +with open('requirements.txt') as f: + required = f.read().splitlines() + install_requires = [item.strip() for item in required if item.strip()[0] != '#'] install_optional = [ 'vllm==0.6.3', + 'torch==2.4.0', # required by vllm ] extras_require = { - 'demo': ['hydra-core', 'transformers', ''], 'single-controller': ['ray', 'kubernetes'], - 'single-controller-ray': ['ray'], - 'test': ['fsspec', 'pytest', 'datasets'] + 'test': ['pytest'] } from pathlib import Path diff --git a/verl/third_party/vllm/vllm_v_0_3_1/llm.py b/verl/third_party/vllm/vllm_v_0_3_1/llm.py index ab71654..8d24759 100644 --- a/verl/third_party/vllm/vllm_v_0_3_1/llm.py +++ b/verl/third_party/vllm/vllm_v_0_3_1/llm.py @@ -27,7 +27,7 @@ from vllm.utils import Counter import torch from torch.nn.utils.rnn import pad_sequence -from verl.trainer.ppo.rollout.tokenizer import HybridEngineBaseTokenizer +from verl.workers.rollout.tokenizer import HybridEngineBaseTokenizer class LLM: @@ -125,7 +125,7 @@ def __init__( if not isinstance(tokenizer, tokenizer_cls): raise ValueError( f"Unexpected tokenizer type: {type(tokenizer)}. Must be" - "one of the following: PreTrainedTokenizer, PreTrainedTokenizerFast, verl.trainer.ppo.rollout.HybridEngineBaseTokenizer" + "one of the following: PreTrainedTokenizer, PreTrainedTokenizerFast, verl.workers.rollout.HybridEngineBaseTokenizer" ) self.llm_engine = LLMEngine.from_engine_args(model, tokenizer, engine_args) self.request_counter = Counter() diff --git a/verl/third_party/vllm/vllm_v_0_4_2/llm.py b/verl/third_party/vllm/vllm_v_0_4_2/llm.py index b042525..94623a4 100644 --- a/verl/third_party/vllm/vllm_v_0_4_2/llm.py +++ b/verl/third_party/vllm/vllm_v_0_4_2/llm.py @@ -29,7 +29,7 @@ from vllm.utils import Counter import torch from torch.nn.utils.rnn import pad_sequence -from verl.trainer.ppo.rollout.tokenizer import HybridEngineBaseTokenizer +from verl.workers.rollout.tokenizer import HybridEngineBaseTokenizer class LLM: @@ -129,7 +129,7 @@ def __init__( if not isinstance(tokenizer, tokenizer_cls): raise ValueError( f"Unexpected tokenizer type: {type(tokenizer)}. Must be" - "one of the following: PreTrainedTokenizer, PreTrainedTokenizerFast, verl.trainer.ppo.rollout.HybridEngineBaseTokenizer" + "one of the following: PreTrainedTokenizer, PreTrainedTokenizerFast, verl.workers.rollout.HybridEngineBaseTokenizer" ) self.llm_engine = LLMEngine.from_engine_args(model, tokenizer, engine_args) self.request_counter = Counter() diff --git a/verl/third_party/vllm/vllm_v_0_5_4/llm.py b/verl/third_party/vllm/vllm_v_0_5_4/llm.py index 0979cdc..5f56f1e 100644 --- a/verl/third_party/vllm/vllm_v_0_5_4/llm.py +++ b/verl/third_party/vllm/vllm_v_0_5_4/llm.py @@ -37,7 +37,7 @@ from vllm.utils import Counter, deprecate_kwargs import torch from torch.nn.utils.rnn import pad_sequence -from verl.trainer.ppo.rollout.tokenizer import HybridEngineBaseTokenizer +from verl.workers.rollout.tokenizer import HybridEngineBaseTokenizer class LLM(LLM): @@ -143,7 +143,7 @@ def __init__( if not isinstance(tokenizer, tokenizer_cls): raise ValueError( f"Unexpected tokenizer type: {type(tokenizer)}. Must be" - "one of the following: PreTrainedTokenizer, PreTrainedTokenizerFast, verl.trainer.ppo.rollout.HybridEngineBaseTokenizer" + "one of the following: PreTrainedTokenizer, PreTrainedTokenizerFast, verl.workers.rollout.HybridEngineBaseTokenizer" ) self.llm_engine = LLMEngine.from_engine_args(model, tokenizer, engine_args) # TODO: check usagecontext self.request_counter = Counter() diff --git a/verl/third_party/vllm/vllm_v_0_6_3/llm.py b/verl/third_party/vllm/vllm_v_0_6_3/llm.py index 9351457..cd3d646 100644 --- a/verl/third_party/vllm/vllm_v_0_6_3/llm.py +++ b/verl/third_party/vllm/vllm_v_0_6_3/llm.py @@ -19,7 +19,7 @@ import torch.nn as nn from torch.nn.utils.rnn import pad_sequence from transformers import PretrainedConfig, PreTrainedTokenizer, PreTrainedTokenizerFast -from verl.trainer.ppo.rollout.tokenizer import HybridEngineBaseTokenizer +from verl.workers.rollout.tokenizer import HybridEngineBaseTokenizer from vllm import LLM from vllm.outputs import EmbeddingRequestOutput, RequestOutput from vllm.utils import Counter @@ -137,7 +137,7 @@ def __init__( if not isinstance(tokenizer, tokenizer_cls): raise ValueError( f"Unexpected tokenizer type: {type(tokenizer)}. Must be" - "one of the following: PreTrainedTokenizer, PreTrainedTokenizerFast, verl.trainer.ppo.rollout.HybridEngineBaseTokenizer" + "one of the following: PreTrainedTokenizer, PreTrainedTokenizerFast, verl.workers.rollout.HybridEngineBaseTokenizer" ) self.llm_engine = LLMEngine.from_engine_args(model, tokenizer, engine_args) # TODO: check usagecontext self.request_counter = Counter() diff --git a/verl/trainer/main_generation.py b/verl/trainer/main_generation.py index 47677d3..b0bc7d7 100644 --- a/verl/trainer/main_generation.py +++ b/verl/trainer/main_generation.py @@ -31,7 +31,7 @@ from verl import DataProto from verl.utils.fs import copy_local_path_from_hdfs -from verl.trainer.ppo.workers.fsdp_workers import ActorRolloutRefWorker +from verl.workers.fsdp_workers import ActorRolloutRefWorker from verl.utils.hdfs_io import makedirs from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup diff --git a/verl/trainer/main_ppo.py b/verl/trainer/main_ppo.py index 2e664fe..325e9be 100644 --- a/verl/trainer/main_ppo.py +++ b/verl/trainer/main_ppo.py @@ -120,13 +120,13 @@ def main_task(config): # define worker classes if config.actor_rollout_ref.actor.strategy == 'fsdp': assert config.actor_rollout_ref.actor.strategy == config.critic.strategy - from verl.trainer.ppo.workers.fsdp_workers import ActorRolloutRefWorker, CriticWorker + from verl.workers.fsdp_workers import ActorRolloutRefWorker, CriticWorker from verl.single_controller.ray import RayWorkerGroup ray_worker_group_cls = RayWorkerGroup elif config.actor_rollout_ref.actor.strategy == 'megatron': assert config.actor_rollout_ref.actor.strategy == config.critic.strategy - from verl.trainer.ppo.workers.megatron_workers import ActorRolloutRefWorker, CriticWorker + from verl.workers.megatron_workers import ActorRolloutRefWorker, CriticWorker from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup ray_worker_group_cls = NVMegatronRayWorkerGroup @@ -159,9 +159,9 @@ def main_task(config): # - The reward type depends on the tag of the data if config.reward_model.enable: if config.reward_model.strategy == 'fsdp': - from verl.trainer.ppo.workers.fsdp_workers import RewardModelWorker + from verl.workers.fsdp_workers import RewardModelWorker elif config.reward_model.strategy == 'megatron': - from verl.trainer.ppo.workers.megatron_workers import RewardModelWorker + from verl.workers.megatron_workers import RewardModelWorker else: raise NotImplementedError role_worker_mapping[Role.RewardModel] = RewardModelWorker diff --git a/verl/trainer/ppo/rollout/megatron/__init__.py b/verl/trainer/ppo/rollout/megatron/__init__.py deleted file mode 100644 index d1b842b..0000000 --- a/verl/trainer/ppo/rollout/megatron/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from .allgather_pp_model import AllGatherPPModel -from .hybrid_engine_naive_rollout import MegatronHybridEngineNaiveRollout -from .naive_rollout import MegatronNaiveRollout diff --git a/verl/trainer/ppo/workers/__init__.py b/verl/workers/__init__.py similarity index 100% rename from verl/trainer/ppo/workers/__init__.py rename to verl/workers/__init__.py diff --git a/verl/trainer/ppo/actor/__init__.py b/verl/workers/actor/__init__.py similarity index 100% rename from verl/trainer/ppo/actor/__init__.py rename to verl/workers/actor/__init__.py diff --git a/verl/trainer/ppo/actor/base.py b/verl/workers/actor/base.py similarity index 100% rename from verl/trainer/ppo/actor/base.py rename to verl/workers/actor/base.py diff --git a/verl/trainer/ppo/actor/dp_actor.py b/verl/workers/actor/dp_actor.py similarity index 99% rename from verl/trainer/ppo/actor/dp_actor.py rename to verl/workers/actor/dp_actor.py index 885ff00..e0001c9 100644 --- a/verl/trainer/ppo/actor/dp_actor.py +++ b/verl/workers/actor/dp_actor.py @@ -22,7 +22,7 @@ from verl import DataProto from verl.trainer.ppo import core_algos -from verl.trainer.ppo.actor import BasePPOActor +from verl.workers.actor import BasePPOActor from verl.utils.py_functional import append_to_dict from verl.utils.torch_functional import logprobs_from_logits diff --git a/verl/trainer/ppo/actor/megatron_actor.py b/verl/workers/actor/megatron_actor.py similarity index 99% rename from verl/trainer/ppo/actor/megatron_actor.py rename to verl/workers/actor/megatron_actor.py index 8bb31b1..e674a28 100644 --- a/verl/trainer/ppo/actor/megatron_actor.py +++ b/verl/workers/actor/megatron_actor.py @@ -38,7 +38,7 @@ from verl.utils.megatron.pipeline_parallel import (compute_transformers_input_shapes, make_batch_generator) from verl import DataProto from verl.trainer.ppo import core_algos -from verl.trainer.ppo.actor import BasePPOActor +from verl.workers.actor import BasePPOActor from verl.utils.py_functional import append_to_dict from verl.utils.torch_functional import logprobs_from_logits, broadcast_dict_tensor, split_dict_tensor_into_batches diff --git a/verl/trainer/ppo/critic/__init__.py b/verl/workers/critic/__init__.py similarity index 100% rename from verl/trainer/ppo/critic/__init__.py rename to verl/workers/critic/__init__.py diff --git a/verl/trainer/ppo/critic/base.py b/verl/workers/critic/base.py similarity index 100% rename from verl/trainer/ppo/critic/base.py rename to verl/workers/critic/base.py diff --git a/verl/trainer/ppo/critic/dp_critic.py b/verl/workers/critic/dp_critic.py similarity index 99% rename from verl/trainer/ppo/critic/dp_critic.py rename to verl/workers/critic/dp_critic.py index e078c62..028b6c9 100644 --- a/verl/trainer/ppo/critic/dp_critic.py +++ b/verl/workers/critic/dp_critic.py @@ -25,7 +25,7 @@ from verl import DataProto from verl.trainer.ppo import core_algos -from verl.trainer.ppo.critic import BasePPOCritic +from verl.workers.critic import BasePPOCritic from verl.utils.py_functional import append_to_dict from verl.utils.torch_functional import masked_mean diff --git a/verl/trainer/ppo/critic/megatron_critic.py b/verl/workers/critic/megatron_critic.py similarity index 99% rename from verl/trainer/ppo/critic/megatron_critic.py rename to verl/workers/critic/megatron_critic.py index 17c3e18..22ae384 100644 --- a/verl/trainer/ppo/critic/megatron_critic.py +++ b/verl/workers/critic/megatron_critic.py @@ -25,7 +25,7 @@ from verl import DataProto from verl.trainer.ppo import core_algos -from verl.trainer.ppo.critic import BasePPOCritic +from verl.workers.critic import BasePPOCritic from verl.utils.megatron.pipeline_parallel import (compute_transformers_input_shapes, make_batch_generator) from verl.utils.py_functional import append_to_dict from verl.utils.torch_dtypes import PrecisionType diff --git a/verl/trainer/ppo/workers/fsdp_workers.py b/verl/workers/fsdp_workers.py similarity index 98% rename from verl/trainer/ppo/workers/fsdp_workers.py rename to verl/workers/fsdp_workers.py index 439a36b..45c4197 100644 --- a/verl/trainer/ppo/workers/fsdp_workers.py +++ b/verl/workers/fsdp_workers.py @@ -27,7 +27,6 @@ from verl.single_controller.base.decorator import register, Dispatch import verl.utils.torch_functional as verl_F from verl import DataProto -from verl.trainer.ppo.actor import DataParallelPPOActor from verl.utils.model import compute_position_id_with_mask from verl.utils.fs import copy_local_path_from_hdfs from verl.utils.fsdp_utils import get_fsdp_wrap_policy, load_fsdp_grad, offload_fsdp_grad, init_fn, get_init_weight_context_manager @@ -220,14 +219,14 @@ def _build_model_optimizer(self, def _build_rollout(self): if self.config.rollout.name == 'hf': - from verl.trainer.ppo.rollout import HFRollout - from verl.trainer.ppo.hybrid_engine import BaseShardingManager + from verl.workers.rollout import HFRollout + from verl.workers.hybrid_engine import BaseShardingManager rollout = HFRollout(module=self.actor_module_fsdp, config=self.config.rollout) sharding_manager = BaseShardingManager() # TODO: a sharding manager that do nothing? elif self.config.rollout.name == 'vllm': - from verl.trainer.ppo.rollout.vllm_rollout import vLLMRollout - from verl.trainer.ppo.hybrid_engine import FSDPVLLMShardingManager + from verl.workers.rollout.vllm_rollout import vLLMRollout + from verl.workers.hybrid_engine import FSDPVLLMShardingManager log_gpu_memory_usage('Before building vllm rollout', logger=None) rollout = vLLMRollout(actor_module=self.actor_module_fsdp, config=self.config.rollout, @@ -246,6 +245,7 @@ def _build_rollout(self): @register(dispatch_mode=Dispatch.ONE_TO_ALL) def init_model(self): + from verl.workers.actor import DataParallelPPOActor # This is used to import external_lib into the huggingface systems import_external_libs(self.config.model.get('external_lib', None)) @@ -557,7 +557,7 @@ def init_model(self): # This is used to import external_lib into the huggingface systems import_external_libs(self.config.model.get('external_lib', None)) - from verl.trainer.ppo.critic import DataParallelPPOCritic + from verl.workers.critic import DataParallelPPOCritic self.critic_module, self.critic_optimizer, self.critic_lr_scheduler = self._build_critic_model_optimizer( self.config) diff --git a/verl/trainer/ppo/hybrid_engine/__init__.py b/verl/workers/hybrid_engine/__init__.py similarity index 100% rename from verl/trainer/ppo/hybrid_engine/__init__.py rename to verl/workers/hybrid_engine/__init__.py diff --git a/verl/trainer/ppo/hybrid_engine/base.py b/verl/workers/hybrid_engine/base.py similarity index 100% rename from verl/trainer/ppo/hybrid_engine/base.py rename to verl/workers/hybrid_engine/base.py diff --git a/verl/trainer/ppo/hybrid_engine/fsdp_vllm.py b/verl/workers/hybrid_engine/fsdp_vllm.py similarity index 100% rename from verl/trainer/ppo/hybrid_engine/fsdp_vllm.py rename to verl/workers/hybrid_engine/fsdp_vllm.py diff --git a/verl/trainer/ppo/hybrid_engine/megatron_vllm.py b/verl/workers/hybrid_engine/megatron_vllm.py similarity index 100% rename from verl/trainer/ppo/hybrid_engine/megatron_vllm.py rename to verl/workers/hybrid_engine/megatron_vllm.py diff --git a/verl/trainer/ppo/workers/megatron_workers.py b/verl/workers/megatron_workers.py similarity index 98% rename from verl/trainer/ppo/workers/megatron_workers.py rename to verl/workers/megatron_workers.py index b826905..f76c4d4 100644 --- a/verl/trainer/ppo/workers/megatron_workers.py +++ b/verl/workers/megatron_workers.py @@ -23,10 +23,10 @@ import torch.nn as nn from omegaconf import DictConfig from verl.single_controller.base.megatron.worker import MegatronWorker -from verl.trainer.ppo.actor.megatron_actor import MegatronPPOActor -from verl.trainer.ppo.critic.megatron_critic import MegatronPPOCritic -from verl.trainer.ppo.hybrid_engine import AllGatherPPModel -from verl.trainer.ppo.reward_model.megatron.reward_model import MegatronRewardModel +from verl.workers.actor.megatron_actor import MegatronPPOActor +from verl.workers.critic.megatron_critic import MegatronPPOCritic +from verl.workers.hybrid_engine import AllGatherPPModel +from verl.workers.reward_model.megatron.reward_model import MegatronRewardModel from verl.single_controller.base.decorator import register, Dispatch from verl import DataProto @@ -216,8 +216,8 @@ def megatron_actor_model_provider(pre_process, post_process): def _build_rollout(self): if self.config.rollout.name == 'vllm': - from verl.trainer.ppo.rollout.vllm_rollout import vLLMRollout - from verl.trainer.ppo.hybrid_engine import MegatronVLLMShardingManager + from verl.workers.rollout.vllm_rollout import vLLMRollout + from verl.workers.hybrid_engine import MegatronVLLMShardingManager from verl.utils.model import normalize_pp_vpp_params # NOTE(sgm): If the QKV and gate_up projection layer are concate together in actor, diff --git a/verl/trainer/ppo/reward_model/__init__.py b/verl/workers/reward_model/__init__.py similarity index 100% rename from verl/trainer/ppo/reward_model/__init__.py rename to verl/workers/reward_model/__init__.py diff --git a/verl/trainer/ppo/reward_model/base.py b/verl/workers/reward_model/base.py similarity index 100% rename from verl/trainer/ppo/reward_model/base.py rename to verl/workers/reward_model/base.py diff --git a/verl/trainer/ppo/reward_model/megatron/__init__.py b/verl/workers/reward_model/megatron/__init__.py similarity index 100% rename from verl/trainer/ppo/reward_model/megatron/__init__.py rename to verl/workers/reward_model/megatron/__init__.py diff --git a/verl/trainer/ppo/reward_model/megatron/reward_model.py b/verl/workers/reward_model/megatron/reward_model.py similarity index 99% rename from verl/trainer/ppo/reward_model/megatron/reward_model.py rename to verl/workers/reward_model/megatron/reward_model.py index 7ea5629..2c4c1b6 100644 --- a/verl/trainer/ppo/reward_model/megatron/reward_model.py +++ b/verl/workers/reward_model/megatron/reward_model.py @@ -28,7 +28,7 @@ from verl import DataProto from verl.utils.torch_functional import logprobs_from_logits, broadcast_dict_tensor, split_dict_tensor_into_batches from verl.utils.torch_dtypes import PrecisionType -from verl.trainer.ppo.reward_model.base import BasePPORewardModel +from verl.workers.reward_model.base import BasePPORewardModel from verl.utils.megatron import sequence_parallel as sp_utils from megatron.core import parallel_state as mpu from megatron.core.pipeline_parallel import get_forward_backward_func diff --git a/verl/trainer/ppo/rollout/__init__.py b/verl/workers/rollout/__init__.py similarity index 100% rename from verl/trainer/ppo/rollout/__init__.py rename to verl/workers/rollout/__init__.py diff --git a/verl/trainer/ppo/rollout/base.py b/verl/workers/rollout/base.py similarity index 100% rename from verl/trainer/ppo/rollout/base.py rename to verl/workers/rollout/base.py diff --git a/verl/trainer/ppo/rollout/hf_rollout.py b/verl/workers/rollout/hf_rollout.py similarity index 100% rename from verl/trainer/ppo/rollout/hf_rollout.py rename to verl/workers/rollout/hf_rollout.py diff --git a/verl/trainer/ppo/rollout/naive/__init__.py b/verl/workers/rollout/naive/__init__.py similarity index 100% rename from verl/trainer/ppo/rollout/naive/__init__.py rename to verl/workers/rollout/naive/__init__.py diff --git a/verl/trainer/ppo/rollout/naive/naive_rollout.py b/verl/workers/rollout/naive/naive_rollout.py similarity index 100% rename from verl/trainer/ppo/rollout/naive/naive_rollout.py rename to verl/workers/rollout/naive/naive_rollout.py diff --git a/verl/trainer/ppo/rollout/tokenizer.py b/verl/workers/rollout/tokenizer.py similarity index 100% rename from verl/trainer/ppo/rollout/tokenizer.py rename to verl/workers/rollout/tokenizer.py diff --git a/verl/trainer/ppo/rollout/vllm_rollout/__init__.py b/verl/workers/rollout/vllm_rollout/__init__.py similarity index 100% rename from verl/trainer/ppo/rollout/vllm_rollout/__init__.py rename to verl/workers/rollout/vllm_rollout/__init__.py diff --git a/verl/trainer/ppo/rollout/vllm_rollout/vllm_rollout.py b/verl/workers/rollout/vllm_rollout/vllm_rollout.py similarity index 99% rename from verl/trainer/ppo/rollout/vllm_rollout/vllm_rollout.py rename to verl/workers/rollout/vllm_rollout/vllm_rollout.py index e66275c..0dc2de0 100644 --- a/verl/trainer/ppo/rollout/vllm_rollout/vllm_rollout.py +++ b/verl/workers/rollout/vllm_rollout/vllm_rollout.py @@ -34,7 +34,7 @@ from verl import DataProto from verl.utils.torch_functional import get_eos_mask, pad_sequence_to_length -from verl.trainer.ppo.rollout.base import BaseRollout +from verl.workers.rollout.base import BaseRollout from verl.third_party.vllm import LLM, vllm_version from verl.third_party.vllm import parallel_state as vllm_ps from vllm import SamplingParams