From 6143651b4cc4109771e2b8b46ad3706ac8303e1e Mon Sep 17 00:00:00 2001 From: shengguangming Date: Wed, 18 Dec 2024 13:28:02 +0800 Subject: [PATCH] [refact] move workers out of trainers --- docs/examples/ppo_code_architecture.rst | 6 +++--- examples/split_placement/main_ppo_split.py | 8 ++++---- verl/third_party/vllm/vllm_v_0_3_1/llm.py | 4 ++-- verl/third_party/vllm/vllm_v_0_4_2/llm.py | 4 ++-- verl/third_party/vllm/vllm_v_0_5_4/llm.py | 4 ++-- verl/third_party/vllm/vllm_v_0_6_3/llm.py | 4 ++-- verl/trainer/main_generation.py | 2 +- verl/trainer/main_ppo.py | 8 ++++---- verl/{trainer => }/workers/__init__.py | 0 verl/{trainer => }/workers/actor/__init__.py | 0 verl/{trainer => }/workers/actor/base.py | 0 verl/{trainer => }/workers/actor/dp_actor.py | 2 +- verl/{trainer => }/workers/actor/megatron_actor.py | 2 +- verl/{trainer => }/workers/critic/__init__.py | 0 verl/{trainer => }/workers/critic/base.py | 0 verl/{trainer => }/workers/critic/dp_critic.py | 2 +- verl/{trainer => }/workers/critic/megatron_critic.py | 2 +- verl/{trainer => }/workers/fsdp_workers.py | 12 ++++++------ verl/{trainer => }/workers/hybrid_engine/__init__.py | 0 verl/{trainer => }/workers/hybrid_engine/base.py | 0 .../{trainer => }/workers/hybrid_engine/fsdp_vllm.py | 0 .../workers/hybrid_engine/megatron_vllm.py | 0 verl/{trainer => }/workers/megatron_workers.py | 12 ++++++------ verl/{trainer => }/workers/reward_model/__init__.py | 0 verl/{trainer => }/workers/reward_model/base.py | 0 .../workers/reward_model/megatron/__init__.py | 0 .../workers/reward_model/megatron/reward_model.py | 2 +- verl/{trainer => }/workers/rollout/__init__.py | 0 verl/{trainer => }/workers/rollout/base.py | 0 verl/{trainer => }/workers/rollout/hf_rollout.py | 0 verl/{trainer => }/workers/rollout/naive/__init__.py | 0 .../workers/rollout/naive/naive_rollout.py | 0 verl/{trainer => }/workers/rollout/tokenizer.py | 0 .../workers/rollout/vllm_rollout/__init__.py | 0 .../workers/rollout/vllm_rollout/vllm_rollout.py | 2 +- 35 files changed, 38 insertions(+), 38 deletions(-) rename verl/{trainer => }/workers/__init__.py (100%) rename verl/{trainer => }/workers/actor/__init__.py (100%) rename verl/{trainer => }/workers/actor/base.py (100%) rename verl/{trainer => }/workers/actor/dp_actor.py (99%) rename verl/{trainer => }/workers/actor/megatron_actor.py (99%) rename verl/{trainer => }/workers/critic/__init__.py (100%) rename verl/{trainer => }/workers/critic/base.py (100%) rename verl/{trainer => }/workers/critic/dp_critic.py (99%) rename verl/{trainer => }/workers/critic/megatron_critic.py (99%) rename verl/{trainer => }/workers/fsdp_workers.py (98%) rename verl/{trainer => }/workers/hybrid_engine/__init__.py (100%) rename verl/{trainer => }/workers/hybrid_engine/base.py (100%) rename verl/{trainer => }/workers/hybrid_engine/fsdp_vllm.py (100%) rename verl/{trainer => }/workers/hybrid_engine/megatron_vllm.py (100%) rename verl/{trainer => }/workers/megatron_workers.py (98%) rename verl/{trainer => }/workers/reward_model/__init__.py (100%) rename verl/{trainer => }/workers/reward_model/base.py (100%) rename verl/{trainer => }/workers/reward_model/megatron/__init__.py (100%) rename verl/{trainer => }/workers/reward_model/megatron/reward_model.py (99%) rename verl/{trainer => }/workers/rollout/__init__.py (100%) rename verl/{trainer => }/workers/rollout/base.py (100%) rename verl/{trainer => }/workers/rollout/hf_rollout.py (100%) rename verl/{trainer => }/workers/rollout/naive/__init__.py (100%) rename verl/{trainer => }/workers/rollout/naive/naive_rollout.py (100%) rename verl/{trainer => }/workers/rollout/tokenizer.py (100%) rename verl/{trainer => }/workers/rollout/vllm_rollout/__init__.py (100%) rename verl/{trainer => }/workers/rollout/vllm_rollout/vllm_rollout.py (99%) diff --git a/docs/examples/ppo_code_architecture.rst b/docs/examples/ppo_code_architecture.rst index 3f4bc3e..40f3a6e 100644 --- a/docs/examples/ppo_code_architecture.rst +++ b/docs/examples/ppo_code_architecture.rst @@ -48,13 +48,13 @@ Define worker classes if config.actor_rollout_ref.actor.strategy == 'fsdp': # for FSDP backend assert config.actor_rollout_ref.actor.strategy == config.critic.strategy - from verl.trainer.workers.fsdp_workers import ActorRolloutRefWorker, CriticWorker + from verl.workers.fsdp_workers import ActorRolloutRefWorker, CriticWorker from verl.single_controller.ray import RayWorkerGroup ray_worker_group_cls = RayWorkerGroup elif config.actor_rollout_ref.actor.strategy == 'megatron': # for Megatron backend assert config.actor_rollout_ref.actor.strategy == config.critic.strategy - from verl.trainer.workers.megatron_workers import ActorRolloutRefWorker, CriticWorker + from verl.workers.megatron_workers import ActorRolloutRefWorker, CriticWorker from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup ray_worker_group_cls = NVMegatronRayWorkerGroup # Ray worker class for Megatron-LM @@ -139,7 +139,7 @@ Defining reward model/function # - finally, we combine all the rewards together # - The reward type depends on the tag of the data if config.reward_model.enable: - from verl.trainer.workers.fsdp_workers import RewardModelWorker + from verl.workers.fsdp_workers import RewardModelWorker role_worker_mapping[Role.RewardModel] = RewardModelWorker mapping[Role.RewardModel] = global_pool_id diff --git a/examples/split_placement/main_ppo_split.py b/examples/split_placement/main_ppo_split.py index 0960006..27a7b35 100644 --- a/examples/split_placement/main_ppo_split.py +++ b/examples/split_placement/main_ppo_split.py @@ -119,13 +119,13 @@ def main_task(config): # define worker classes if config.actor_rollout_ref.actor.strategy == 'fsdp': assert config.actor_rollout_ref.actor.strategy == config.critic.strategy - from verl.trainer.workers.fsdp_workers import ActorRolloutRefWorker, CriticWorker + from verl.workers.fsdp_workers import ActorRolloutRefWorker, CriticWorker from verl.single_controller.ray import RayWorkerGroup ray_worker_group_cls = RayWorkerGroup elif config.actor_rollout_ref.actor.strategy == 'megatron': assert config.actor_rollout_ref.actor.strategy == config.critic.strategy - from verl.trainer.workers.megatron_workers import ActorRolloutRefWorker, CriticWorker + from verl.workers.megatron_workers import ActorRolloutRefWorker, CriticWorker from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup ray_worker_group_cls = NVMegatronRayWorkerGroup @@ -168,9 +168,9 @@ def main_task(config): # - The reward type depends on the tag of the data if config.reward_model.enable: if config.reward_model.strategy == 'fsdp': - from verl.trainer.workers.fsdp_workers import RewardModelWorker + from verl.workers.fsdp_workers import RewardModelWorker elif config.reward_model.strategy == 'megatron': - from verl.trainer.workers.megatron_workers import RewardModelWorker + from verl.workers.megatron_workers import RewardModelWorker else: raise NotImplementedError role_worker_mapping[Role.RewardModel] = RewardModelWorker diff --git a/verl/third_party/vllm/vllm_v_0_3_1/llm.py b/verl/third_party/vllm/vllm_v_0_3_1/llm.py index c5ca62c..8d24759 100644 --- a/verl/third_party/vllm/vllm_v_0_3_1/llm.py +++ b/verl/third_party/vllm/vllm_v_0_3_1/llm.py @@ -27,7 +27,7 @@ from vllm.utils import Counter import torch from torch.nn.utils.rnn import pad_sequence -from verl.trainer.workers.rollout.tokenizer import HybridEngineBaseTokenizer +from verl.workers.rollout.tokenizer import HybridEngineBaseTokenizer class LLM: @@ -125,7 +125,7 @@ def __init__( if not isinstance(tokenizer, tokenizer_cls): raise ValueError( f"Unexpected tokenizer type: {type(tokenizer)}. Must be" - "one of the following: PreTrainedTokenizer, PreTrainedTokenizerFast, verl.trainer.workers.rollout.HybridEngineBaseTokenizer" + "one of the following: PreTrainedTokenizer, PreTrainedTokenizerFast, verl.workers.rollout.HybridEngineBaseTokenizer" ) self.llm_engine = LLMEngine.from_engine_args(model, tokenizer, engine_args) self.request_counter = Counter() diff --git a/verl/third_party/vllm/vllm_v_0_4_2/llm.py b/verl/third_party/vllm/vllm_v_0_4_2/llm.py index 2a30878..94623a4 100644 --- a/verl/third_party/vllm/vllm_v_0_4_2/llm.py +++ b/verl/third_party/vllm/vllm_v_0_4_2/llm.py @@ -29,7 +29,7 @@ from vllm.utils import Counter import torch from torch.nn.utils.rnn import pad_sequence -from verl.trainer.workers.rollout.tokenizer import HybridEngineBaseTokenizer +from verl.workers.rollout.tokenizer import HybridEngineBaseTokenizer class LLM: @@ -129,7 +129,7 @@ def __init__( if not isinstance(tokenizer, tokenizer_cls): raise ValueError( f"Unexpected tokenizer type: {type(tokenizer)}. Must be" - "one of the following: PreTrainedTokenizer, PreTrainedTokenizerFast, verl.trainer.workers.rollout.HybridEngineBaseTokenizer" + "one of the following: PreTrainedTokenizer, PreTrainedTokenizerFast, verl.workers.rollout.HybridEngineBaseTokenizer" ) self.llm_engine = LLMEngine.from_engine_args(model, tokenizer, engine_args) self.request_counter = Counter() diff --git a/verl/third_party/vllm/vllm_v_0_5_4/llm.py b/verl/third_party/vllm/vllm_v_0_5_4/llm.py index f0469ea..5f56f1e 100644 --- a/verl/third_party/vllm/vllm_v_0_5_4/llm.py +++ b/verl/third_party/vllm/vllm_v_0_5_4/llm.py @@ -37,7 +37,7 @@ from vllm.utils import Counter, deprecate_kwargs import torch from torch.nn.utils.rnn import pad_sequence -from verl.trainer.workers.rollout.tokenizer import HybridEngineBaseTokenizer +from verl.workers.rollout.tokenizer import HybridEngineBaseTokenizer class LLM(LLM): @@ -143,7 +143,7 @@ def __init__( if not isinstance(tokenizer, tokenizer_cls): raise ValueError( f"Unexpected tokenizer type: {type(tokenizer)}. Must be" - "one of the following: PreTrainedTokenizer, PreTrainedTokenizerFast, verl.trainer.workers.rollout.HybridEngineBaseTokenizer" + "one of the following: PreTrainedTokenizer, PreTrainedTokenizerFast, verl.workers.rollout.HybridEngineBaseTokenizer" ) self.llm_engine = LLMEngine.from_engine_args(model, tokenizer, engine_args) # TODO: check usagecontext self.request_counter = Counter() diff --git a/verl/third_party/vllm/vllm_v_0_6_3/llm.py b/verl/third_party/vllm/vllm_v_0_6_3/llm.py index 56d2869..cd3d646 100644 --- a/verl/third_party/vllm/vllm_v_0_6_3/llm.py +++ b/verl/third_party/vllm/vllm_v_0_6_3/llm.py @@ -19,7 +19,7 @@ import torch.nn as nn from torch.nn.utils.rnn import pad_sequence from transformers import PretrainedConfig, PreTrainedTokenizer, PreTrainedTokenizerFast -from verl.trainer.workers.rollout.tokenizer import HybridEngineBaseTokenizer +from verl.workers.rollout.tokenizer import HybridEngineBaseTokenizer from vllm import LLM from vllm.outputs import EmbeddingRequestOutput, RequestOutput from vllm.utils import Counter @@ -137,7 +137,7 @@ def __init__( if not isinstance(tokenizer, tokenizer_cls): raise ValueError( f"Unexpected tokenizer type: {type(tokenizer)}. Must be" - "one of the following: PreTrainedTokenizer, PreTrainedTokenizerFast, verl.trainer.workers.rollout.HybridEngineBaseTokenizer" + "one of the following: PreTrainedTokenizer, PreTrainedTokenizerFast, verl.workers.rollout.HybridEngineBaseTokenizer" ) self.llm_engine = LLMEngine.from_engine_args(model, tokenizer, engine_args) # TODO: check usagecontext self.request_counter = Counter() diff --git a/verl/trainer/main_generation.py b/verl/trainer/main_generation.py index 7e65aff..b0bc7d7 100644 --- a/verl/trainer/main_generation.py +++ b/verl/trainer/main_generation.py @@ -31,7 +31,7 @@ from verl import DataProto from verl.utils.fs import copy_local_path_from_hdfs -from verl.trainer.workers.fsdp_workers import ActorRolloutRefWorker +from verl.workers.fsdp_workers import ActorRolloutRefWorker from verl.utils.hdfs_io import makedirs from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup diff --git a/verl/trainer/main_ppo.py b/verl/trainer/main_ppo.py index 9d81087..325e9be 100644 --- a/verl/trainer/main_ppo.py +++ b/verl/trainer/main_ppo.py @@ -120,13 +120,13 @@ def main_task(config): # define worker classes if config.actor_rollout_ref.actor.strategy == 'fsdp': assert config.actor_rollout_ref.actor.strategy == config.critic.strategy - from verl.trainer.workers.fsdp_workers import ActorRolloutRefWorker, CriticWorker + from verl.workers.fsdp_workers import ActorRolloutRefWorker, CriticWorker from verl.single_controller.ray import RayWorkerGroup ray_worker_group_cls = RayWorkerGroup elif config.actor_rollout_ref.actor.strategy == 'megatron': assert config.actor_rollout_ref.actor.strategy == config.critic.strategy - from verl.trainer.workers.megatron_workers import ActorRolloutRefWorker, CriticWorker + from verl.workers.megatron_workers import ActorRolloutRefWorker, CriticWorker from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup ray_worker_group_cls = NVMegatronRayWorkerGroup @@ -159,9 +159,9 @@ def main_task(config): # - The reward type depends on the tag of the data if config.reward_model.enable: if config.reward_model.strategy == 'fsdp': - from verl.trainer.workers.fsdp_workers import RewardModelWorker + from verl.workers.fsdp_workers import RewardModelWorker elif config.reward_model.strategy == 'megatron': - from verl.trainer.workers.megatron_workers import RewardModelWorker + from verl.workers.megatron_workers import RewardModelWorker else: raise NotImplementedError role_worker_mapping[Role.RewardModel] = RewardModelWorker diff --git a/verl/trainer/workers/__init__.py b/verl/workers/__init__.py similarity index 100% rename from verl/trainer/workers/__init__.py rename to verl/workers/__init__.py diff --git a/verl/trainer/workers/actor/__init__.py b/verl/workers/actor/__init__.py similarity index 100% rename from verl/trainer/workers/actor/__init__.py rename to verl/workers/actor/__init__.py diff --git a/verl/trainer/workers/actor/base.py b/verl/workers/actor/base.py similarity index 100% rename from verl/trainer/workers/actor/base.py rename to verl/workers/actor/base.py diff --git a/verl/trainer/workers/actor/dp_actor.py b/verl/workers/actor/dp_actor.py similarity index 99% rename from verl/trainer/workers/actor/dp_actor.py rename to verl/workers/actor/dp_actor.py index ae1712d..e0001c9 100644 --- a/verl/trainer/workers/actor/dp_actor.py +++ b/verl/workers/actor/dp_actor.py @@ -22,7 +22,7 @@ from verl import DataProto from verl.trainer.ppo import core_algos -from verl.trainer.workers.actor import BasePPOActor +from verl.workers.actor import BasePPOActor from verl.utils.py_functional import append_to_dict from verl.utils.torch_functional import logprobs_from_logits diff --git a/verl/trainer/workers/actor/megatron_actor.py b/verl/workers/actor/megatron_actor.py similarity index 99% rename from verl/trainer/workers/actor/megatron_actor.py rename to verl/workers/actor/megatron_actor.py index 9cbeeb1..e674a28 100644 --- a/verl/trainer/workers/actor/megatron_actor.py +++ b/verl/workers/actor/megatron_actor.py @@ -38,7 +38,7 @@ from verl.utils.megatron.pipeline_parallel import (compute_transformers_input_shapes, make_batch_generator) from verl import DataProto from verl.trainer.ppo import core_algos -from verl.trainer.workers.actor import BasePPOActor +from verl.workers.actor import BasePPOActor from verl.utils.py_functional import append_to_dict from verl.utils.torch_functional import logprobs_from_logits, broadcast_dict_tensor, split_dict_tensor_into_batches diff --git a/verl/trainer/workers/critic/__init__.py b/verl/workers/critic/__init__.py similarity index 100% rename from verl/trainer/workers/critic/__init__.py rename to verl/workers/critic/__init__.py diff --git a/verl/trainer/workers/critic/base.py b/verl/workers/critic/base.py similarity index 100% rename from verl/trainer/workers/critic/base.py rename to verl/workers/critic/base.py diff --git a/verl/trainer/workers/critic/dp_critic.py b/verl/workers/critic/dp_critic.py similarity index 99% rename from verl/trainer/workers/critic/dp_critic.py rename to verl/workers/critic/dp_critic.py index ab30a7c..028b6c9 100644 --- a/verl/trainer/workers/critic/dp_critic.py +++ b/verl/workers/critic/dp_critic.py @@ -25,7 +25,7 @@ from verl import DataProto from verl.trainer.ppo import core_algos -from verl.trainer.workers.critic import BasePPOCritic +from verl.workers.critic import BasePPOCritic from verl.utils.py_functional import append_to_dict from verl.utils.torch_functional import masked_mean diff --git a/verl/trainer/workers/critic/megatron_critic.py b/verl/workers/critic/megatron_critic.py similarity index 99% rename from verl/trainer/workers/critic/megatron_critic.py rename to verl/workers/critic/megatron_critic.py index 4a7636e..22ae384 100644 --- a/verl/trainer/workers/critic/megatron_critic.py +++ b/verl/workers/critic/megatron_critic.py @@ -25,7 +25,7 @@ from verl import DataProto from verl.trainer.ppo import core_algos -from verl.trainer.workers.critic import BasePPOCritic +from verl.workers.critic import BasePPOCritic from verl.utils.megatron.pipeline_parallel import (compute_transformers_input_shapes, make_batch_generator) from verl.utils.py_functional import append_to_dict from verl.utils.torch_dtypes import PrecisionType diff --git a/verl/trainer/workers/fsdp_workers.py b/verl/workers/fsdp_workers.py similarity index 98% rename from verl/trainer/workers/fsdp_workers.py rename to verl/workers/fsdp_workers.py index 3a12cd5..45c4197 100644 --- a/verl/trainer/workers/fsdp_workers.py +++ b/verl/workers/fsdp_workers.py @@ -219,14 +219,14 @@ def _build_model_optimizer(self, def _build_rollout(self): if self.config.rollout.name == 'hf': - from verl.trainer.workers.rollout import HFRollout - from verl.trainer.workers.hybrid_engine import BaseShardingManager + from verl.workers.rollout import HFRollout + from verl.workers.hybrid_engine import BaseShardingManager rollout = HFRollout(module=self.actor_module_fsdp, config=self.config.rollout) sharding_manager = BaseShardingManager() # TODO: a sharding manager that do nothing? elif self.config.rollout.name == 'vllm': - from verl.trainer.workers.rollout.vllm_rollout import vLLMRollout - from verl.trainer.workers.hybrid_engine import FSDPVLLMShardingManager + from verl.workers.rollout.vllm_rollout import vLLMRollout + from verl.workers.hybrid_engine import FSDPVLLMShardingManager log_gpu_memory_usage('Before building vllm rollout', logger=None) rollout = vLLMRollout(actor_module=self.actor_module_fsdp, config=self.config.rollout, @@ -245,7 +245,7 @@ def _build_rollout(self): @register(dispatch_mode=Dispatch.ONE_TO_ALL) def init_model(self): - from verl.trainer.workers.actor import DataParallelPPOActor + from verl.workers.actor import DataParallelPPOActor # This is used to import external_lib into the huggingface systems import_external_libs(self.config.model.get('external_lib', None)) @@ -557,7 +557,7 @@ def init_model(self): # This is used to import external_lib into the huggingface systems import_external_libs(self.config.model.get('external_lib', None)) - from verl.trainer.workers.critic import DataParallelPPOCritic + from verl.workers.critic import DataParallelPPOCritic self.critic_module, self.critic_optimizer, self.critic_lr_scheduler = self._build_critic_model_optimizer( self.config) diff --git a/verl/trainer/workers/hybrid_engine/__init__.py b/verl/workers/hybrid_engine/__init__.py similarity index 100% rename from verl/trainer/workers/hybrid_engine/__init__.py rename to verl/workers/hybrid_engine/__init__.py diff --git a/verl/trainer/workers/hybrid_engine/base.py b/verl/workers/hybrid_engine/base.py similarity index 100% rename from verl/trainer/workers/hybrid_engine/base.py rename to verl/workers/hybrid_engine/base.py diff --git a/verl/trainer/workers/hybrid_engine/fsdp_vllm.py b/verl/workers/hybrid_engine/fsdp_vllm.py similarity index 100% rename from verl/trainer/workers/hybrid_engine/fsdp_vllm.py rename to verl/workers/hybrid_engine/fsdp_vllm.py diff --git a/verl/trainer/workers/hybrid_engine/megatron_vllm.py b/verl/workers/hybrid_engine/megatron_vllm.py similarity index 100% rename from verl/trainer/workers/hybrid_engine/megatron_vllm.py rename to verl/workers/hybrid_engine/megatron_vllm.py diff --git a/verl/trainer/workers/megatron_workers.py b/verl/workers/megatron_workers.py similarity index 98% rename from verl/trainer/workers/megatron_workers.py rename to verl/workers/megatron_workers.py index 1ff62fa..f76c4d4 100644 --- a/verl/trainer/workers/megatron_workers.py +++ b/verl/workers/megatron_workers.py @@ -23,10 +23,10 @@ import torch.nn as nn from omegaconf import DictConfig from verl.single_controller.base.megatron.worker import MegatronWorker -from verl.trainer.workers.actor.megatron_actor import MegatronPPOActor -from verl.trainer.workers.critic.megatron_critic import MegatronPPOCritic -from verl.trainer.workers.hybrid_engine import AllGatherPPModel -from verl.trainer.workers.reward_model.megatron.reward_model import MegatronRewardModel +from verl.workers.actor.megatron_actor import MegatronPPOActor +from verl.workers.critic.megatron_critic import MegatronPPOCritic +from verl.workers.hybrid_engine import AllGatherPPModel +from verl.workers.reward_model.megatron.reward_model import MegatronRewardModel from verl.single_controller.base.decorator import register, Dispatch from verl import DataProto @@ -216,8 +216,8 @@ def megatron_actor_model_provider(pre_process, post_process): def _build_rollout(self): if self.config.rollout.name == 'vllm': - from verl.trainer.workers.rollout.vllm_rollout import vLLMRollout - from verl.trainer.workers.hybrid_engine import MegatronVLLMShardingManager + from verl.workers.rollout.vllm_rollout import vLLMRollout + from verl.workers.hybrid_engine import MegatronVLLMShardingManager from verl.utils.model import normalize_pp_vpp_params # NOTE(sgm): If the QKV and gate_up projection layer are concate together in actor, diff --git a/verl/trainer/workers/reward_model/__init__.py b/verl/workers/reward_model/__init__.py similarity index 100% rename from verl/trainer/workers/reward_model/__init__.py rename to verl/workers/reward_model/__init__.py diff --git a/verl/trainer/workers/reward_model/base.py b/verl/workers/reward_model/base.py similarity index 100% rename from verl/trainer/workers/reward_model/base.py rename to verl/workers/reward_model/base.py diff --git a/verl/trainer/workers/reward_model/megatron/__init__.py b/verl/workers/reward_model/megatron/__init__.py similarity index 100% rename from verl/trainer/workers/reward_model/megatron/__init__.py rename to verl/workers/reward_model/megatron/__init__.py diff --git a/verl/trainer/workers/reward_model/megatron/reward_model.py b/verl/workers/reward_model/megatron/reward_model.py similarity index 99% rename from verl/trainer/workers/reward_model/megatron/reward_model.py rename to verl/workers/reward_model/megatron/reward_model.py index 47315b1..2c4c1b6 100644 --- a/verl/trainer/workers/reward_model/megatron/reward_model.py +++ b/verl/workers/reward_model/megatron/reward_model.py @@ -28,7 +28,7 @@ from verl import DataProto from verl.utils.torch_functional import logprobs_from_logits, broadcast_dict_tensor, split_dict_tensor_into_batches from verl.utils.torch_dtypes import PrecisionType -from verl.trainer.workers.reward_model.base import BasePPORewardModel +from verl.workers.reward_model.base import BasePPORewardModel from verl.utils.megatron import sequence_parallel as sp_utils from megatron.core import parallel_state as mpu from megatron.core.pipeline_parallel import get_forward_backward_func diff --git a/verl/trainer/workers/rollout/__init__.py b/verl/workers/rollout/__init__.py similarity index 100% rename from verl/trainer/workers/rollout/__init__.py rename to verl/workers/rollout/__init__.py diff --git a/verl/trainer/workers/rollout/base.py b/verl/workers/rollout/base.py similarity index 100% rename from verl/trainer/workers/rollout/base.py rename to verl/workers/rollout/base.py diff --git a/verl/trainer/workers/rollout/hf_rollout.py b/verl/workers/rollout/hf_rollout.py similarity index 100% rename from verl/trainer/workers/rollout/hf_rollout.py rename to verl/workers/rollout/hf_rollout.py diff --git a/verl/trainer/workers/rollout/naive/__init__.py b/verl/workers/rollout/naive/__init__.py similarity index 100% rename from verl/trainer/workers/rollout/naive/__init__.py rename to verl/workers/rollout/naive/__init__.py diff --git a/verl/trainer/workers/rollout/naive/naive_rollout.py b/verl/workers/rollout/naive/naive_rollout.py similarity index 100% rename from verl/trainer/workers/rollout/naive/naive_rollout.py rename to verl/workers/rollout/naive/naive_rollout.py diff --git a/verl/trainer/workers/rollout/tokenizer.py b/verl/workers/rollout/tokenizer.py similarity index 100% rename from verl/trainer/workers/rollout/tokenizer.py rename to verl/workers/rollout/tokenizer.py diff --git a/verl/trainer/workers/rollout/vllm_rollout/__init__.py b/verl/workers/rollout/vllm_rollout/__init__.py similarity index 100% rename from verl/trainer/workers/rollout/vllm_rollout/__init__.py rename to verl/workers/rollout/vllm_rollout/__init__.py diff --git a/verl/trainer/workers/rollout/vllm_rollout/vllm_rollout.py b/verl/workers/rollout/vllm_rollout/vllm_rollout.py similarity index 99% rename from verl/trainer/workers/rollout/vllm_rollout/vllm_rollout.py rename to verl/workers/rollout/vllm_rollout/vllm_rollout.py index 24d4487..0dc2de0 100644 --- a/verl/trainer/workers/rollout/vllm_rollout/vllm_rollout.py +++ b/verl/workers/rollout/vllm_rollout/vllm_rollout.py @@ -34,7 +34,7 @@ from verl import DataProto from verl.utils.torch_functional import get_eos_mask, pad_sequence_to_length -from verl.trainer.workers.rollout.base import BaseRollout +from verl.workers.rollout.base import BaseRollout from verl.third_party.vllm import LLM, vllm_version from verl.third_party.vllm import parallel_state as vllm_ps from vllm import SamplingParams