Skip to content

Commit

Permalink
[refact] move workers out of trainers
Browse files Browse the repository at this point in the history
  • Loading branch information
PeterSH6 committed Dec 18, 2024
1 parent 69e1e5d commit 6143651
Show file tree
Hide file tree
Showing 35 changed files with 38 additions and 38 deletions.
6 changes: 3 additions & 3 deletions docs/examples/ppo_code_architecture.rst
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,13 @@ Define worker classes
if config.actor_rollout_ref.actor.strategy == 'fsdp': # for FSDP backend
assert config.actor_rollout_ref.actor.strategy == config.critic.strategy
from verl.trainer.workers.fsdp_workers import ActorRolloutRefWorker, CriticWorker
from verl.workers.fsdp_workers import ActorRolloutRefWorker, CriticWorker
from verl.single_controller.ray import RayWorkerGroup
ray_worker_group_cls = RayWorkerGroup
elif config.actor_rollout_ref.actor.strategy == 'megatron': # for Megatron backend
assert config.actor_rollout_ref.actor.strategy == config.critic.strategy
from verl.trainer.workers.megatron_workers import ActorRolloutRefWorker, CriticWorker
from verl.workers.megatron_workers import ActorRolloutRefWorker, CriticWorker
from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup
ray_worker_group_cls = NVMegatronRayWorkerGroup # Ray worker class for Megatron-LM
Expand Down Expand Up @@ -139,7 +139,7 @@ Defining reward model/function
# - finally, we combine all the rewards together
# - The reward type depends on the tag of the data
if config.reward_model.enable:
from verl.trainer.workers.fsdp_workers import RewardModelWorker
from verl.workers.fsdp_workers import RewardModelWorker
role_worker_mapping[Role.RewardModel] = RewardModelWorker
mapping[Role.RewardModel] = global_pool_id
Expand Down
8 changes: 4 additions & 4 deletions examples/split_placement/main_ppo_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,13 +119,13 @@ def main_task(config):
# define worker classes
if config.actor_rollout_ref.actor.strategy == 'fsdp':
assert config.actor_rollout_ref.actor.strategy == config.critic.strategy
from verl.trainer.workers.fsdp_workers import ActorRolloutRefWorker, CriticWorker
from verl.workers.fsdp_workers import ActorRolloutRefWorker, CriticWorker
from verl.single_controller.ray import RayWorkerGroup
ray_worker_group_cls = RayWorkerGroup

elif config.actor_rollout_ref.actor.strategy == 'megatron':
assert config.actor_rollout_ref.actor.strategy == config.critic.strategy
from verl.trainer.workers.megatron_workers import ActorRolloutRefWorker, CriticWorker
from verl.workers.megatron_workers import ActorRolloutRefWorker, CriticWorker
from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup
ray_worker_group_cls = NVMegatronRayWorkerGroup

Expand Down Expand Up @@ -168,9 +168,9 @@ def main_task(config):
# - The reward type depends on the tag of the data
if config.reward_model.enable:
if config.reward_model.strategy == 'fsdp':
from verl.trainer.workers.fsdp_workers import RewardModelWorker
from verl.workers.fsdp_workers import RewardModelWorker
elif config.reward_model.strategy == 'megatron':
from verl.trainer.workers.megatron_workers import RewardModelWorker
from verl.workers.megatron_workers import RewardModelWorker
else:
raise NotImplementedError
role_worker_mapping[Role.RewardModel] = RewardModelWorker
Expand Down
4 changes: 2 additions & 2 deletions verl/third_party/vllm/vllm_v_0_3_1/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from vllm.utils import Counter
import torch
from torch.nn.utils.rnn import pad_sequence
from verl.trainer.workers.rollout.tokenizer import HybridEngineBaseTokenizer
from verl.workers.rollout.tokenizer import HybridEngineBaseTokenizer


class LLM:
Expand Down Expand Up @@ -125,7 +125,7 @@ def __init__(
if not isinstance(tokenizer, tokenizer_cls):
raise ValueError(
f"Unexpected tokenizer type: {type(tokenizer)}. Must be"
"one of the following: PreTrainedTokenizer, PreTrainedTokenizerFast, verl.trainer.workers.rollout.HybridEngineBaseTokenizer"
"one of the following: PreTrainedTokenizer, PreTrainedTokenizerFast, verl.workers.rollout.HybridEngineBaseTokenizer"
)
self.llm_engine = LLMEngine.from_engine_args(model, tokenizer, engine_args)
self.request_counter = Counter()
Expand Down
4 changes: 2 additions & 2 deletions verl/third_party/vllm/vllm_v_0_4_2/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from vllm.utils import Counter
import torch
from torch.nn.utils.rnn import pad_sequence
from verl.trainer.workers.rollout.tokenizer import HybridEngineBaseTokenizer
from verl.workers.rollout.tokenizer import HybridEngineBaseTokenizer


class LLM:
Expand Down Expand Up @@ -129,7 +129,7 @@ def __init__(
if not isinstance(tokenizer, tokenizer_cls):
raise ValueError(
f"Unexpected tokenizer type: {type(tokenizer)}. Must be"
"one of the following: PreTrainedTokenizer, PreTrainedTokenizerFast, verl.trainer.workers.rollout.HybridEngineBaseTokenizer"
"one of the following: PreTrainedTokenizer, PreTrainedTokenizerFast, verl.workers.rollout.HybridEngineBaseTokenizer"
)
self.llm_engine = LLMEngine.from_engine_args(model, tokenizer, engine_args)
self.request_counter = Counter()
Expand Down
4 changes: 2 additions & 2 deletions verl/third_party/vllm/vllm_v_0_5_4/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from vllm.utils import Counter, deprecate_kwargs
import torch
from torch.nn.utils.rnn import pad_sequence
from verl.trainer.workers.rollout.tokenizer import HybridEngineBaseTokenizer
from verl.workers.rollout.tokenizer import HybridEngineBaseTokenizer


class LLM(LLM):
Expand Down Expand Up @@ -143,7 +143,7 @@ def __init__(
if not isinstance(tokenizer, tokenizer_cls):
raise ValueError(
f"Unexpected tokenizer type: {type(tokenizer)}. Must be"
"one of the following: PreTrainedTokenizer, PreTrainedTokenizerFast, verl.trainer.workers.rollout.HybridEngineBaseTokenizer"
"one of the following: PreTrainedTokenizer, PreTrainedTokenizerFast, verl.workers.rollout.HybridEngineBaseTokenizer"
)
self.llm_engine = LLMEngine.from_engine_args(model, tokenizer, engine_args) # TODO: check usagecontext
self.request_counter = Counter()
Expand Down
4 changes: 2 additions & 2 deletions verl/third_party/vllm/vllm_v_0_6_3/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence
from transformers import PretrainedConfig, PreTrainedTokenizer, PreTrainedTokenizerFast
from verl.trainer.workers.rollout.tokenizer import HybridEngineBaseTokenizer
from verl.workers.rollout.tokenizer import HybridEngineBaseTokenizer
from vllm import LLM
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.utils import Counter
Expand Down Expand Up @@ -137,7 +137,7 @@ def __init__(
if not isinstance(tokenizer, tokenizer_cls):
raise ValueError(
f"Unexpected tokenizer type: {type(tokenizer)}. Must be"
"one of the following: PreTrainedTokenizer, PreTrainedTokenizerFast, verl.trainer.workers.rollout.HybridEngineBaseTokenizer"
"one of the following: PreTrainedTokenizer, PreTrainedTokenizerFast, verl.workers.rollout.HybridEngineBaseTokenizer"
)
self.llm_engine = LLMEngine.from_engine_args(model, tokenizer, engine_args) # TODO: check usagecontext
self.request_counter = Counter()
Expand Down
2 changes: 1 addition & 1 deletion verl/trainer/main_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@

from verl import DataProto
from verl.utils.fs import copy_local_path_from_hdfs
from verl.trainer.workers.fsdp_workers import ActorRolloutRefWorker
from verl.workers.fsdp_workers import ActorRolloutRefWorker
from verl.utils.hdfs_io import makedirs
from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup

Expand Down
8 changes: 4 additions & 4 deletions verl/trainer/main_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,13 +120,13 @@ def main_task(config):
# define worker classes
if config.actor_rollout_ref.actor.strategy == 'fsdp':
assert config.actor_rollout_ref.actor.strategy == config.critic.strategy
from verl.trainer.workers.fsdp_workers import ActorRolloutRefWorker, CriticWorker
from verl.workers.fsdp_workers import ActorRolloutRefWorker, CriticWorker
from verl.single_controller.ray import RayWorkerGroup
ray_worker_group_cls = RayWorkerGroup

elif config.actor_rollout_ref.actor.strategy == 'megatron':
assert config.actor_rollout_ref.actor.strategy == config.critic.strategy
from verl.trainer.workers.megatron_workers import ActorRolloutRefWorker, CriticWorker
from verl.workers.megatron_workers import ActorRolloutRefWorker, CriticWorker
from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup
ray_worker_group_cls = NVMegatronRayWorkerGroup

Expand Down Expand Up @@ -159,9 +159,9 @@ def main_task(config):
# - The reward type depends on the tag of the data
if config.reward_model.enable:
if config.reward_model.strategy == 'fsdp':
from verl.trainer.workers.fsdp_workers import RewardModelWorker
from verl.workers.fsdp_workers import RewardModelWorker
elif config.reward_model.strategy == 'megatron':
from verl.trainer.workers.megatron_workers import RewardModelWorker
from verl.workers.megatron_workers import RewardModelWorker
else:
raise NotImplementedError
role_worker_mapping[Role.RewardModel] = RewardModelWorker
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from verl import DataProto
from verl.trainer.ppo import core_algos
from verl.trainer.workers.actor import BasePPOActor
from verl.workers.actor import BasePPOActor
from verl.utils.py_functional import append_to_dict
from verl.utils.torch_functional import logprobs_from_logits

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from verl.utils.megatron.pipeline_parallel import (compute_transformers_input_shapes, make_batch_generator)
from verl import DataProto
from verl.trainer.ppo import core_algos
from verl.trainer.workers.actor import BasePPOActor
from verl.workers.actor import BasePPOActor
from verl.utils.py_functional import append_to_dict
from verl.utils.torch_functional import logprobs_from_logits, broadcast_dict_tensor, split_dict_tensor_into_batches

Expand Down
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

from verl import DataProto
from verl.trainer.ppo import core_algos
from verl.trainer.workers.critic import BasePPOCritic
from verl.workers.critic import BasePPOCritic
from verl.utils.py_functional import append_to_dict
from verl.utils.torch_functional import masked_mean

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

from verl import DataProto
from verl.trainer.ppo import core_algos
from verl.trainer.workers.critic import BasePPOCritic
from verl.workers.critic import BasePPOCritic
from verl.utils.megatron.pipeline_parallel import (compute_transformers_input_shapes, make_batch_generator)
from verl.utils.py_functional import append_to_dict
from verl.utils.torch_dtypes import PrecisionType
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -219,14 +219,14 @@ def _build_model_optimizer(self,

def _build_rollout(self):
if self.config.rollout.name == 'hf':
from verl.trainer.workers.rollout import HFRollout
from verl.trainer.workers.hybrid_engine import BaseShardingManager
from verl.workers.rollout import HFRollout
from verl.workers.hybrid_engine import BaseShardingManager
rollout = HFRollout(module=self.actor_module_fsdp, config=self.config.rollout)
sharding_manager = BaseShardingManager()
# TODO: a sharding manager that do nothing?
elif self.config.rollout.name == 'vllm':
from verl.trainer.workers.rollout.vllm_rollout import vLLMRollout
from verl.trainer.workers.hybrid_engine import FSDPVLLMShardingManager
from verl.workers.rollout.vllm_rollout import vLLMRollout
from verl.workers.hybrid_engine import FSDPVLLMShardingManager
log_gpu_memory_usage('Before building vllm rollout', logger=None)
rollout = vLLMRollout(actor_module=self.actor_module_fsdp,
config=self.config.rollout,
Expand All @@ -245,7 +245,7 @@ def _build_rollout(self):

@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def init_model(self):
from verl.trainer.workers.actor import DataParallelPPOActor
from verl.workers.actor import DataParallelPPOActor
# This is used to import external_lib into the huggingface systems
import_external_libs(self.config.model.get('external_lib', None))

Expand Down Expand Up @@ -557,7 +557,7 @@ def init_model(self):
# This is used to import external_lib into the huggingface systems
import_external_libs(self.config.model.get('external_lib', None))

from verl.trainer.workers.critic import DataParallelPPOCritic
from verl.workers.critic import DataParallelPPOCritic
self.critic_module, self.critic_optimizer, self.critic_lr_scheduler = self._build_critic_model_optimizer(
self.config)

Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@
import torch.nn as nn
from omegaconf import DictConfig
from verl.single_controller.base.megatron.worker import MegatronWorker
from verl.trainer.workers.actor.megatron_actor import MegatronPPOActor
from verl.trainer.workers.critic.megatron_critic import MegatronPPOCritic
from verl.trainer.workers.hybrid_engine import AllGatherPPModel
from verl.trainer.workers.reward_model.megatron.reward_model import MegatronRewardModel
from verl.workers.actor.megatron_actor import MegatronPPOActor
from verl.workers.critic.megatron_critic import MegatronPPOCritic
from verl.workers.hybrid_engine import AllGatherPPModel
from verl.workers.reward_model.megatron.reward_model import MegatronRewardModel

from verl.single_controller.base.decorator import register, Dispatch
from verl import DataProto
Expand Down Expand Up @@ -216,8 +216,8 @@ def megatron_actor_model_provider(pre_process, post_process):

def _build_rollout(self):
if self.config.rollout.name == 'vllm':
from verl.trainer.workers.rollout.vllm_rollout import vLLMRollout
from verl.trainer.workers.hybrid_engine import MegatronVLLMShardingManager
from verl.workers.rollout.vllm_rollout import vLLMRollout
from verl.workers.hybrid_engine import MegatronVLLMShardingManager
from verl.utils.model import normalize_pp_vpp_params

# NOTE(sgm): If the QKV and gate_up projection layer are concate together in actor,
Expand Down
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from verl import DataProto
from verl.utils.torch_functional import logprobs_from_logits, broadcast_dict_tensor, split_dict_tensor_into_batches
from verl.utils.torch_dtypes import PrecisionType
from verl.trainer.workers.reward_model.base import BasePPORewardModel
from verl.workers.reward_model.base import BasePPORewardModel
from verl.utils.megatron import sequence_parallel as sp_utils
from megatron.core import parallel_state as mpu
from megatron.core.pipeline_parallel import get_forward_backward_func
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@

from verl import DataProto
from verl.utils.torch_functional import get_eos_mask, pad_sequence_to_length
from verl.trainer.workers.rollout.base import BaseRollout
from verl.workers.rollout.base import BaseRollout
from verl.third_party.vllm import LLM, vllm_version
from verl.third_party.vllm import parallel_state as vllm_ps
from vllm import SamplingParams
Expand Down

0 comments on commit 6143651

Please sign in to comment.