From b4a3d6b9c52d9bdfe3b3b8cf57144484a73e38d1 Mon Sep 17 00:00:00 2001 From: HL Date: Wed, 4 Dec 2024 19:17:43 -0800 Subject: [PATCH 01/14] [tokenizer] feat: support tokenizers whose pad_token_id is none (#36) * [tokenizer] feat: support tokenizers whose pad_token_id is none * add test to ci * install test version * update ci * dont use gemma for testing * dont use gemma for testing * add proxy * revert dataset test * add back tests * fix format * fix format * fix deps * use git clone instead of https download * fix path * revert and use one yaml for gpu instead * fix path * cleanup * limit pyarrow version * Revert "limit pyarrow version" This reverts commit b924f79a79088c21636269d11a4ec3095af10c09. * lfs * try lfs * do not clone if exist --- .../workflows/{ray_test.yml => gpu_test.yml} | 10 ++- setup.py | 2 +- tests/verl/utils/dataset/test_rl_dataset.py | 61 +++++++++++++++++++ tests/verl/utils/dataset/test_rm_dataset.py | 42 +++++++++++++ verl/trainer/fsdp_sft_trainer.py | 5 +- verl/trainer/main_generation.py | 4 +- verl/trainer/main_ppo.py | 2 + verl/trainer/ppo/workers/fsdp_workers.py | 6 ++ verl/trainer/ppo/workers/megatron_workers.py | 6 ++ verl/utils/__init__.py | 7 ++- verl/utils/dataset/rl_dataset.py | 37 +---------- verl/utils/dataset/rm_dataset.py | 17 ++---- verl/utils/dataset/sft_dataset.py | 4 +- verl/utils/tokenizer.py | 29 +++++++++ 14 files changed, 174 insertions(+), 58 deletions(-) rename .github/workflows/{ray_test.yml => gpu_test.yml} (68%) create mode 100644 tests/verl/utils/dataset/test_rl_dataset.py create mode 100644 tests/verl/utils/dataset/test_rm_dataset.py create mode 100644 verl/utils/tokenizer.py diff --git a/.github/workflows/ray_test.yml b/.github/workflows/gpu_test.yml similarity index 68% rename from .github/workflows/ray_test.yml rename to .github/workflows/gpu_test.yml index 011ea3a..59d0587 100644 --- a/.github/workflows/ray_test.yml +++ b/.github/workflows/gpu_test.yml @@ -18,15 +18,19 @@ on: jobs: ray: - runs-on: [self-hosted, gpu] # test if the enviroment is ready + runs-on: [self-hosted, gpu] steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: fetch-depth: 0 - name: Install the current repository run: | - pip install -e . - - name: Running some ray test that only need 2 GPUs + pip install -e .[test] + - name: Running dataset tests + run: | + [ ! -d "$HOME/verl-data" ] && git clone --depth 1 https://github.com/eric-haibin-lin/verl-data ~/verl-data + pytest -s -x tests/verl + - name: Running ray tests that need 2 GPUs run: | cd tests/ray pytest -s -x test_rvdz.py test_driverfunc_to_worker.py test_data_transfer.py test_colocated_workers.py test_check_worker_alive.py diff --git a/setup.py b/setup.py index 8289d3c..4b194f1 100644 --- a/setup.py +++ b/setup.py @@ -29,7 +29,6 @@ 'pybind11', 'hydra-core', 'numpy', - 'pytest', 'yapf', "dill", "accelerate" @@ -43,6 +42,7 @@ 'demo': ['hydra-core', 'transformers', ''], 'single-controller': ['ray', 'kubernetes'], 'single-controller-ray': ['ray'], + 'test': ['fsspec', 'pytest', 'datasets'] } from pathlib import Path diff --git a/tests/verl/utils/dataset/test_rl_dataset.py b/tests/verl/utils/dataset/test_rl_dataset.py new file mode 100644 index 0000000..3ddeb75 --- /dev/null +++ b/tests/verl/utils/dataset/test_rl_dataset.py @@ -0,0 +1,61 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import torch +from torch.utils.data import DataLoader +from transformers import AutoTokenizer + + +def get_gsm8k_data(): + # prepare test dataset + url = "https://github.com/eric-haibin-lin/verl-data/raw/refs/heads/main/gsm8k/train.parquet" + local_folder = os.path.expanduser('~/verl-data/gsm8k/') + local_path = os.path.join(local_folder, 'train.parquet') + os.makedirs(local_folder, exist_ok=True) + # import fsspec + # with fsspec.open(url, mode='rb') as fin, fsspec.open(local_path, mode='wb') as fout: + # content = fin.read() + # fout.write(content) + return local_path + + +def test_rl_dataset(): + from verl.utils.dataset.rl_dataset import RLHFDataset, collate_fn + tokenizer = AutoTokenizer.from_pretrained('deepseek-ai/deepseek-coder-1.3b-instruct') + from verl.utils import set_pad_token_id + set_pad_token_id(tokenizer) + local_path = get_gsm8k_data() + dataset = RLHFDataset(parquet_files=local_path, tokenizer=tokenizer, prompt_key='prompt', max_prompt_length=256) + + dataloader = DataLoader(dataset=dataset, batch_size=16, shuffle=True, drop_last=True, collate_fn=collate_fn) + + a = next(iter(dataloader)) + + from verl import DataProto + + tensors = {} + non_tensors = {} + + for key, val in a.items(): + if isinstance(val, torch.Tensor): + tensors[key] = val + else: + non_tensors[key] = val + + data_proto = DataProto.from_dict(tensors=tensors, non_tensors=non_tensors) + + data = dataset[0]['input_ids'] + output = tokenizer.batch_decode([data])[0] + print(f'type: type{output}') + print(f'\n\noutput: {output}') diff --git a/tests/verl/utils/dataset/test_rm_dataset.py b/tests/verl/utils/dataset/test_rm_dataset.py new file mode 100644 index 0000000..c139134 --- /dev/null +++ b/tests/verl/utils/dataset/test_rm_dataset.py @@ -0,0 +1,42 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +from transformers import AutoTokenizer +from verl.utils import set_pad_token_id +from verl.utils.dataset.rm_dataset import RMDataset + + +def get_rm_data(): + # prepare test dataset + url = "https://github.com/eric-haibin-lin/verl-data/raw/refs/heads/main/full_hh_rlhf/rm/test.parquet" + local_folder = os.path.expanduser('~/verl-data/full_hh_rlhf/rm/') + local_path = os.path.join(local_folder, 'test.parquet') + os.makedirs(local_folder, exist_ok=True) + # import fsspec + # with fsspec.open(url, mode='rb') as fin, fsspec.open(local_path, mode='wb') as fout: + # content = fin.read() + # fout.write(content) + return local_path + + +def test_rm_dataset(): + tokenizer = AutoTokenizer.from_pretrained("facebook/opt-1.3b") + set_pad_token_id(tokenizer) + local_path = get_rm_data() + dataset = RMDataset(parquet_files=local_path, tokenizer=tokenizer, max_length=512) + data = dataset[0]['input_ids'] + output = tokenizer.batch_decode(data) + assert len(output) > 1 + assert type(output[0]) == str diff --git a/verl/trainer/fsdp_sft_trainer.py b/verl/trainer/fsdp_sft_trainer.py index 43400dd..9af663e 100644 --- a/verl/trainer/fsdp_sft_trainer.py +++ b/verl/trainer/fsdp_sft_trainer.py @@ -24,11 +24,9 @@ os.environ['TOKENIZERS_PARALLELISM'] = 'true' import logging -import functools import re import torch import torch.distributed -import wandb from torch import nn, optim from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, MixedPrecision, ShardingStrategy, CPUOffload from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedModel, AutoConfig @@ -66,7 +64,8 @@ def __init__(self, config, device_mesh: DeviceMesh): local_model_path = copy_local_path_from_hdfs(src=self.config.model.partial_pretrain, verbose=True) self.tokenizer = AutoTokenizer.from_pretrained(local_model_path, trust_remote_code=self.config.model.trust_remote_code) - + from verl.utils import set_pad_token_id + set_pad_token_id(self.tokenizer) if self.config.data.chat_template is not None: raise ValueError('Apply Chat template from config is not supported yet.') diff --git a/verl/trainer/main_generation.py b/verl/trainer/main_generation.py index b13e890..0d17073 100644 --- a/verl/trainer/main_generation.py +++ b/verl/trainer/main_generation.py @@ -25,7 +25,6 @@ from verl.utils.model import compute_position_id_with_mask -import torch import pandas as pd from transformers import AutoTokenizer @@ -45,6 +44,9 @@ def main(config): OmegaConf.resolve(config) local_path = copy_local_path_from_hdfs(config.model.path) tokenizer = AutoTokenizer.from_pretrained(local_path) + from verl.utils import set_pad_token_id + set_pad_token_id(tokenizer) + if config.rollout.temperature == 0.: assert config.data.n_samples == 1, 'When temperature=0, n_samples must be 1.' diff --git a/verl/trainer/main_ppo.py b/verl/trainer/main_ppo.py index c1cdb23..cbb1c61 100644 --- a/verl/trainer/main_ppo.py +++ b/verl/trainer/main_ppo.py @@ -113,6 +113,8 @@ def main_task(config): # instantiate tokenizer tokenizer = AutoTokenizer.from_pretrained(local_path) + from verl.utils import set_pad_token_id + set_pad_token_id(tokenizer) # define worker classes if config.actor_rollout_ref.actor.strategy == 'fsdp': diff --git a/verl/trainer/ppo/workers/fsdp_workers.py b/verl/trainer/ppo/workers/fsdp_workers.py index 5e4cbd6..9d47bae 100644 --- a/verl/trainer/ppo/workers/fsdp_workers.py +++ b/verl/trainer/ppo/workers/fsdp_workers.py @@ -35,6 +35,7 @@ from verl.utils.import_utils import import_external_libs from verl.utils.debug import log_gpu_memory_usage import verl.utils.hdfs_io as hdfs_io +from verl.utils import set_pad_token_id logger = logging.getLogger(__file__) logger.setLevel(os.getenv('VERL_PPO_LOGGING_LEVEL', 'WARN')) @@ -107,6 +108,8 @@ def _build_model_optimizer(self, # note that we have to create model in fp32. Otherwise, the optimizer is in bf16, which is incorrect # TODO(zhangchi.usc1992): 1. support create from random initialized model. 2. Support init with FSDP directly self.tokenizer = AutoTokenizer.from_pretrained(local_path, trust_remote_code=trust_remote_code) + set_pad_token_id(self.tokenizer) + torch_dtype = fsdp_config.get('model_dtype', None) if torch_dtype is None: torch_dtype = torch.float32 if self._is_actor else torch.bfloat16 @@ -466,6 +469,7 @@ def _build_critic_model_optimizer(self, config): tokenizer_path = copy_local_path_from_hdfs(config.model.tokenizer_path) self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, trust_remote_code=config.model.get('trust_remote_code', False)) + set_pad_token_id(self.tokenizer) from omegaconf import OmegaConf override_config = OmegaConf.to_container(self.config.model.get('override_config', OmegaConf.create())) @@ -675,6 +679,8 @@ def _build_model(self, config): self.tokenizer = AutoTokenizer.from_pretrained(local_path, trust_remote_code=config.model.get( 'trust_remote_code', False)) + set_pad_token_id(self.tokenizer) + set_pad_token_id(self.input_tokenizer) trust_remote_code = config.model.get('trust_remote_code', False) model_config = AutoConfig.from_pretrained(local_path, trust_remote_code=trust_remote_code) diff --git a/verl/trainer/ppo/workers/megatron_workers.py b/verl/trainer/ppo/workers/megatron_workers.py index 65b11a8..a481d46 100644 --- a/verl/trainer/ppo/workers/megatron_workers.py +++ b/verl/trainer/ppo/workers/megatron_workers.py @@ -35,6 +35,7 @@ from verl.utils.model import load_megatron_model_weights from verl.utils.megatron_utils import init_model_parallel_config from verl.utils.megatron_utils import offload_megatron_param_and_grad, load_megatron_param_and_grad +from verl.utils import set_pad_token_id from megatron.core import parallel_state as mpu from megatron.core import ModelParallelConfig @@ -136,6 +137,7 @@ def _build_model_optimizer(self, # Step 1: initialize the tokenizer local_path = copy_local_path_from_hdfs(model_path) self.tokenizer = AutoTokenizer.from_pretrained(local_path) + set_pad_token_id(self.tokenizer) # Step 2: get the actor_model_config actor_model_config = AutoConfig.from_pretrained(local_path) @@ -459,6 +461,7 @@ def _build_critic_model_optimizer(self, # Step 1: initialize the tokenizer local_path = copy_local_path_from_hdfs(model_path) self.tokenizer = AutoTokenizer.from_pretrained(local_path) + set_pad_token_id(self.tokenizer) # Step 2: get the actor_model_config critic_model_config = AutoConfig.from_pretrained(local_path) @@ -622,6 +625,7 @@ def _build_rm_model(self, model_path, megatron_config: ModelParallelConfig, over # Step 1: initialize the tokenizer local_path = copy_local_path_from_hdfs(model_path) self.tokenizer = AutoTokenizer.from_pretrained(local_path) + set_pad_token_id(self.tokenizer) # Step 2: get the actor_model_config rm_model_config = AutoConfig.from_pretrained(local_path) @@ -685,11 +689,13 @@ def init_model(self): sft_tokenizer_local_path = copy_local_path_from_hdfs(self.config.model.input_tokenizer) sft_tokenizer = AutoTokenizer.from_pretrained(sft_tokenizer_local_path) + set_pad_token_id(sft_tokenizer) rm_tokenizer_path = self.config.model.get('rm_tokenizer', None) rm_tokenizer = None if rm_tokenizer_path is not None: rm_tokenizer_local_path = copy_local_path_from_hdfs(rm_tokenizer_path) rm_tokenizer = AutoTokenizer.from_pretrained(rm_tokenizer_local_path) + set_pad_token_id(rm_tokenizer) torch_dtype = torch.bfloat16 megatron_config = OmegaConf.create({ diff --git a/verl/utils/__init__.py b/verl/utils/__init__.py index 7a7aadb..e453070 100644 --- a/verl/utils/__init__.py +++ b/verl/utils/__init__.py @@ -10,4 +10,9 @@ # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and -# limitations under the License. \ No newline at end of file +# limitations under the License. + +from . import tokenizer +from .tokenizer import * + +__all__ = tokenizer.__all__ \ No newline at end of file diff --git a/verl/utils/dataset/rl_dataset.py b/verl/utils/dataset/rl_dataset.py index 7ebc39a..d4b18f0 100644 --- a/verl/utils/dataset/rl_dataset.py +++ b/verl/utils/dataset/rl_dataset.py @@ -142,39 +142,4 @@ def __getitem__(self, item): if self.return_raw_chat: row_dict['raw_prompt'] = chat.tolist() - return row_dict - - -if __name__ == '__main__': - from transformers import AutoTokenizer - - from torch.utils.data import DataLoader - - local_path = copy_local_path_from_hdfs('~/models/gemma-1.1-7b-it') - tokenizer = AutoTokenizer.from_pretrained(local_path) - - dataset = RLHFDataset(parquet_files='~/data/rlhf/gsm8k/train.parquet', - tokenizer=tokenizer, - prompt_key='prompt', - max_prompt_length=256) - - dataloader = DataLoader(dataset=dataset, batch_size=16, shuffle=True, drop_last=True, collate_fn=collate_fn) - - a = next(iter(dataloader)) - - from verl import DataProto - - tensors = {} - non_tensors = {} - - for key, val in a.items(): - if isinstance(val, torch.Tensor): - tensors[key] = val - else: - non_tensors[key] = val - - data_proto = DataProto.from_dict(tensors=tensors, non_tensors=non_tensors) - - data = dataset[0]['input_ids'] - output = tokenizer.batch_decode([data])[0] - print(f'\n\noutput: {output}') + return row_dict \ No newline at end of file diff --git a/verl/utils/dataset/rm_dataset.py b/verl/utils/dataset/rm_dataset.py index cc0c6b2..29d3cd9 100644 --- a/verl/utils/dataset/rm_dataset.py +++ b/verl/utils/dataset/rm_dataset.py @@ -18,9 +18,11 @@ import pandas as pd import torch -from torch.utils.data import Dataset, DataLoader +from torch.utils.data import Dataset from transformers import AutoTokenizer +from verl.utils import set_pad_token_id + def download_files_distributed(download_fn): import torch.distributed @@ -53,6 +55,7 @@ def __init__(self, self.cache_dir = os.path.expanduser(cache_dir) if isinstance(tokenizer, str): tokenizer = AutoTokenizer.from_pretrained(tokenizer) + set_pad_token_id(tokenizer) self.tokenizer = tokenizer self.prompt_key = prompt_key @@ -138,14 +141,4 @@ def __getitem__(self, item): return { 'input_ids': input_ids, 'attention_mask': attention_mask, - } - - -if __name__ == '__main__': - tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1", add_bos_token=False) - - dataset = RMDataset(parquet_files='~/data/full_hh_rlhf/rm/train.parquet', tokenizer=tokenizer, max_length=512) - data = dataset[0]['input_ids'] - output = tokenizer.batch_decode(data) - print(output[0]) - print(output[1]) + } \ No newline at end of file diff --git a/verl/utils/dataset/sft_dataset.py b/verl/utils/dataset/sft_dataset.py index fa64e2b..c39d7be 100644 --- a/verl/utils/dataset/sft_dataset.py +++ b/verl/utils/dataset/sft_dataset.py @@ -28,6 +28,7 @@ from verl.utils.fs import copy_local_path_from_hdfs from verl.utils.model import compute_position_id_with_mask +from verl.utils import set_pad_token_id class SFTDataset(Dataset): @@ -51,6 +52,7 @@ def __init__(self, self.parquet_files = parquet_files if isinstance(tokenizer, str): tokenizer = AutoTokenizer.from_pretrained(tokenizer) + set_pad_token_id(tokenizer) self.tokenizer: PreTrainedTokenizer = tokenizer self.prompt_key = prompt_key @@ -148,7 +150,7 @@ def __getitem__(self, item): if __name__ == '__main__': local_model_path = copy_local_path_from_hdfs('~/models/gemma-2b-it') tokenizer = AutoTokenizer.from_pretrained(local_model_path) - + set_pad_token_id(tokenizer) dataset = SFTDataset(parquet_files='~/data/gsm8k/train.parquet', tokenizer=tokenizer, prompt_key='question', diff --git a/verl/utils/tokenizer.py b/verl/utils/tokenizer.py new file mode 100644 index 0000000..55de19f --- /dev/null +++ b/verl/utils/tokenizer.py @@ -0,0 +1,29 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Utils for tokenization.""" + +__all__ = ['set_pad_token_id'] + + +def set_pad_token_id(tokenizer): + """Set pad_token_id to eos_token_id if it is None. + + Args: + tokenizer (transformers.PreTrainedTokenizer): The tokenizer to be set. + + """ + if tokenizer.pad_token_id is None: + tokenizer.pad_token_id = tokenizer.eos_token_id + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token From c592a8be4cf3b05ec7b354e0d062dd7686d16ba3 Mon Sep 17 00:00:00 2001 From: Guangming Sheng Date: Fri, 6 Dec 2024 14:23:41 +0800 Subject: [PATCH 02/14] [rollout] feat: support vLLM v0.6.3 and fix hf rollout import issue (#33) * [feat] support vllm spmd version in v0.6.3 * [misc] fix hf_weight loader * [misc] rollout: update vllm version and fix hf import * lint * [misc] fix init * [doc] feat: modify doc to support vllm v6 --- README.md | 2 +- docs/preparation/install.rst | 2 +- requirements.txt | 6 +- setup.py | 2 +- verl/third_party/vllm/__init__.py | 8 +- .../third_party/vllm/vllm_v_0_6_3/__init__.py | 13 + .../vllm/vllm_v_0_6_3/arg_utils.py | 78 ++++ verl/third_party/vllm/vllm_v_0_6_3/config.py | 105 +++++ .../vllm_v_0_6_3/dtensor_weight_loaders.py | 380 ++++++++++++++++ .../vllm/vllm_v_0_6_3/hf_weight_loader.py | 41 ++ verl/third_party/vllm/vllm_v_0_6_3/llm.py | 200 +++++++++ .../vllm/vllm_v_0_6_3/llm_engine_sp.py | 408 ++++++++++++++++++ .../vllm_v_0_6_3/megatron_weight_loaders.py | 308 +++++++++++++ .../vllm/vllm_v_0_6_3/model_loader.py | 332 ++++++++++++++ .../vllm/vllm_v_0_6_3/model_runner.py | 182 ++++++++ .../vllm/vllm_v_0_6_3/parallel_state.py | 312 ++++++++++++++ .../vllm/vllm_v_0_6_3/spmd_gpu_executor.py | 256 +++++++++++ .../vllm/vllm_v_0_6_3/tokenizer.py | 40 ++ verl/third_party/vllm/vllm_v_0_6_3/worker.py | 333 ++++++++++++++ verl/trainer/ppo/hybrid_engine/__init__.py | 2 + .../ppo/rollout/vllm_rollout/vllm_rollout.py | 4 +- 21 files changed, 3004 insertions(+), 10 deletions(-) create mode 100644 verl/third_party/vllm/vllm_v_0_6_3/__init__.py create mode 100644 verl/third_party/vllm/vllm_v_0_6_3/arg_utils.py create mode 100644 verl/third_party/vllm/vllm_v_0_6_3/config.py create mode 100644 verl/third_party/vllm/vllm_v_0_6_3/dtensor_weight_loaders.py create mode 100644 verl/third_party/vllm/vllm_v_0_6_3/hf_weight_loader.py create mode 100644 verl/third_party/vllm/vllm_v_0_6_3/llm.py create mode 100644 verl/third_party/vllm/vllm_v_0_6_3/llm_engine_sp.py create mode 100644 verl/third_party/vllm/vllm_v_0_6_3/megatron_weight_loaders.py create mode 100644 verl/third_party/vllm/vllm_v_0_6_3/model_loader.py create mode 100644 verl/third_party/vllm/vllm_v_0_6_3/model_runner.py create mode 100644 verl/third_party/vllm/vllm_v_0_6_3/parallel_state.py create mode 100644 verl/third_party/vllm/vllm_v_0_6_3/spmd_gpu_executor.py create mode 100644 verl/third_party/vllm/vllm_v_0_6_3/tokenizer.py create mode 100644 verl/third_party/vllm/vllm_v_0_6_3/worker.py diff --git a/README.md b/README.md index 922cf87..21cde82 100644 --- a/README.md +++ b/README.md @@ -69,7 +69,7 @@ The following dependencies are required for all backends. pip install torch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0 --index-url https://download.pytorch.org/whl/cu121 # install vllm -pip3 install vllm==0.5.4 +pip3 install vllm==0.6.3 # or you can install 0.5.4, 0.4.2 and 0.3.1 pip3 install ray==2.10 # other version may have bug # flash attention 2 diff --git a/docs/preparation/install.rst b/docs/preparation/install.rst index 1c623c1..9a932e9 100644 --- a/docs/preparation/install.rst +++ b/docs/preparation/install.rst @@ -45,7 +45,7 @@ found in :doc:`FSDP Workers<../workers/fsdp_workers>`. pip install torch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0 --index-url https://download.pytorch.org/whl/cu121 # install vllm - pip3 install vllm==0.5.4 + pip3 install vllm==0.6.3 # or you can install 0.5.4, 0.4.2 and 0.3.1 pip3 install ray==2.10 # other version may have bug # flash attention 2 diff --git a/requirements.txt b/requirements.txt index 823cfda..ca102e9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,12 +1,10 @@ transformers hydra-core -tensordict < 0.3.1 +tensordict==0.5.0 numpy pytest -deepspeed pybind11 codetiming yapf wandb -git+https://github.com/NVIDIA/TransformerEngine.git@stable -# vllm==0.5.4 # vllm is installed in image building to avoid ray conflicts \ No newline at end of file +git+https://github.com/NVIDIA/TransformerEngine.git@stable \ No newline at end of file diff --git a/setup.py b/setup.py index 4b194f1..9a97a50 100644 --- a/setup.py +++ b/setup.py @@ -35,7 +35,7 @@ ] install_optional = [ - 'vllm==0.5.4', + 'vllm==0.6.3', ] extras_require = { diff --git a/verl/third_party/vllm/__init__.py b/verl/third_party/vllm/__init__.py index 9eee28f..290c837 100644 --- a/verl/third_party/vllm/__init__.py +++ b/verl/third_party/vllm/__init__.py @@ -40,6 +40,12 @@ def get_version(pkg): from .vllm_v_0_5_4.llm import LLM from .vllm_v_0_5_4.llm import LLMEngine from .vllm_v_0_5_4 import parallel_state +elif package_version == '0.6.3': + vllm_version = '0.6.3' + from .vllm_v_0_6_3.llm import LLM + from .vllm_v_0_6_3.llm import LLMEngine + from .vllm_v_0_6_3 import parallel_state else: raise ValueError( - f'vllm version {package_version} not supported. Currently supported versions are 0.3.1, 0.4.2, and 0.5.4.') + f'vllm version {package_version} not supported. Currently supported versions are 0.3.1, 0.4.2, 0.5.4 and 0.6.3.' + ) diff --git a/verl/third_party/vllm/vllm_v_0_6_3/__init__.py b/verl/third_party/vllm/vllm_v_0_6_3/__init__.py new file mode 100644 index 0000000..1ce90c5 --- /dev/null +++ b/verl/third_party/vllm/vllm_v_0_6_3/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/verl/third_party/vllm/vllm_v_0_6_3/arg_utils.py b/verl/third_party/vllm/vllm_v_0_6_3/arg_utils.py new file mode 100644 index 0000000..bc4685c --- /dev/null +++ b/verl/third_party/vllm/vllm_v_0_6_3/arg_utils.py @@ -0,0 +1,78 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023 The vLLM team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/engine/arg_utils.py + +import os +from dataclasses import dataclass + +from transformers import PretrainedConfig +from vllm.config import EngineConfig +from vllm.engine.arg_utils import EngineArgs + +from .config import LoadConfig, ModelConfig + + +@dataclass +class EngineArgs(EngineArgs): + model_hf_config: PretrainedConfig = None # for verl + + def __post_init__(self): + pass + + def create_model_config(self) -> ModelConfig: + return ModelConfig( + hf_config=self.model_hf_config, + tokenizer_mode=self.tokenizer_mode, + trust_remote_code=self.trust_remote_code, + dtype=self.dtype, + seed=self.seed, + revision=self.revision, + code_revision=self.code_revision, + rope_scaling=self.rope_scaling, + rope_theta=self.rope_theta, + tokenizer_revision=self.tokenizer_revision, + max_model_len=self.max_model_len, + quantization=self.quantization, + quantization_param_path=self.quantization_param_path, + enforce_eager=self.enforce_eager, + max_context_len_to_capture=self.max_context_len_to_capture, + max_seq_len_to_capture=self.max_seq_len_to_capture, + max_logprobs=self.max_logprobs, + disable_sliding_window=self.disable_sliding_window, + skip_tokenizer_init=self.skip_tokenizer_init, + served_model_name=self.served_model_name, + limit_mm_per_prompt=self.limit_mm_per_prompt, + use_async_output_proc=not self.disable_async_output_proc, + override_neuron_config=self.override_neuron_config, + config_format=self.config_format, + mm_processor_kwargs=self.mm_processor_kwargs, + ) + + def create_load_config(self) -> LoadConfig: + return LoadConfig( + load_format=self.load_format, + download_dir=self.download_dir, + model_loader_extra_config=self.model_loader_extra_config, + ignore_patterns=self.ignore_patterns, + ) + + def create_engine_config(self) -> EngineConfig: + engine_config = super().create_engine_config() + + # NOTE[VERL]: Use the world_size set by torchrun + world_size = int(os.getenv("WORLD_SIZE", "-1")) + assert world_size != -1, "The world_size is set to -1, not initialized by TORCHRUN" + engine_config.parallel_config.world_size = world_size + + return engine_config diff --git a/verl/third_party/vllm/vllm_v_0_6_3/config.py b/verl/third_party/vllm/vllm_v_0_6_3/config.py new file mode 100644 index 0000000..d7cee45 --- /dev/null +++ b/verl/third_party/vllm/vllm_v_0_6_3/config.py @@ -0,0 +1,105 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023 The vLLM team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py + +import enum +import json +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, List, Optional, Union + +from transformers import PretrainedConfig + +# Add for verl +from vllm.config import ModelConfig +from vllm.logger import init_logger +from vllm.utils import is_hip + +if TYPE_CHECKING: + from vllm.model_executor.model_loader.loader import BaseModelLoader + +logger = init_logger(__name__) + + +class LoadFormat(str, enum.Enum): + AUTO = "auto" + MEGATRON = "megatron" + HF = "hf" + DTENSOR = "dtensor" + DUMMY_HF = "dummy_hf" + DUMMY_MEGATRON = "dummy_megatron" + DUMMY_DTENSOR = "dummy_dtensor" + + +class ModelConfig(ModelConfig): + + def __init__(self, hf_config: PretrainedConfig, *args, **kwargs) -> None: + super().__init__(model=hf_config._name_or_path, tokenizer=hf_config._name_or_path, *args, **kwargs) + self.hf_config = hf_config + + +@dataclass +class LoadConfig: + """ + download_dir: Directory to download and load the weights, default to the + default cache directory of huggingface. + load_format: The format of the model weights to load: + "auto" will try to load the weights in the safetensors format and + fall back to the pytorch bin format if safetensors format is + not available. + "pt" will load the weights in the pytorch bin format. + "safetensors" will load the weights in the safetensors format. + "npcache" will load the weights in pytorch format and store + a numpy cache to speed up the loading. + "dummy" will initialize the weights with random values, which is + mainly for profiling. + "tensorizer" will use CoreWeave's tensorizer library for + fast weight loading. + "bitsandbytes" will load nf4 type weights. + ignore_patterns: The list of patterns to ignore when loading the model. + Default to "original/**/*" to avoid repeated loading of llama's + checkpoints. + + """ + + load_format: Union[str, LoadFormat, "BaseModelLoader"] = LoadFormat.AUTO + download_dir: Optional[str] = None + model_loader_extra_config: Optional[Union[str, dict]] = field(default_factory=dict) + ignore_patterns: Optional[Union[List[str], str]] = None + + def __post_init__(self): + model_loader_extra_config = self.model_loader_extra_config or {} + if isinstance(model_loader_extra_config, str): + self.model_loader_extra_config = json.loads(model_loader_extra_config) + self._verify_load_format() + + if self.ignore_patterns is not None and len(self.ignore_patterns) > 0: + logger.info("Ignoring the following patterns when downloading weights: %s", self.ignore_patterns) + else: + self.ignore_patterns = ["original/**/*"] + + def _verify_load_format(self) -> None: + if not isinstance(self.load_format, str): + return + + load_format = self.load_format.lower() + self.load_format = LoadFormat(load_format) + + rocm_not_supported_load_format: List[str] = [] + if is_hip() and load_format in rocm_not_supported_load_format: + rocm_supported_load_format = [ + f for f in LoadFormat.__members__ if (f not in rocm_not_supported_load_format) + ] + raise ValueError(f"load format '{load_format}' is not supported in ROCm. " + f"Supported load formats are " + f"{rocm_supported_load_format}") diff --git a/verl/third_party/vllm/vllm_v_0_6_3/dtensor_weight_loaders.py b/verl/third_party/vllm/vllm_v_0_6_3/dtensor_weight_loaders.py new file mode 100644 index 0000000..a3042ca --- /dev/null +++ b/verl/third_party/vllm/vllm_v_0_6_3/dtensor_weight_loaders.py @@ -0,0 +1,380 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023 The vLLM team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/model_loader + +from typing import Dict + +import torch.nn as nn +from torch.distributed._tensor import DTensor +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.utils import is_pp_missing_parameter + + +def gemma_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(vllm_model.named_parameters()) + for name, loaded_weight in actor_weights.items(): + for param_name, shard_name, shard_id in stacked_params_mapping: + if shard_name not in name: + continue + stacked_name = name.replace(shard_name, param_name) + # Skip loading extra bias for GPTQ models. + if stacked_name.endswith(".bias") and stacked_name not in params_dict: + continue + local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight) + param = params_dict[stacked_name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, local_loaded_weight.to(dtype=param.dtype), shard_id) + break + else: + # lm_head is not used in vllm as it is tied with embed_token. + # To prevent errors, skip loading lm_head.weight. + if "lm_head.weight" in name: + continue + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight) + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, local_loaded_weight.to(dtype=param.dtype)) + + +def gptbigcode_dtensor_load_weights(actor_weights: Dict, vllm_model: nn.Module): + params_dict = dict(vllm_model.named_parameters(remove_duplicate=False)) + for name, loaded_weight in actor_weights.items(): + if "lm_head.weight" in name: + continue + if ".attn.bias" in name: + # Skip attention mask. + # NOTE: "c_attn.bias" should not be skipped. + continue + local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight) + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, local_loaded_weight.to(dtype=param.dtype)) + + +def starcoder2_dtensor_load_weights(actor_weights: Dict, vllm_model: nn.Module): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + + params_dict = dict(vllm_model.named_parameters(remove_duplicate=False)) + for name, loaded_weight in actor_weights.items(): + if "rotary_emb.inv_freq" in name: + continue + + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, local_loaded_weight.to(dtype=param.dtype), shard_id) + break + else: + if vllm_model.config.tie_word_embeddings and "lm_head.weight" in name: + continue + param = params_dict[name] + local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight) + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, local_loaded_weight.to(dtype=param.dtype)) + + +def llama_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + (".gate_up_proj", ".gate_proj", 0), + (".gate_up_proj", ".up_proj", 1), + ] + params_dict = dict(vllm_model.named_parameters()) + for name, loaded_weight in actor_weights.items(): + if "rotary_emb.inv_freq" in name: + continue + if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + continue + # With tie_word_embeddings, we can skip lm_head.weight + # The weight might appear unnecessarily in the files if the model is + # processed with quantization, LoRA, fine-tuning, etc. + if vllm_model.config.tie_word_embeddings and "lm_head.weight" in name: + continue + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, local_loaded_weight.to(dtype=param.dtype), shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight) + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, local_loaded_weight) + + +def qwen2_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(vllm_model.named_parameters(remove_duplicate=False)) + for name, loaded_weight in actor_weights.items(): + if "rotary_emb.inv_freq" in name: + continue + if vllm_model.config.tie_word_embeddings and "lm_head.weight" in name: + continue + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, local_loaded_weight.to(dtype=param.dtype), shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight) + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, local_loaded_weight.to(dtype=param.dtype)) + + +def qwen2vl_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(vllm_model.named_parameters(remove_duplicate=False)) + for name, loaded_weight in actor_weights.items(): + if "rotary_emb.inv_freq" in name: + continue + if vllm_model.config.tie_word_embeddings and "lm_head.weight" in name: + continue + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, local_loaded_weight.to(dtype=param.dtype), shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight) + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, local_loaded_weight.to(dtype=param.dtype)) + + +from vllm.model_executor.layers.fused_moe import FusedMoE + + +def deepseekv2_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=vllm_model.config.n_routed_experts, + ) + + params_dict = dict(vllm_model.named_parameters(remove_duplicate=False)) + for name, loaded_weight in actor_weights.items(): + if "rotary_emb.inv_freq" in name: + continue + for param_name, weight_name, shard_id in stacked_params_mapping: + # Skip non-stacked layers and experts (experts handled below). + if weight_name not in name: + continue + # We have mlp.experts[0].gate_proj in the checkpoint. + # Since we handle the experts below in expert_params_mapping, + # we need to skip here BEFORE we update the name, otherwise + # name will be updated to mlp.experts[0].gate_up_proj, which + # will then be updated below in expert_params_mapping + # for mlp.experts[0].gate_gate_up_proj, which breaks load. + if ("mlp.experts." in name) and name not in params_dict: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + if is_pp_missing_parameter(name, vllm_model): + continue + + param = params_dict[name] + local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight) + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, local_loaded_weight.to(dtype=param.dtype), shard_id) + break + else: + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + + if is_pp_missing_parameter(name, vllm_model): + continue + + param = params_dict[name] + local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight) + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader( + param, + local_loaded_weight.to(dtype=param.dtype), + weight_name, + shard_id=shard_id, + expert_id=expert_id, + ) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + if is_pp_missing_parameter(name, vllm_model): + continue + + param = params_dict[name] + local_loaded_weight = redistribute_dtensor(param_name=name, loaded_weights=loaded_weight) + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, local_loaded_weight.to(dtype=param.dtype)) + + +def gpt2_dtensor_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module: + pass + + +def redistribute_dtensor(param_name: str, loaded_weights: DTensor, parallelize_plan: Dict = None): + param_name = _process_parameter_names(name=param_name) + if parallelize_plan is not None: + assert ( + param_name + in parallelize_plan.keys()), f"param name: {param_name} not in parallelize_plan :{parallelize_plan.keys()}" + placement = parallelize_plan[param_name] + local_loaded_weights = loaded_weights.redistribute(device_mesh=loaded_weights.device_mesh, + placements=placement).to_local() + else: + local_loaded_weights = loaded_weights.full_tensor() + return local_loaded_weights + + +def _process_parameter_names(name): + # Remove '.weight' if it exists at the end of the string + if name.endswith(".weight"): + name = name[:-7] + + # Remove 'model.layers.x.' or 'model.' prefix + if "model.layers" in name: + parts = name.split(".") + # Reconstruct the string without 'model.layers.x.' + name = ".".join(parts[3:]) # parts[0] is 'model', parts[1] is 'layers', parts[2] is 'x' + elif name.startswith("model."): + name = name[6:] # Remove 'model.' + + return name + + +__MODEL_DTENSOR_WEIGHT_LOADER_REGISTRY__ = { + "GPT2LMHeadModel": gpt2_dtensor_weight_loader, + "LlamaForCausalLM": llama_dtensor_weight_loader, + "LLaMAForCausalLM": llama_dtensor_weight_loader, + "MistralForCausalLM": llama_dtensor_weight_loader, # mistral is the same as llama in vLLM + "InternLMForCausalLM": llama_dtensor_weight_loader, + "AquilaModel": llama_dtensor_weight_loader, + "AquilaForCausalLM": llama_dtensor_weight_loader, + "Phi3ForCausalLM": llama_dtensor_weight_loader, + "GemmaForCausalLM": gemma_dtensor_weight_loader, + "Gemma2ForCausalLM": gemma_dtensor_weight_loader, + "GPTBigCodeForCausalLM": gptbigcode_dtensor_load_weights, + "Starcoder2ForCausalLM": starcoder2_dtensor_load_weights, + "Qwen2ForCausalLM": qwen2_dtensor_weight_loader, + "DeepseekV2ForCausalLM": deepseekv2_dtensor_weight_loader, + "Qwen2VLForConditionalGeneration": qwen2vl_dtensor_weight_loader, +} + + +# the actor model is .state_dict() +# Load dtensor weights +def load_dtensor_weights(actor_weights: Dict, vllm_model: nn.Module): + weight_loader = _get_model_weight_loader(vllm_model.__class__.__name__) + weight_loader(actor_weights, vllm_model) + # NOTE(sgm) to reduce peak memory usage, we offload vllm model to cpu + # after init, and we need this after sync model weights for in first iter. + vllm_model = vllm_model.cuda() + + +def _get_model_weight_loader(arch: str): + if arch in __MODEL_DTENSOR_WEIGHT_LOADER_REGISTRY__: + return __MODEL_DTENSOR_WEIGHT_LOADER_REGISTRY__[arch] + raise ValueError(f"Model architectures {arch} are not supported for now. " + f"Supported architectures: {__MODEL_DTENSOR_WEIGHT_LOADER_REGISTRY__.keys()}") + + +# NOTE(sgm): we use per-parameter weight loader in each vllm sub +def update_dtensor_weight_loader(): + pass diff --git a/verl/third_party/vllm/vllm_v_0_6_3/hf_weight_loader.py b/verl/third_party/vllm/vllm_v_0_6_3/hf_weight_loader.py new file mode 100644 index 0000000..a3e5b22 --- /dev/null +++ b/verl/third_party/vllm/vllm_v_0_6_3/hf_weight_loader.py @@ -0,0 +1,41 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023 The vLLM team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/model_loader + +from typing import Dict + +import torch.nn as nn +from vllm.model_executor.model_loader.utils import set_default_torch_dtype + + +def update_hf_weight_loader(): + print("no hf weight loader need to be updated") + return + + +def load_hf_weights(actor_weights: Dict, vllm_model: nn.Module): + assert isinstance(actor_weights, Dict) + with set_default_torch_dtype(next(vllm_model.parameters()).dtype): # TODO + if vllm_model.config.tie_word_embeddings and "lm_head.weight" in actor_weights.keys(): + del actor_weights["lm_head.weight"] + vllm_model.load_weights(actor_weights.items()) + for _, module in vllm_model.named_modules(): + quant_method = getattr(module, "quant_method", None) + if quant_method is not None: + quant_method.process_weights_after_loading(module) + # FIXME: Remove this after Mixtral is updated + # to use quant_method. + if hasattr(module, "process_weights_after_loading"): + module.process_weights_after_loading() + vllm_model = vllm_model.cuda() diff --git a/verl/third_party/vllm/vllm_v_0_6_3/llm.py b/verl/third_party/vllm/vllm_v_0_6_3/llm.py new file mode 100644 index 0000000..9351457 --- /dev/null +++ b/verl/third_party/vllm/vllm_v_0_6_3/llm.py @@ -0,0 +1,200 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023 The vLLM team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/llm.py + +from typing import Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +from torch.nn.utils.rnn import pad_sequence +from transformers import PretrainedConfig, PreTrainedTokenizer, PreTrainedTokenizerFast +from verl.trainer.ppo.rollout.tokenizer import HybridEngineBaseTokenizer +from vllm import LLM +from vllm.outputs import EmbeddingRequestOutput, RequestOutput +from vllm.utils import Counter + +from .arg_utils import EngineArgs +from .llm_engine_sp import LLMEngine + + +class LLM(LLM): + """An LLM for generating texts from given prompts and sampling parameters. + + This class includes a tokenizer, a language model (possibly distributed + across multiple GPUs), and GPU memory space allocated for intermediate + states (aka KV cache). Given a batch of prompts and sampling parameters, + this class generates texts from the model, using an intelligent batching + mechanism and efficient memory management. + + NOTE: This class is intended to be used for offline inference. For online + serving, use the `AsyncLLMEngine` class instead. + NOTE: For the comprehensive list of arguments, see `EngineArgs`. + + Args: + model: A HuggingFace Transformers model instance. + tokenizer: A HuggingFace Transformers tokenizer instance. + tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer + if available, and "slow" will always use the slow tokenizer. + trust_remote_code: Trust remote code (e.g., from HuggingFace) when + downloading the model and tokenizer. + tensor_parallel_size: The number of GPUs to use for distributed + execution with tensor parallelism. + dtype: The data type for the model weights and activations. Currently, + we support `float32`, `float16`, and `bfloat16`. If `auto`, we use + the `torch_dtype` attribute specified in the model config file. + However, if the `torch_dtype` in the config is `float32`, we will + use `float16` instead. + quantization: The method used to quantize the model weights. Currently, + we support "awq". If None, we assume the model weights are not + quantized and use `dtype` to determine the data type of the weights. + revision: The specific model version to use. It can be a branch name, + a tag name, or a commit id. + tokenizer_revision: The specific tokenizer version to use. It can be a + branch name, a tag name, or a commit id. + seed: The seed to initialize the random number generator for sampling. + gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to + reserve for the model weights, activations, and KV cache. Higher + values will increase the KV cache size and thus improve the model's + throughput. However, if the value is too high, it may cause out-of- + memory (OOM) errors. + swap_space: The size (GiB) of CPU memory per GPU to use as swap space. + This can be used for temporarily storing the states of the requests + when their `best_of` sampling parameters are larger than 1. If all + requests will have `best_of=1`, you can safely set this to 0. + Otherwise, too small values may cause out-of-memory (OOM) errors. + enforce_eager: Whether to enforce eager execution. If True, we will + disable CUDA graph and always execute the model in eager mode. + If False, we will use CUDA graph and eager execution in hybrid. + max_context_len_to_capture: Maximum context len covered by CUDA graphs. + When a sequence has context length larger than this, we fall back + to eager mode. + disable_custom_all_reduce: See ParallelConfig + """ + + def __init__( + self, + model: Union[nn.Module, Dict], # model itself or its parameter dict + tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast, HybridEngineBaseTokenizer], + model_hf_config: PretrainedConfig, + tokenizer_mode: str = "auto", + trust_remote_code: bool = False, + skip_tokenizer_init: bool = False, + tensor_parallel_size: int = 1, + dtype: str = "auto", + quantization: Optional[str] = None, + revision: Optional[str] = None, + tokenizer_revision: Optional[str] = None, + seed: int = 0, + gpu_memory_utilization: float = 0.9, + swap_space: int = 4, + cpu_offload_gb: float = 0, + enforce_eager: bool = False, + max_context_len_to_capture: Optional[int] = None, + max_seq_len_to_capture: int = 8192, + disable_custom_all_reduce: bool = False, + load_format="auto", + **kwargs, + ) -> None: + if "disable_log_stats" not in kwargs: + kwargs["disable_log_stats"] = True + removed_vision_keys = ("image_token_id", "image_feature_size", "image_input_shape", "image_input_type") + if any(k in kwargs for k in removed_vision_keys): + raise TypeError("There is no need to pass vision-related arguments anymore.") + engine_args = EngineArgs( + model_hf_config=model_hf_config, + # tokenizer=tokenizer, + tokenizer_mode=tokenizer_mode, + skip_tokenizer_init=skip_tokenizer_init, + trust_remote_code=trust_remote_code, + tensor_parallel_size=tensor_parallel_size, + dtype=dtype, + quantization=quantization, + revision=revision, + tokenizer_revision=tokenizer_revision, + seed=seed, + gpu_memory_utilization=gpu_memory_utilization, + swap_space=swap_space, + cpu_offload_gb=cpu_offload_gb, + enforce_eager=enforce_eager, + max_context_len_to_capture=max_context_len_to_capture, + max_seq_len_to_capture=max_seq_len_to_capture, + disable_custom_all_reduce=disable_custom_all_reduce, + load_format=load_format, + **kwargs, + ) + tokenizer_cls = (PreTrainedTokenizer, PreTrainedTokenizerFast, HybridEngineBaseTokenizer) + if not isinstance(tokenizer, tokenizer_cls): + raise ValueError( + f"Unexpected tokenizer type: {type(tokenizer)}. Must be" + "one of the following: PreTrainedTokenizer, PreTrainedTokenizerFast, verl.trainer.ppo.rollout.HybridEngineBaseTokenizer" + ) + self.llm_engine = LLMEngine.from_engine_args(model, tokenizer, engine_args) # TODO: check usagecontext + self.request_counter = Counter() + + def init_cache_engine(self): + self.llm_engine.init_cache_engine() + + def free_cache_engine(self): + self.llm_engine.free_cache_engine() + + def get_tokenizer(self) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: + return self.llm_engine.tokenizer + + def set_tokenizer( + self, + tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], + ) -> None: + self.llm_engine.tokenizer = tokenizer + + def _run_engine(self, *, use_tqdm: bool) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: + outputs = super()._run_engine(use_tqdm=use_tqdm) + return self._post_process_outputs(outputs) + + # # NOTE(shengguangming): add for verl + # # TODO(sgm): we can optimize it by making the dataloader yield List[int] without padding. + # def _pre_process_inputs(self, prompt_token_ids: torch.Tensor) -> List[int]: + # # remove the left padding in the prompt token_id + # pad_token_id = self.llm_engine.tokenizer.pad_token_id if self.llm_engine.tokenizer.pad_token_id is not None else self.llm_engine.tokenizer.eos_token_id + # non_pad_index = torch.nonzero(prompt_token_ids != pad_token_id, as_tuple=False)[0][0] + # token_ids = prompt_token_ids[non_pad_index:].tolist() + # return token_ids + + # NOTE(shengguangming): add for verl + def _post_process_outputs(self, request_outputs: List[RequestOutput]) -> Tuple[torch.Tensor, torch.Tensor]: + output_token_ids = [] + logprobs = [] + for request_output in request_outputs: # List[RequestOutput] + outputs = request_output.outputs + for output in outputs: # List[CompletionOutput], usually len == 1 + output_token_ids.append(torch.tensor(output.token_ids)) + # TODO(shengguangming): can be optimzied by rewrite the Sampler._get_logprobs() logits + logprobs_dicts = output.logprobs + if logprobs_dicts is not None: + logprob = [] + for logprobs_dict, id in zip(logprobs_dicts, output.token_ids): + logprob.append(logprobs_dict[id].logprob) + logprobs.append(torch.tensor(logprob)) + + pad_token_id = (self.llm_engine.tokenizer.pad_token_id if self.llm_engine.tokenizer.pad_token_id is not None + else self.llm_engine.tokenizer.eos_token_id) + output_token_ids = pad_sequence(output_token_ids, batch_first=True, padding_value=pad_token_id) + if len(logprobs) > 0: + logprobs = pad_sequence(logprobs, batch_first=True, padding_value=pad_token_id) + return output_token_ids, logprobs + + def sync_model_weights(self, actor_weights: Dict[str, torch.Tensor], load_format: str) -> None: + self.llm_engine.sync_model_weights(actor_weights=actor_weights, load_format=load_format) + + def offload_model_weights(self) -> None: + self.llm_engine.offload_model_weights() diff --git a/verl/third_party/vllm/vllm_v_0_6_3/llm_engine_sp.py b/verl/third_party/vllm/vllm_v_0_6_3/llm_engine_sp.py new file mode 100644 index 0000000..10b112b --- /dev/null +++ b/verl/third_party/vllm/vllm_v_0_6_3/llm_engine_sp.py @@ -0,0 +1,408 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023 The vLLM team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/engine/llm_engine.py + +from functools import partial +from typing import Callable, Dict, Optional, Type, Union + +import torch +import torch.nn as nn +from vllm.config import ( + CacheConfig, + DecodingConfig, + DeviceConfig, + EngineConfig, + LoadConfig, + LoRAConfig, + ModelConfig, + ObservabilityConfig, + ParallelConfig, + PromptAdapterConfig, + SchedulerConfig, + SpeculativeConfig, +) +from vllm.core.scheduler import Scheduler +from vllm.engine.arg_utils import EngineArgs +from vllm.engine.llm_engine import LLMEngine, SchedulerContext, SchedulerOutputState, _load_generation_config_dict +from vllm.engine.metrics_types import StatLoggerBase +from vllm.engine.output_processor.interfaces import SequenceGroupOutputProcessor +from vllm.engine.output_processor.stop_checker import StopChecker +from vllm.executor.executor_base import ExecutorBase +from vllm.inputs import INPUT_REGISTRY, InputRegistry +from vllm.inputs.preprocess import InputPreprocessor +from vllm.logger import init_logger +from vllm.sequence import Sequence +from vllm.tracing import init_tracer +from vllm.transformers_utils.detokenizer import Detokenizer +from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.usage.usage_lib import UsageContext, is_usage_stats_enabled, usage_message +from vllm.utils import Counter, weak_bind +from vllm.version import __version__ as VLLM_VERSION + +from .arg_utils import EngineArgs +from .config import LoadConfig, ModelConfig +from .tokenizer import TokenizerGroup + +logger = init_logger(__name__) +_LOCAL_LOGGING_INTERVAL_SEC = 5 + + +class LLMEngine(LLMEngine): + """An LLM engine that receives requests and generates texts. + + This is the main class for the vLLM engine. It receives requests + from clients and generates texts from the LLM. It includes a tokenizer, a + language model (possibly distributed across multiple GPUs), and GPU memory + space allocated for intermediate states (aka KV cache). This class utilizes + iteration-level scheduling and efficient memory management to maximize the + serving throughput. + + The :class:`~vllm.LLM` class wraps this class for offline batched inference + and the :class:`AsyncLLMEngine` class wraps this class for online serving. + + The config arguments are derived from :class:`~vllm.EngineArgs`. (See + :ref:`engine_args`) + + Args: + model_config: The configuration related to the LLM model. + cache_config: The configuration related to the KV cache memory + management. + parallel_config: The configuration related to distributed execution. + scheduler_config: The configuration related to the request scheduler. + device_config: The configuration related to the device. + lora_config (Optional): The configuration related to serving multi-LoRA. + speculative_config (Optional): The configuration related to speculative + decoding. + executor_class: The model executor class for managing distributed + execution. + prompt_adapter_config (Optional): The configuration related to serving + prompt adapters. + log_stats: Whether to log statistics. + usage_context: Specified entry point, used for usage info collection. + """ + + def __init__( + self, + # NOTE(sgm): first two arguments are added for verl + model: Union[nn.Module, Dict], # model itself or its parameter dict + tokenizer: nn.Module, + # NOTE(sgm): vllm original arguments + model_config: ModelConfig, + cache_config: CacheConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + device_config: DeviceConfig, + load_config: LoadConfig, + lora_config: Optional[LoRAConfig], + speculative_config: Optional[SpeculativeConfig], + decoding_config: Optional[DecodingConfig], + observability_config: Optional[ObservabilityConfig], + prompt_adapter_config: Optional[PromptAdapterConfig], + executor_class: Type[ExecutorBase], + log_stats: bool, + usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, + stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, + input_registry: InputRegistry = INPUT_REGISTRY, + use_cached_outputs: bool = False, + ) -> None: + logger.info( + "Initializing an LLM engine (v%s) with config: " + "model=%r, speculative_config=%r, tokenizer=%r, " + "skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, " + "override_neuron_config=%s, " + "rope_scaling=%r, rope_theta=%r, tokenizer_revision=%s, " + "trust_remote_code=%s, dtype=%s, max_seq_len=%d, " + "download_dir=%r, load_format=%s, tensor_parallel_size=%d, " + "pipeline_parallel_size=%d, " + "disable_custom_all_reduce=%s, quantization=%s, " + "enforce_eager=%s, kv_cache_dtype=%s, " + "quantization_param_path=%s, device_config=%s, " + "decoding_config=%r, observability_config=%r, " + "seed=%d, served_model_name=%s, use_v2_block_manager=%s, " + "num_scheduler_steps=%d, chunked_prefill_enabled=%s " + "multi_step_stream_outputs=%s, enable_prefix_caching=%s, " + "use_async_output_proc=%s, use_cached_outputs=%s, " + "mm_processor_kwargs=%s)", + VLLM_VERSION, + model_config.model, + speculative_config, + model_config.tokenizer, + model_config.skip_tokenizer_init, + model_config.tokenizer_mode, + model_config.revision, + model_config.override_neuron_config, + model_config.rope_scaling, + model_config.rope_theta, + model_config.tokenizer_revision, + model_config.trust_remote_code, + model_config.dtype, + model_config.max_model_len, + load_config.download_dir, + load_config.load_format, + parallel_config.tensor_parallel_size, + parallel_config.pipeline_parallel_size, + parallel_config.disable_custom_all_reduce, + model_config.quantization, + model_config.enforce_eager, + cache_config.cache_dtype, + model_config.quantization_param_path, + device_config.device, + decoding_config, + observability_config, + model_config.seed, + model_config.served_model_name, + scheduler_config.use_v2_block_manager, + scheduler_config.num_scheduler_steps, + scheduler_config.chunked_prefill_enabled, + scheduler_config.multi_step_stream_outputs, + cache_config.enable_prefix_caching, + model_config.use_async_output_proc, + use_cached_outputs, + model_config.mm_processor_kwargs, + ) + # TODO(woosuk): Print more configs in debug mode. + self.model_config = model_config + self.cache_config = cache_config + self.lora_config = lora_config + self.parallel_config = parallel_config + self.scheduler_config = scheduler_config + self.device_config = device_config + self.speculative_config = speculative_config + self.load_config = load_config + self.decoding_config = decoding_config or DecodingConfig() + self.prompt_adapter_config = prompt_adapter_config + self.observability_config = observability_config or ObservabilityConfig() + self.log_stats = log_stats + self.use_cached_outputs = use_cached_outputs + + if not self.model_config.skip_tokenizer_init: + self.tokenizer = self._init_tokenizer(tokenizer) + self.detokenizer = Detokenizer(self.tokenizer) + tokenizer_group = self.get_tokenizer_group() + else: + self.tokenizer = None + self.detokenizer = None + tokenizer_group = None + + # Ensure that the function doesn't contain a reference to self, + # to avoid engine GC issues + def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer: + assert tokenizer_group, "tokenizer_group cannot be None, " "make sure skip_tokenizer_init is False" + return tokenizer_group.get_lora_tokenizer(sequence.lora_request) + + self.seq_counter = Counter() + self.generation_config_fields = _load_generation_config_dict(model_config) + + self.input_preprocessor = InputPreprocessor(model_config, self.tokenizer) + + self.input_registry = input_registry + self.input_processor = input_registry.create_input_processor(model_config) + + self.model_executor = executor_class( + model=model, # add for spmd_gpu_executor + model_config=model_config, + cache_config=cache_config, + parallel_config=parallel_config, + scheduler_config=scheduler_config, + device_config=device_config, + lora_config=lora_config, + speculative_config=speculative_config, + load_config=load_config, + prompt_adapter_config=prompt_adapter_config, + observability_config=self.observability_config, + ) + + if not self.model_config.embedding_mode: + self._initialize_kv_caches() + + # If usage stat is enabled, collect relevant info. + if is_usage_stats_enabled(): + from vllm.model_executor.model_loader import get_architecture_class_name + + usage_message.report_usage( + get_architecture_class_name(model_config), + usage_context, + extra_kvs={ + # Common configuration + "dtype": str(model_config.dtype), + "tensor_parallel_size": parallel_config.tensor_parallel_size, + "block_size": cache_config.block_size, + "gpu_memory_utilization": cache_config.gpu_memory_utilization, + # Quantization + "quantization": model_config.quantization, + "kv_cache_dtype": str(cache_config.cache_dtype), + # Feature flags + "enable_lora": bool(lora_config), + "enable_prompt_adapter": bool(prompt_adapter_config), + "enable_prefix_caching": cache_config.enable_prefix_caching, + "enforce_eager": model_config.enforce_eager, + "disable_custom_all_reduce": parallel_config.disable_custom_all_reduce, + }, + ) + + if self.tokenizer: + # Ping the tokenizer to ensure liveness if it runs in a + # different process. + self.tokenizer.ping() + + self.cached_scheduler_outputs = [ + SchedulerOutputState() for _ in range(self.parallel_config.pipeline_parallel_size) + ] + + self.scheduler_contexts = [ + SchedulerContext(multi_step_stream_outputs=self.scheduler_config.multi_step_stream_outputs) + for _ in range(self.parallel_config.pipeline_parallel_size) + ] + + if model_config.use_async_output_proc: + process_model_outputs = weak_bind(self._process_model_outputs) + + self.async_callbacks = [ + partial(process_model_outputs, ctx=self.scheduler_contexts[v_id]) + for v_id in range(self.parallel_config.pipeline_parallel_size) + ] + else: + self.async_callbacks = [] + + # Currently used by AsyncLLMEngine to ensure quick append + # of request outputs to asyncio queues + self.process_request_outputs_callback: Optional[Callable] = None + + # Create the scheduler. + # NOTE: the cache_config here have been updated with the numbers of + # GPU and CPU blocks, which are profiled in the distributed executor. + self.scheduler = [ + Scheduler( + scheduler_config, + cache_config, + lora_config, + parallel_config.pipeline_parallel_size, + self.async_callbacks[v_id] if model_config.use_async_output_proc else None, + ) for v_id in range(parallel_config.pipeline_parallel_size) + ] + + # Metric Logging. + if self.log_stats: + if stat_loggers is not None: + self.stat_loggers = stat_loggers + else: + # Lazy import for prometheus multiprocessing. + # We need to set PROMETHEUS_MULTIPROC_DIR environment variable + # before prometheus_client is imported. + # See https://prometheus.github.io/client_python/multiprocess/ + from vllm.engine.metrics import LoggingStatLogger, PrometheusStatLogger + + self.stat_loggers = { + "logging": + LoggingStatLogger(local_interval=_LOCAL_LOGGING_INTERVAL_SEC), + "prometheus": + PrometheusStatLogger( + local_interval=_LOCAL_LOGGING_INTERVAL_SEC, + labels=dict(model_name=model_config.served_model_name), + max_model_len=self.model_config.max_model_len, + ), + } + self.stat_loggers["prometheus"].info("cache_config", self.cache_config) + + self.tracer = None + if self.observability_config.otlp_traces_endpoint: + self.tracer = init_tracer("vllm.llm_engine", self.observability_config.otlp_traces_endpoint) + + # Create sequence output processor, e.g. for beam search or + # speculative decoding. + self.output_processor = SequenceGroupOutputProcessor.create_output_processor( + self.scheduler_config, + self.detokenizer, + self.scheduler, + self.seq_counter, + get_tokenizer_for_seq, + stop_checker=StopChecker( + self.scheduler_config.max_model_len, + get_tokenizer_for_seq, + ), + ) + + # TODO(sgm): add for verl but we may not tokenizer in Rollout + def _init_tokenizer(self, tokenizer, **tokenizer_init_kwargs): + init_kwargs = dict(enable_lora=bool(self.lora_config), + max_num_seqs=self.scheduler_config.max_num_seqs, + max_input_length=None) + init_kwargs.update(tokenizer_init_kwargs) + return TokenizerGroup(tokenizer, **init_kwargs) + + def init_cache_engine(self): + # TODO: check whether we should rebuild the CUDAGraph every iter when offload/load KVCache + # Re-capture CUDAGraph would be time-consuming + self.model_executor.init_cache_engine() + + def free_cache_engine(self): + self.model_executor.free_cache_engine() + + # NOTE(sgm): currently, we only support GPU executor + # The GPUExecutor remove the Ray dependency + @classmethod + def _get_executor_cls(cls, engine_config: EngineConfig) -> Type[ExecutorBase]: + distributed_executor_backend = engine_config.parallel_config.distributed_executor_backend + # Initialize the cluster and specify the executor class.] + assert (engine_config.device_config.device_type == "cuda" + ), "Currently, the vllm in verl only support running on GPU" + + # print('Waiting for debugger'); import os,debugpy; debugpy.listen(('localhost', 5678 + int(os.getenv('RANK', '0')))); debugpy.wait_for_client() + if engine_config.parallel_config.world_size == 1: + engine_config.load_config.load_format = "dummy_hf" + + from .spmd_gpu_executor import SPMDGPUExecutor + + executor_class = SPMDGPUExecutor + + return executor_class + + @classmethod + def from_engine_args( + cls, + model, + tokenizer, + engine_args: EngineArgs, + usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, + stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, + ) -> "LLMEngine": + """Creates an LLM engine from the engine arguments.""" + # Create the engine configs. + engine_config = engine_args.create_engine_config() + executor_class = cls._get_executor_cls(engine_config) + # Initialize the cluster and specify the executor class. + assert (engine_config.device_config.device_type == "cuda" + ), "Currently, the vllm in verl only support running on GPU" + + from .spmd_gpu_executor import SPMDGPUExecutor + + executor_class = SPMDGPUExecutor + + # Create the LLM engine. + engine = cls( + model, + tokenizer, + **engine_config.to_dict(), + executor_class=executor_class, + log_stats=not engine_args.disable_log_stats, + usage_context=usage_context, + stat_loggers=stat_loggers, + ) + return engine + + def sync_model_weights(self, actor_weights: Dict[str, torch.Tensor], load_format: str) -> None: + self.model_executor.sync_model_weights(actor_weights=actor_weights, load_format=load_format) + + def offload_model_weights(self) -> None: + self.model_executor.offload_model_weights() diff --git a/verl/third_party/vllm/vllm_v_0_6_3/megatron_weight_loaders.py b/verl/third_party/vllm/vllm_v_0_6_3/megatron_weight_loaders.py new file mode 100644 index 0000000..7fd6c0e --- /dev/null +++ b/verl/third_party/vllm/vllm_v_0_6_3/megatron_weight_loaders.py @@ -0,0 +1,308 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023 The vLLM team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/model_loader + +from typing import Dict + +import torch +import torch.nn as nn +from vllm.model_executor.layers.linear import * +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead, VocabParallelEmbedding +from vllm.model_executor.models import ModelRegistry + + +# NOTE(shengguangming): replace the origin weight loader function in the class +def parallel_weight_loader(self, param: torch.Tensor, loaded_weight: torch.Tensor) -> None: + """Parallel Linear weight loader.""" + assert (param.size() == loaded_weight.size( + )), "the parameter size is not align with the loaded weight size, param size: {}, loaded_weight size: {}".format( + param.size(), loaded_weight.size()) + assert (param.data.dtype == loaded_weight.data.dtype + ), "if we want to shared weights, the data type should also be the same" + + param.data = loaded_weight.data + + +def default_weight_loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None: + """Default weight loader.""" + assert param.size() == loaded_weight.size() + assert (param.data.dtype == loaded_weight.data.dtype + ), "if we want to shared weights, the data type should also be the same" + + param.data = loaded_weight.data + + +def gpt2_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module: + params_dict = dict(vllm_model.named_parameters(remove_duplicate=False)) + for name, loaded_weight in actor_weights.items(): + if "lm_head.weight" in name: + # GPT-2 ties the weights of the embedding layer and the final + # linear layer. + continue + if ".attn.bias" in name or ".attn.masked_bias" in name: + # Skip attention mask. + # NOTE: "c_attn.bias" should not be skipped. + continue + if not name.startswith("transformer."): + name = "transformer." + name + param = params_dict[name] + # The HF's GPT-2 implementation uses Conv1D instead of Linear. + # Because of this, we need to transpose the weights. + # Note(zhuohan): the logic below might break quantized models. + for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]: + if conv1d_weight_name not in name: + continue + if not name.endswith(".weight"): + continue + # TODO: check megatron + loaded_weight = loaded_weight.t() + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + + +def llama_megatron_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module: + # NOTE(shengguangming): the megatron llama may have this prefix + params_dict = dict(vllm_model.named_parameters()) + for name, loaded_weight in actor_weights.items(): + if "rotary_emb.inv_freq" in name: + continue + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + + +def llama_megatron_core_te_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module: + params_mapping = [ + # (megatron core gpt model name, vllm model name) + ("embedding.word_embeddings", "model.embed_tokens"), + ("self_attention.linear_qkv.layer_norm_weight", "input_layernorm.weight"), + ("self_attention.linear_qkv.layer_norm_bias", "input_layernorm.bias"), + ("self_attention.linear_qkv", "self_attn.qkv_proj"), + ("self_attention.linear_qkv", "self_attn.qkv_proj"), + ("self_attention.linear_proj", "self_attn.o_proj"), + ("pre_mlp_layernorm", "post_attention_layernorm"), + ("mlp.linear_fc1.layer_norm_weight", "post_attention_layernorm.weight"), + ("mlp.linear_fc1.layer_norm_bias", "post_attention_layernorm.bias"), + ("mlp.linear_fc1", "mlp.gate_up_proj"), + ("mlp.linear_fc2", "mlp.down_proj"), + ("decoder.final_layernorm", "model.norm"), + ("output_layer", "lm_head"), + ] + # NOTE(shengguangming): the megatron llama may have this prefix + params_dict = dict(vllm_model.named_parameters()) + for name, loaded_weight in actor_weights.items(): + name = _replace_name(name, params_mapping) + if name.endswith(".bias") and name not in params_dict: + continue + if "rotary_emb.inv_freq" in name: + continue + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + + +def llama_megatron_core_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module: + params_mapping = [ + # (megatron core gpt model name, vllm model name) + ("embedding.word_embeddings", "model.embed_tokens"), + ("self_attention.linear_qkv", "self_attn.qkv_proj"), + ("self_attention.linear_proj", "self_attn.o_proj"), + ( + "input_layernorm", + "input_layernorm", + ), + ("pre_mlp_layernorm", "post_attention_layernorm"), + ("mlp.linear_fc1", "mlp.gate_up_proj"), + ("mlp.linear_fc2", "mlp.down_proj"), + ("decoder.final_layernorm", "model.norm"), + ("output_layer", "lm_head"), + ] + # NOTE(shengguangming): the megatron llama may have this prefix + params_dict = dict(vllm_model.named_parameters()) + for name, loaded_weight in actor_weights.items(): + name = _replace_name(name, params_mapping) + if name.endswith(".bias") and name not in params_dict: + continue + if "rotary_emb.inv_freq" in name: + continue + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + + +def _replace_name(megatron_name, name_mapping): + for m_name, v_name in name_mapping: + if m_name not in megatron_name: + continue + if "layers" in megatron_name: # deal with decoder layers + megatron_name = megatron_name.replace("decoder", "model") + megatron_name_list = megatron_name.split(".") + if "layer_norm_weight" in megatron_name_list or "layer_norm_bias" in megatron_name_list: + param_name_list = megatron_name_list[:3] + param_name_list.append(v_name) + param_name = ".".join(param_name_list) + else: + param_name_list = megatron_name_list[:3] + weight_or_bias = megatron_name_list[-1] + param_name_list.append(v_name) + param_name_list.append(weight_or_bias) + param_name = ".".join(param_name_list) + return param_name + else: + param_name = megatron_name.replace(m_name, v_name) + return param_name + + +def llama_megatron_core_te_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module: + params_mapping = [ + # (megatron core gpt model name, vllm model name) + ("embedding.word_embeddings", "model.embed_tokens"), + ("self_attention.linear_qkv.layer_norm_weight", "input_layernorm.weight"), + ("self_attention.linear_qkv.layer_norm_bias", "input_layernorm.bias"), + ("self_attention.linear_qkv", "self_attn.qkv_proj"), + ("self_attention.linear_qkv", "self_attn.qkv_proj"), + ("self_attention.linear_proj", "self_attn.o_proj"), + ("pre_mlp_layernorm", "post_attention_layernorm"), + ("mlp.linear_fc1.layer_norm_weight", "post_attention_layernorm.weight"), + ("mlp.linear_fc1.layer_norm_bias", "post_attention_layernorm.bias"), + ("mlp.linear_fc1", "mlp.gate_up_proj"), + ("mlp.linear_fc2", "mlp.down_proj"), + ("decoder.final_layernorm", "model.norm"), + ("output_layer", "lm_head"), + ] + # NOTE(shengguangming): the megatron llama may have this prefix + params_dict = dict(vllm_model.named_parameters()) + for name, loaded_weight in actor_weights.items(): + name = _replace_name(name, params_mapping) + if name.endswith(".bias") and name not in params_dict: + continue + if "rotary_emb.inv_freq" in name: + continue + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + + +def llama_megatron_core_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module: + params_mapping = [ + # (megatron core gpt model name, vllm model name) + ("embedding.word_embeddings", "model.embed_tokens"), + ("self_attention.linear_qkv", "self_attn.qkv_proj"), + ("self_attention.linear_proj", "self_attn.o_proj"), + ( + "input_layernorm", + "input_layernorm", + ), + ("pre_mlp_layernorm", "post_attention_layernorm"), + ("mlp.linear_fc1", "mlp.gate_up_proj"), + ("mlp.linear_fc2", "mlp.down_proj"), + ("decoder.final_layernorm", "model.norm"), + ("output_layer", "lm_head"), + ] + # NOTE(shengguangming): the megatron llama may have this prefix + params_dict = dict(vllm_model.named_parameters()) + for name, loaded_weight in actor_weights.items(): + name = _replace_name(name, params_mapping) + if name.endswith(".bias") and name not in params_dict: + continue + if "rotary_emb.inv_freq" in name: + continue + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + + +def _replace_name(megatron_name, name_mapping): + for m_name, v_name in name_mapping: + if m_name not in megatron_name: + continue + if "layers" in megatron_name: # deal with decoder layers + megatron_name = megatron_name.replace("decoder", "model") + megatron_name_list = megatron_name.split(".") + if "layer_norm_weight" in megatron_name_list or "layer_norm_bias" in megatron_name_list: + param_name_list = megatron_name_list[:3] + param_name_list.append(v_name) + param_name = ".".join(param_name_list) + else: + param_name_list = megatron_name_list[:3] + weight_or_bias = megatron_name_list[-1] + param_name_list.append(v_name) + param_name_list.append(weight_or_bias) + param_name = ".".join(param_name_list) + return param_name + else: + param_name = megatron_name.replace(m_name, v_name) + return param_name + + +def mistral_megatron_weight_loader(actor_weights: Dict, vllm_model: nn.Module) -> nn.Module: + # TODO: need to implement a general way to deal with prefix + params_dict = dict(vllm_model.named_parameters()) + for name, loaded_weight in actor_weights.items(): + if "rotary_emb.inv_freq" in name: + continue + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + + +__LAYER_WEIGHT_MEGATRON_LOADER_REGISTRY__ = { + ColumnParallelLinear: parallel_weight_loader, + MergedColumnParallelLinear: parallel_weight_loader, + QKVParallelLinear: parallel_weight_loader, + RowParallelLinear: parallel_weight_loader, + VocabParallelEmbedding: parallel_weight_loader, + ParallelLMHead: parallel_weight_loader, + # "ScaledActivation.weight_loader": ScaledActivation, # TODO(shengguangming): latest commit in vllm fix awq for this function and add load_weights + # "default_weight_loader": default_weight_loader +} + +# for layer_class, weight_loader in __LAYER_WEIGHT_MEGATRON_LOADER_REGISTRY__.items(): +# # setattr(layer_class, 'megatron_weight_loader', weight_loader) +# layer_class.weight_loader = weight_loader + +__MODEL_MEGATRON_WEIGHT_LOADER_REGISTRY__ = { + "GPT2LMHeadModel": gpt2_weight_loader, + "LlamaForCausalLM": llama_megatron_weight_loader, # use te backend for open-source megatron + "LLaMAForCausalLM": llama_megatron_weight_loader, + "MistralForCausalLM": mistral_megatron_weight_loader, +} + + +# the actor model is .state_dict() +# Load megatron weights +def load_megatron_weights(actor_weights: Dict, vllm_model: nn.Module): + weight_loader = _get_model_weight_loader(vllm_model.__class__.__name__) + weight_loader(actor_weights, vllm_model) + # NOTE(sgm) to reduce peak memory usage, we offload vllm model to cpu + # after init, and we need this after sync model weights for in first iter. + vllm_model = vllm_model.cuda() + + +def _get_model_weight_loader(arch: str): + if arch in __MODEL_MEGATRON_WEIGHT_LOADER_REGISTRY__: + return __MODEL_MEGATRON_WEIGHT_LOADER_REGISTRY__[arch] + raise ValueError(f"Model architectures {arch} are not supported for now. " + f"Supported architectures: {ModelRegistry.get_supported_archs()}") + + +def update_megatron_weight_loader(): + for layer_class, weight_loader in __LAYER_WEIGHT_MEGATRON_LOADER_REGISTRY__.items(): + layer_class.weight_loader = weight_loader diff --git a/verl/third_party/vllm/vllm_v_0_6_3/model_loader.py b/verl/third_party/vllm/vllm_v_0_6_3/model_loader.py new file mode 100644 index 0000000..2f32a91 --- /dev/null +++ b/verl/third_party/vllm/vllm_v_0_6_3/model_loader.py @@ -0,0 +1,332 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023 The vLLM team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/models +"""Utilities for selecting and loading models.""" +from typing import Dict, Optional, Union + +import torch +import torch.nn as nn +from transformers import PreTrainedModel +from vllm.config import CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig +from vllm.distributed.communication_op import tensor_model_parallel_all_gather +from vllm.model_executor.model_loader import BaseModelLoader +from vllm.model_executor.model_loader.loader import _initialize_model +from vllm.model_executor.model_loader.utils import set_default_torch_dtype + +from .config import LoadConfig, LoadFormat, ModelConfig +from .dtensor_weight_loaders import load_dtensor_weights, update_dtensor_weight_loader +from .hf_weight_loader import update_hf_weight_loader +from .megatron_weight_loaders import load_megatron_weights, update_megatron_weight_loader + + +def get_model( + actor_model: Union[PreTrainedModel, Dict], + model_config: ModelConfig, + load_config: LoadConfig, + device_config: DeviceConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + lora_config: Optional[LoRAConfig], + cache_config: CacheConfig = None, +) -> nn.Module: + loader = get_model_loader(load_config) + if load_config.load_format.startswith("dummy"): + return loader.load_model( + model_config=model_config, + device_config=device_config, + lora_config=lora_config, + parallel_config=parallel_config, + scheduler_config=scheduler_config, + cache_config=cache_config, + ) + else: + return loader.load_model( + actor_model=actor_model, + model_config=model_config, + device_config=device_config, + lora_config=lora_config, + parallel_config=parallel_config, + scheduler_config=scheduler_config, + cache_config=cache_config, + ) + + +def get_model_loader(load_config: LoadConfig) -> BaseModelLoader: + """Get a model loader based on the load format.""" + + if isinstance(load_config.load_format, type): + return load_config.load_format(load_config) + + if load_config.load_format == LoadFormat.AUTO: + update_megatron_weight_loader() + return MegatronLoader(load_config) + + # NOTE(sgm): change the weight_loader function in runtime + if load_config.load_format == LoadFormat.MEGATRON: + update_megatron_weight_loader() + return MegatronLoader(load_config) + + if load_config.load_format == LoadFormat.HF: + update_hf_weight_loader() + return HFLoader(load_config) + + if load_config.load_format == LoadFormat.DTENSOR: + update_dtensor_weight_loader() + return DTensorLoader(load_config) + + if load_config.load_format == LoadFormat.DUMMY_HF: + update_hf_weight_loader() + return DummyModelLoader(load_config) + + if load_config.load_format == LoadFormat.DUMMY_MEGATRON: + update_megatron_weight_loader() + return DummyModelLoader(load_config) + + if load_config.load_format == LoadFormat.DUMMY_DTENSOR: + update_dtensor_weight_loader() + return DummyModelLoader(load_config) + + raise ValueError("load format not supported in verl: {}, only support {} and {}".format( + load_config.load_format, LoadFormat.MEGATRON, LoadFormat.HF)) + + +class DummyModelLoader(BaseModelLoader): + """Model loader that will set model weights to random values.""" + + def __init__(self, load_config: LoadConfig): + super().__init__(load_config) + if load_config.model_loader_extra_config: + raise ValueError(f"Model loader extra config is not supported for " + f"load format {load_config.load_format}") + + def download_model(self, model_config: ModelConfig) -> None: + pass + + def load_model( + self, + *, + model_config: ModelConfig, + device_config: DeviceConfig, + lora_config: Optional[LoRAConfig], + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + cache_config: CacheConfig, + ) -> nn.Module: + with set_default_torch_dtype(model_config.dtype): + with torch.device(device_config.device): + model = _initialize_model(model_config, self.load_config, lora_config, cache_config, scheduler_config) + # NOTE(woosuk): For accurate performance evaluation, we assign + # random values to the weights. + # initialize_dummy_weights(model) + return model.eval() + + +class MegatronLoader(BaseModelLoader): + """Model loader that can load the model weights from partitioned megatron model.""" + + def __init__(self, load_config: LoadConfig): + super().__init__(load_config) + if load_config.model_loader_extra_config: + raise ValueError(f"Model loader extra config is not supported for " + f"load format {load_config.load_format}") + + def download_model(self, model_config: ModelConfig) -> None: + pass # Nothing to download + + def _get_weights_iterator(actor_model: Union[PreTrainedModel, Dict]): + # NOTE(shengguangming) Load the weights from the actor model + pass + # if isinstance(actor_model, nn.Module): + # load_weights(actor_weights=dict(actor_model.named_parameters(remove_duplicate=False)), vllm_model=model) + # else: + # load_weights(actor_weights=actor_model, vllm_model=model) + # return actor_model + + def load_model( + self, + actor_model: Union[PreTrainedModel, Dict], + model_config: ModelConfig, + device_config: DeviceConfig, + lora_config: Optional[LoRAConfig], + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + cache_config: CacheConfig, + ) -> nn.Module: + with set_default_torch_dtype(model_config.dtype): + with torch.device(device_config.device): + model = _initialize_model(model_config, self.load_config, lora_config, cache_config, scheduler_config) + + # TODO(sgm): This is a hack, we need to register the load_weight() func for each model in vllm + if isinstance(actor_model, nn.Module): + load_megatron_weights(actor_weights=dict(actor_model.named_parameters(remove_duplicate=False)), + vllm_model=model) + else: + load_megatron_weights(actor_weights=actor_model, vllm_model=model) + + for _, module in model.named_modules(): + quant_method = getattr(module, "quant_method", None) + if quant_method is not None: + quant_method.process_weights_after_loading(module) + # FIXME: Remove this after Mixtral is updated + # to use quant_method. + if hasattr(module, "process_weights_after_loading"): + module.process_weights_after_loading() + # NOTE(sgm) Some weights are point to gpu, but still need this. + model = model.cuda() # NOTE (zhangchi.usc1992) We need this for vllm to profile memory usage + return model.eval() + + +class HFLoader(BaseModelLoader): + """Model loader that can load the model weights from model's full params.""" + + def __init__(self, load_config: LoadConfig): + super().__init__(load_config) + if load_config.model_loader_extra_config: + raise ValueError(f"Model loader extra config is not supported for " + f"load format {load_config.load_format}") + + def _get_weights_iterator(self, actor_model: Union[PreTrainedModel, Dict]): + if isinstance(actor_model, Dict): + return actor_model.items() + elif isinstance(actor_model, nn.Module): + return dict(actor_model.named_parameters()).items() + else: + raise ValueError(f"actor model should be Dict or nn.Module, but get {type(actor_model)}") + + def load_model( + self, + actor_model: Union[PreTrainedModel, Dict], + model_config: ModelConfig, + device_config: DeviceConfig, + lora_config: Optional[LoRAConfig], + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + cache_config: CacheConfig, + ) -> nn.Module: + with set_default_torch_dtype(model_config.dtype): + # with torch.device(device_config.device): + # NOTE(sgm): init the model in cpu + model = _initialize_model(model_config, self.load_config, lora_config, cache_config, scheduler_config) + model.load_weights(self._get_weights_iterator(actor_model)) + for _, module in model.named_modules(): + quant_method = getattr(module, "quant_method", None) + if quant_method is not None: + quant_method.process_weights_after_loading(module) + # FIXME: Remove this after Mixtral is updated + # to use quant_method. + if hasattr(module, "process_weights_after_loading"): + module.process_weights_after_loading() + # NOTE(sgm) Some weights are point to gpu, but still need this. + model = model.cuda() # NOTE (zhangchi.usc1992) We need this for vllm to profile memory usage + return model.eval() + + +class DTensorLoader(BaseModelLoader): + """Model loader that can load the model weights from partitioned megatron model.""" + + def __init__(self, load_config: LoadConfig): + super().__init__(load_config) + if load_config.model_loader_extra_config: + raise ValueError(f"Model loader extra config is not supported for " + f"load format {load_config.load_format}") + + def _get_weights_iterator(actor_model: Union[PreTrainedModel, Dict]): + # NOTE(shengguangming) Load the weights from the actor model + pass + # if isinstance(actor_model, nn.Module): + # load_weights(actor_weights=dict(actor_model.named_parameters(remove_duplicate=False)), vllm_model=model) + # else: + # load_weights(actor_weights=actor_model, vllm_model=model) + # return actor_model + + def load_model( + self, + actor_model: Union[PreTrainedModel, Dict], + model_config: ModelConfig, + device_config: DeviceConfig, + lora_config: Optional[LoRAConfig], + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + cache_config: CacheConfig, + ) -> nn.Module: + with set_default_torch_dtype(model_config.dtype): + with torch.device(device_config.device): + model = _initialize_model(model_config, self.load_config, lora_config, cache_config, scheduler_config) + + # TODO(sgm): This is a hack, we need to register the load_weight() func for each model in vllm + if isinstance(actor_model, nn.Module): + load_dtensor_weights(actor_weights=dict(actor_model.named_parameters(remove_duplicate=False)), + vllm_model=model) + else: + load_dtensor_weights(actor_weights=actor_model, vllm_model=model) + + for _, module in model.named_modules(): + quant_method = getattr(module, "quant_method", None) + if quant_method is not None: + quant_method.process_weights_after_loading(module) + # FIXME: Remove this after Mixtral is updated + # to use quant_method. + if hasattr(module, "process_weights_after_loading"): + module.process_weights_after_loading() + # NOTE(sgm) Some weights are point to gpu, but still need this. + model = model.cuda() # NOTE (zhangchi.usc1992) We need this for vllm to profile memory usage + return model.eval() + + +# FIXME(sgm): hack the _get_logits function in vllm v0.4.2 +# as they use ray, the _get_logits result will only need to return to the driver node, +# therefore gather is enough. However, we use SPMD instead of a central scheduler, +# all_gather is required (aligned with v0.2.6) +def _get_logits(self, hidden_states: torch.Tensor, embedding: torch.Tensor, + embedding_bias: Optional[torch.Tensor]) -> torch.Tensor: + # Get the logits for the next tokens. + logits = torch.matmul(hidden_states, embedding.t()) + if embedding_bias is not None: + logits += embedding_bias + logits = tensor_model_parallel_all_gather(logits) + # Remove paddings in vocab (if any). + if logits is not None: + logits = logits[:, :self.org_vocab_size] + return logits + + +from vllm.model_executor.layers.logits_processor import LogitsProcessor + + +def logitsprocessor_init( + self, + vocab_size: int, + org_vocab_size: Optional[int] = None, + scale: float = 1.0, + logits_as_input: bool = False, + soft_cap: Optional[float] = None, +) -> None: + """ + Args: + scale: A scaling factor to apply to the logits. + """ + super(LogitsProcessor, self).__init__() + self.scale = scale + self.vocab_size = vocab_size + # Whether the input is logits (default is hidden states). + self.logits_as_input = logits_as_input + # original vocabulary size (without LoRA). + self.org_vocab_size = org_vocab_size or vocab_size + # Soft cap the logits. Used in Gemma 2. + self.soft_cap = soft_cap + # Whether to use gather or all-gather to gather the logits. + self.use_gather = False + + +LogitsProcessor.__init__ = logitsprocessor_init # use all_gather diff --git a/verl/third_party/vllm/vllm_v_0_6_3/model_runner.py b/verl/third_party/vllm/vllm_v_0_6_3/model_runner.py new file mode 100644 index 0000000..b0cceff --- /dev/null +++ b/verl/third_party/vllm/vllm_v_0_6_3/model_runner.py @@ -0,0 +1,182 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023 The vLLM team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/worker/model_runner.py + +import warnings +from enum import IntEnum +from typing import Dict, Optional, Union + +import torch +import torch.nn as nn +import vllm.envs as envs +from vllm.compilation.levels import CompilationLevel +from vllm.config import ( + CacheConfig, + DeviceConfig, + LoadConfig, + LoRAConfig, + ModelConfig, + ObservabilityConfig, + ParallelConfig, + PromptAdapterConfig, + SchedulerConfig, +) +from vllm.inputs import INPUT_REGISTRY, InputRegistry +from vllm.logger import init_logger +from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager +from vllm.model_executor.models.interfaces import supports_lora +from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry +from vllm.prompt_adapter.worker_manager import LRUCacheWorkerPromptAdapterManager +from vllm.utils import DeviceMemoryProfiler, is_hip, supports_dynamo +from vllm.worker.model_runner import ModelRunner + +from .config import LoadConfig, ModelConfig +from .model_loader import get_model + +logger = init_logger(__name__) + + +# How batches are constructed. +class BatchType(IntEnum): + # Every batch is prefill. + PREFILL = 0 + # Every batch is decode. + DECODE = 1 + # Batch is a mixture of prefill and decode. + MIXED = 2 + + +class ModelRunner(ModelRunner): + + def __init__( + self, + model: Union[nn.Module, Dict], # [verl] model itself or its parameter dict + model_config: ModelConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + device_config: DeviceConfig, + cache_config: CacheConfig, + load_config: LoadConfig, + lora_config: Optional[LoRAConfig], + kv_cache_dtype: Optional[str] = "auto", + is_driver_worker: bool = False, + prompt_adapter_config: Optional[PromptAdapterConfig] = None, + return_hidden_states: bool = False, + observability_config: Optional[ObservabilityConfig] = None, + input_registry: InputRegistry = INPUT_REGISTRY, + mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, + ): + + super().__init__( + model_config, + parallel_config, + scheduler_config, + device_config, + cache_config, + load_config, + lora_config, + kv_cache_dtype, + is_driver_worker=True, # a hack + prompt_adapter_config=prompt_adapter_config, + return_hidden_states=return_hidden_states, + observability_config=observability_config, + input_registry=input_registry, + mm_registry=mm_registry, + ) + + # NOTE(sgm): add for verl + self.model = model # this will be replaced by get_model() + + def load_model(self) -> None: + logger.info("Starting to load model %s...", self.model_config.model) + with DeviceMemoryProfiler() as m: + self.model = get_model( + self.model, + model_config=self.model_config, + device_config=self.device_config, + load_config=self.load_config, + lora_config=self.lora_config, + parallel_config=self.parallel_config, + scheduler_config=self.scheduler_config, + cache_config=self.cache_config, + ) + + self.model_memory_usage = m.consumed_memory + logger.info("Loading model weights took %.4f GB", self.model_memory_usage / float(2**30)) + + if self.lora_config: + assert supports_lora(self.model), f"{self.model.__class__.__name__} does not support LoRA yet." + + if supports_multimodal(self.model): + logger.warning("Regarding multimodal models, vLLM currently " + "only supports adding LoRA to language model.") + # It's necessary to distinguish between the max_position_embeddings + # of VLMs and LLMs. + if hasattr(self.model.config, "max_position_embeddings"): + max_pos_embeddings = self.model.config.max_position_embeddings + else: + max_pos_embeddings = self.model.config.text_config.max_position_embeddings + + self.lora_manager = LRUCacheWorkerLoRAManager( + self.scheduler_config.max_num_seqs, + self.scheduler_config.max_num_batched_tokens, + self.vocab_size, + self.lora_config, + self.device, + self.model.embedding_modules, + self.model.embedding_padding_modules, + max_position_embeddings=max_pos_embeddings, + ) + self.model = self.lora_manager.create_lora_manager(self.model) + + if self.prompt_adapter_config: + self.prompt_adapter_manager = LRUCacheWorkerPromptAdapterManager( + self.scheduler_config.max_num_seqs, + self.scheduler_config.max_num_batched_tokens, + self.device, + self.prompt_adapter_config, + ) + self.model = self.prompt_adapter_manager.create_prompt_adapter_manager(self.model) + + if self.kv_cache_dtype == "fp8" and is_hip(): + # Currently only ROCm accepts kv-cache scaling factors + # via quantization_param_path and this will be deprecated + # in the future. + if self.model_config.quantization_param_path is not None: + if callable(getattr(self.model, "load_kv_cache_scales", None)): + warnings.warn( + "Loading kv cache scaling factor from JSON is " + "deprecated and will be removed. Please include " + "kv cache scaling factors in the model checkpoint.", + FutureWarning, + stacklevel=2, + ) + self.model.load_kv_cache_scales(self.model_config.quantization_param_path) + logger.info("Loaded KV cache scaling factors from %s", self.model_config.quantization_param_path) + else: + raise RuntimeError( + "Using FP8 KV cache and scaling factors provided but " + "model %s does not support loading scaling factors.", + self.model.__class__, + ) + else: + logger.warning("Using FP8 KV cache but no scaling factors " + "provided. Defaulting to scaling factors of 1.0. " + "This may lead to less accurate results!") + + if envs.VLLM_TORCH_COMPILE_LEVEL == CompilationLevel.DYNAMO_AS_IS and supports_dynamo(): + from vllm.plugins import get_torch_compile_backend + + backend = get_torch_compile_backend() or "eager" + self.model = torch.compile(self.model, fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE, backend=backend) diff --git a/verl/third_party/vllm/vllm_v_0_6_3/parallel_state.py b/verl/third_party/vllm/vllm_v_0_6_3/parallel_state.py new file mode 100644 index 0000000..0150c1c --- /dev/null +++ b/verl/third_party/vllm/vllm_v_0_6_3/parallel_state.py @@ -0,0 +1,312 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023 The vLLM team. +# Adapted from +# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +"""Model and data parallel groups.""" +import os +from typing import Optional + +import torch +import torch.distributed +import vllm.distributed.parallel_state as ps +from vllm.distributed.parallel_state import ( + get_pp_group, + get_world_group, + init_distributed_environment, + init_model_parallel_group, +) +from vllm.logger import init_logger + +logger = init_logger(__name__) +""" +This version is strongly tied with Megatron to implement HybridEngine and weight sharing between vllm and Megatron. +- We assume the Megatron tp+dp+pp world is already established before calling this function. + +""" + +# Device mesh for using DTensor +_DEVICE_MESH = None + +# Tensor model parallel group that the current rank belongs to. +_TP = None +# Pipeline model parallel group that the current rank belongs to. +_PP = None + + +# This method is for initializing the ParallelGroup when using HybridEngine +def initialize_parallel_state( + distributed_init_method: str = "env://", + backend: str = "nccl", + tensor_model_parallel_size: int = 1, + num_tp_per_train_tp: int = 1, + pipeline_model_parallel_size: int = 1, +): + # torch.distributed.all_reduce does not free the input tensor until + # the synchronization point. This causes the memory usage to grow + # as the number of all_reduce calls increases. This env var disables + # this behavior. + # Related issue: + # https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573 + os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" + + # NOTE(sgm): Modify for verl, Env vars will be set by TORCHRUN. + rank = int(os.getenv("RANK", "-1")) + local_rank = int(os.getenv("LOCAL_RANK", "0")) + + # Use the world_size set by TORCHRUN + world_size = int(os.getenv("WORLD_SIZE", "-1")) + assert world_size != -1, "The world_size is set to -1, not initialized by TORCHRUN" + init_distributed_environment(world_size, rank, distributed_init_method, local_rank, backend) + if torch.distributed.get_world_size() > 1: + # NOTE: build a sepearate inference group with infer tp & micro dp + initialize_model_parallel_for_vllm( + tensor_model_parallel_size=tensor_model_parallel_size, + num_tensor_model_parallel_groups_per_train_tp=num_tp_per_train_tp, + ) + else: + initialize_model_parallel(tensor_model_parallel_size, pipeline_model_parallel_size, backend) + + +def ensure_model_parallel_initialized( + tensor_model_parallel_size: int, + pipeline_model_parallel_size: int = 1, + backend: Optional[str] = None, +) -> None: + """Helper to initialize model parallel groups if they are not initialized, + or ensure tensor-parallel and pipeline-parallel sizes are equal to expected + values if the model parallel groups are initialized. + """ + # get the backend of _DEVICE_WORLD_GROUP + backend = backend or torch.distributed.get_backend(get_world_group().device_group) + if not model_parallel_is_initialized(): + initialize_model_parallel(tensor_model_parallel_size, pipeline_model_parallel_size, backend) + return + + assert get_tensor_model_parallel_world_size() == tensor_model_parallel_size, ( + "tensor parallel group already initialized, but of unexpected size: " + f"{get_tensor_model_parallel_world_size()=} vs. " + f"{tensor_model_parallel_size=}") + pp_world_size = get_pp_group().world_size + assert pp_world_size == pipeline_model_parallel_size, ( + "pipeline parallel group already initialized, but of unexpected size: " + f"{pp_world_size=} vs. " + f"{pipeline_model_parallel_size=}") + + +# TODO(sgm): deviate from the v0.5.4, not pp now +def model_parallel_is_initialized(): + """Check if tensor and pipeline parallel groups are initialized.""" + return ps._TP is not None + # and _PIPELINE_MODEL_PARALLEL_GROUP is not None) + + +def initialize_model_parallel_for_vllm( + tensor_model_parallel_size: int, + num_tensor_model_parallel_groups_per_train_tp: int = 1, + pipeline_model_parallel_size: int = 1, +) -> None: + pass + + # Get world size and rank. Ensure some consistencies. + assert torch.distributed.is_initialized() + + assert isinstance(tensor_model_parallel_size, int) + + # assert num_tensor_model_parallel_groups_per_train_tp == 1 and not different_tp_group + # assert num_tensor_model_parallel_groups_per_train_tp > 1 and different_tp_group + + # Build the tensor model-parallel groups. + assert ps._TP is None, "tensor model parallel group is already initialized" + + global _TP + + world_size: int = torch.distributed.get_world_size() + + rank = torch.distributed.get_rank() + + backend = torch.distributed.get_backend() + + num_tensor_model_parallel_groups = world_size // tensor_model_parallel_size + + if num_tensor_model_parallel_groups_per_train_tp == 1: + # if tensor_model_parallel_size == train_tensor_parallel_size: + # using the same tp group as Megatron/vllm + assert _TP is None, "tensor model parallel group is already initialized" + group_ranks = [] + for i in range(num_tensor_model_parallel_groups): + ranks = range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size) + group_ranks.append(ranks) + _TP = init_model_parallel_group( + group_ranks=group_ranks, + local_rank=get_world_group().local_rank, + backend=backend, + use_custom_allreduce=False, # TODO: check why True is not work in Ray trainer + use_message_queue_broadcaster=True, + ) + ps._TP = _TP + # _MICRO_DATA_PARALLEL_GROUP is move to hybrid engine + else: + # initialize a micro_dp group and a tp group + # assume training tp=4, infer tp=2, then, weight is partitioned as + # [1], [2], [3], [4] for training and [1,2], [1,2], [3,4], [3,4] for inference + + # Build the inference tp groups + # train_tp = train_tensor_parallel_size + train_tp = num_tensor_model_parallel_groups_per_train_tp * tensor_model_parallel_size + # num_tensor_model_parallel_groups_per_train_tp = train_tp // tensor_model_parallel_size + assert _TP is None, "tensor model parallel group is already initialized" + group_ranks = [] + for i in range(num_tensor_model_parallel_groups // num_tensor_model_parallel_groups_per_train_tp): + start = train_tp * i + end = train_tp * (i + 1) + for j in range(num_tensor_model_parallel_groups_per_train_tp): + ranks = list(range(start, end, num_tensor_model_parallel_groups_per_train_tp)) + for i in range(len(ranks)): + ranks[i] += j + group_ranks.append(ranks) + _TP = init_model_parallel_group( + group_ranks=group_ranks, + local_rank=get_world_group().local_rank, + backend=backend, + use_custom_allreduce=False, # TODO: check why True is not work in Ray trainer + use_message_queue_broadcaster=True, + ) + ps._TP = _TP + + # Build the pipeline model-parallel groups. + # global _PIPELINE_MODEL_PARALLEL_GROUP + # global _PIPELINE_GLOBAL_RANKS + # assert ps._PIPELINE_MODEL_PARALLEL_GROUP is None, ("pipeline model parallel group is already initialized") + + # ps._PIPELINE_MODEL_PARALLEL_GROUP = mpu.get_pipeline_model_parallel_group() + # ps._PIPELINE_GLOBAL_RANKS = mpu.get_pipeline_model_parallel_ranks() + + # TODO: init using device mesh (not support hybrid engine now) + # Build the pipeline model-parallel groups. + num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size + global _PP + assert _PP is None, "pipeline model parallel group is already initialized" + group_ranks = [] + for i in range(num_pipeline_model_parallel_groups): + ranks = list(range(i, world_size, num_pipeline_model_parallel_groups)) + group_ranks.append(ranks) + # pipeline parallel does not need custom allreduce + _PP = init_model_parallel_group(group_ranks, get_world_group().local_rank, backend, use_custom_allreduce=False) + ps._PP = _PP # for verl + + +def initialize_model_parallel( + tensor_model_parallel_size: int = 1, + pipeline_model_parallel_size: int = 1, + backend: Optional[str] = None, +) -> None: + """ + NOTE: This method is a hack from the open-sourced version without + asertion of world_size = tp * pp + + Initialize model parallel groups. + + Arguments: + tensor_model_parallel_size: number of GPUs used for tensor model + parallelism. + pipeline_model_parallel_size: number of GPUs used for pipeline model + parallelism. + + Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we + use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize + the model pipeline. The present function will + create 4 tensor model-parallel groups and 2 pipeline model-parallel groups: + 4 tensor model-parallel groups: + [g0, g1], [g2, g3], [g4, g5], [g6, g7] + 2 pipeline model-parallel groups: + [g0, g2, g4, g6], [g1, g3, g5, g7] + Note that for efficiency, the caller should make sure adjacent ranks + are on the same DGX box. For example if we are using 2 DGX-1 boxes + with a total of 16 GPUs, rank 0 to 7 belong to the first box and + ranks 8 to 15 belong to the second box. + """ + # Get world size and rank. Ensure some consistencies. + assert torch.distributed.is_initialized() + world_size: int = torch.distributed.get_world_size() + backend = backend or torch.distributed.get_backend(ps.get_world_group().device_group) + + # NOTE(sgm) we don't assert world_size == tp * pp + # DP is not managed by vllm but by the VeRL WorkerGroup + # if (world_size != + # tensor_model_parallel_size * pipeline_model_parallel_size): + # raise RuntimeError( + # f"world_size ({world_size}) is not equal to " + # f"tensor_model_parallel_size ({tensor_model_parallel_size}) x " + # f"pipeline_model_parallel_size ({pipeline_model_parallel_size})") + + num_tensor_model_parallel_groups: int = world_size // tensor_model_parallel_size + rank = torch.distributed.get_rank() + global _TP + assert _TP is None, "tensor model parallel group is already initialized" + group_ranks = [] + for i in range(num_tensor_model_parallel_groups): + ranks = list(range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size)) + group_ranks.append(ranks) + + # message queue broadcaster is only used in tensor model parallel group + _TP = init_model_parallel_group( + group_ranks, + get_world_group().local_rank, + backend, + use_custom_allreduce=False, # TODO: check why True is not work in Ray trainer + use_message_queue_broadcaster=True, + ) + ps._TP = _TP + + # TODO: init using device mesh (not support hybrid engine now) + # Build the pipeline model-parallel groups. + num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size + global _PP + assert _PP is None, "pipeline model parallel group is already initialized" + group_ranks = [] + for i in range(num_pipeline_model_parallel_groups): + ranks = list(range(i, world_size, num_pipeline_model_parallel_groups)) + group_ranks.append(ranks) + # pipeline parallel does not need custom allreduce + _PP = init_model_parallel_group(group_ranks, get_world_group().local_rank, backend, use_custom_allreduce=False) + ps._PP = _PP # for verl + + +""" +Device mesh utilities +""" + + +def get_device_mesh(): + assert _DEVICE_MESH is not None, "device mesh is not initialized" + return _DEVICE_MESH + + +""" +Tensor model parallel utilities +""" + + +def get_tensor_model_parallel_group(): + """Get the tensor model parallel group the caller rank belongs to.""" + assert _TP is not None, "tensor model parallel group is not initialized" + return _TP.device_group + + +def get_tensor_model_parallel_world_size(): + """Return world size for the tensor model parallel group.""" + return torch.distributed.get_world_size(group=get_tensor_model_parallel_group()) + + +def get_tensor_model_parallel_rank(): + """Return my rank for the tensor model parallel group.""" + return torch.distributed.get_rank(group=get_tensor_model_parallel_group()) + + +def get_tensor_model_parallel_src_rank(): + """Calculate the global rank corresponding to the first local rank + in the tensor model parallel group.""" + global_rank = torch.distributed.get_rank() + local_world_size = get_tensor_model_parallel_world_size() + return (global_rank // local_world_size) * local_world_size diff --git a/verl/third_party/vllm/vllm_v_0_6_3/spmd_gpu_executor.py b/verl/third_party/vllm/vllm_v_0_6_3/spmd_gpu_executor.py new file mode 100644 index 0000000..229a424 --- /dev/null +++ b/verl/third_party/vllm/vllm_v_0_6_3/spmd_gpu_executor.py @@ -0,0 +1,256 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023 The vLLM team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/executor/gpu_executor.py + +import os +import socket +from typing import Dict, List, Optional, Set, Tuple + +import torch +from vllm.config import ( + CacheConfig, + DeviceConfig, + LoRAConfig, + ObservabilityConfig, + ParallelConfig, + PromptAdapterConfig, + SchedulerConfig, + SpeculativeConfig, +) +from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase +from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.sequence import ExecuteModelRequest + +from .config import LoadConfig, ModelConfig + +logger = init_logger(__name__) + + +class SPMDGPUExecutor(ExecutorBase): + """SPMD-based multi-GPU executor implementations.""" + + def __init__( + self, + model, # pytorch model itself or its parameter dict + model_config: ModelConfig, + cache_config: CacheConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + device_config: DeviceConfig, + load_config: LoadConfig, + lora_config: Optional[LoRAConfig], + speculative_config: Optional[SpeculativeConfig], + prompt_adapter_config: Optional[PromptAdapterConfig], + observability_config: Optional[ObservabilityConfig], + ) -> None: + self.model_config = model_config + self.cache_config = cache_config + self.lora_config = lora_config + self.load_config = load_config + self.parallel_config = parallel_config + self.scheduler_config = scheduler_config + self.device_config = device_config + self.speculative_config = speculative_config + self.prompt_adapter_config = prompt_adapter_config + self.observability_config = observability_config + + distributed_init_method = initialize_cluster(parallel_config) + self._init_executor(model, distributed_init_method) + + # TODO(sgm): verl not support speculative decode now + def _init_executor(self, model, distributed_init_method) -> None: + assert not self.speculative_config, "Speculative decoding not yet supported for multi-GPU backend." + + # Create the parallel worker for each GPU. + self._init_workers_sp(model, distributed_init_method) + + def _init_workers_sp(self, model, distributed_init_method: str): + # Lazy import the Worker to avoid importing torch.cuda/xformers + # before CUDA_VISIBLE_DEVICES is set in the Worker + from .worker import Worker # pylint: disable=import-outside-toplevel + + rank = int(os.getenv("RANK")) + local_rank = int(os.getenv("LOCAL_RANK")) + print(f"local rank {local_rank}") + + # see https://github.com/NVIDIA/nccl/issues/1234 + os.environ["NCCL_CUMEM_ENABLE"] = "0" + + self.worker = Worker( + model, + self.model_config, + self.parallel_config, + self.scheduler_config, + self.device_config, + self.cache_config, + self.load_config, + local_rank, + rank, + distributed_init_method, + lora_config=self.lora_config, + speculative_config=None, + prompt_adapter_config=self.speculative_config, + is_driver_worker=True, + model_runner_cls=None, # use the default one + ) + + # NOTE(shengguangming): torch.distributed.init_process_group will be called inside the init_model() + self.worker.init_device() + self.worker.load_model() + + def determine_num_available_blocks(self) -> Tuple[int, int]: + """Determine the number of available KV blocks. + + This invokes `determine_num_available_blocks` on each worker and takes + the min of the results, guaranteeing that the selected cache sizes are + compatible with all workers. + + Returns: + - tuple[num_gpu_blocks, num_cpu_blocks] + """ + # Get the maximum number of blocks that can be allocated on GPU and CPU. + num_blocks = self.worker.determine_num_available_blocks() + + # NOTE(shengguangming): Now we don't use a shared centralized controler but each process will + # have its own scheduler + num_gpu_blocks = num_blocks[0] + num_cpu_blocks = num_blocks[1] + + return num_gpu_blocks, num_cpu_blocks + + def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None: + """Initialize the KV cache in all workers.""" + + # NOTE: We log here to avoid multiple logs when number of workers is + # greater than one. We could log in the engine, but not all executors + # have GPUs. + logger.info("# GPU blocks: %d, # CPU blocks: %d", num_gpu_blocks, num_cpu_blocks) + + self.cache_config.num_gpu_blocks = num_gpu_blocks + self.cache_config.num_cpu_blocks = num_cpu_blocks + + if torch.distributed.get_rank() == 0: + print( + f"before init cache memory allocated: {torch.cuda.memory_allocated() / 1e9}GB, reserved: {torch.cuda.memory_reserved() / 1e9}GB" + ) + self.worker.initialize_cache(num_gpu_blocks=num_gpu_blocks, num_cpu_blocks=num_cpu_blocks) + if torch.distributed.get_rank() == 0: + print( + f"after init cache memory allocated: {torch.cuda.memory_allocated() / 1e9}GB, reserved: {torch.cuda.memory_reserved() / 1e9}GB" + ) + + # NOTE(sgm): This will not profile & capture the model(CUDAGraph) when rebuilding KVCache + def init_cache_engine(self) -> None: + self.worker._init_cache_engine() + + def free_cache_engine(self) -> None: + self.worker.free_cache_engine() + + def execute_model(self, execute_model_req) -> List[SamplerOutput]: + all_outputs = self.worker.execute_model(execute_model_req=execute_model_req) + + # NOTE(sgm): + # Each GPU in vllm under verl has its own spmd_gpu_executor, therefore all GPUs should return the outputs + # In vllm with ray, only the driver worker returns the sampling results. + return all_outputs + + def add_lora(self, lora_request: LoRARequest) -> bool: + assert lora_request.lora_int_id > 0, "lora_id must be greater than 0." + return self.worker.add_lora(lora_request=lora_request) + + def remove_lora(self, lora_id: int) -> bool: + assert lora_id > 0, "lora_id must be greater than 0." + return self.worker.remove_lora(lora_id=lora_id) + + def list_loras(self) -> Set[int]: + return self.worker.list_loras() + + def check_health(self) -> None: + # SPMDExecutor will always be healthy as long as + # it's running. + return + + # NOTE(sgm) add for verl to pass the abstract class test, not used + from vllm.prompt_adapter.request import PromptAdapterRequest + + def add_prompt_adapter(self, prompt_adapter_request: PromptAdapterRequest) -> bool: + assert prompt_adapter_request.prompt_adapter_id > 0, "prompt_adapter_id must be greater than 0." + return self.worker.add_prompt_adapter(prompt_adapter_request) + + def list_prompt_adapters(self) -> Set[int]: + return self.worker.list_prompt_adapters() + + def pin_lora(self, lora_id: int) -> bool: + assert lora_id > 0, "lora_id must be greater than 0." + return self.worker.pin_lora(lora_id) + + def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool: + assert prompt_adapter_id > 0, "prompt_adapter_id must be greater than 0." + return self.worker.pin_prompt_adapter(prompt_adapter_id) + + def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool: + assert prompt_adapter_id > 0, "prompt_adapter_id must be greater than 0." + return self.worker.remove_prompt_adapter(prompt_adapter_id) + + # NOTE(sgm): add for verl + def offload_model_weights(self) -> None: + self.worker.offload_model_weights() + + def sync_model_weights(self, actor_weights: Dict[str, torch.Tensor], load_format: str) -> None: + self.worker.sync_model_weights(actor_weights=actor_weights, load_format=load_format) + + +def initialize_cluster( + parallel_config: ParallelConfig, + engine_use_ray: bool = False, + ray_address: Optional[str] = None, +) -> Tuple[str, Optional[None]]: + """Initialize the distributed cluster probably with Ray. + + Args: + parallel_config: The configurations for parallel execution. + + Returns: + The `distributed_init_method` is the address for initializing the + distributed backend. + """ + + # Initialize cluster locally. + port = get_open_port() + # We need to setup the distributed init method to make sure + # the distributed megatron code (e.g., get world size) works correctly. + # distributed_init_method = f"tcp://localhost:{port}" + distributed_init_method = "env://" + return distributed_init_method + + +def get_open_port(): + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + return s.getsockname()[1] + + +# TODO(sgm): not implemented async executor yet +class SPMDGPUExecutorAsync(SPMDGPUExecutor, ExecutorAsyncBase): + + async def execute_model_async(self, execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: + """Executes one model step on the given sequences.""" + raise NotImplementedError + + async def check_health_async(self) -> None: + """Checks if the executor is healthy. If not, it should raise an + exception.""" + self.check_health() diff --git a/verl/third_party/vllm/vllm_v_0_6_3/tokenizer.py b/verl/third_party/vllm/vllm_v_0_6_3/tokenizer.py new file mode 100644 index 0000000..b0b4d0e --- /dev/null +++ b/verl/third_party/vllm/vllm_v_0_6_3/tokenizer.py @@ -0,0 +1,40 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023 The vLLM team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/transformers_utils/tokenizer_group/tokenizer_group.py + +from typing import Optional + +from transformers import PreTrainedTokenizer +from vllm.transformers_utils.tokenizer_group import TokenizerGroup +from vllm.utils import LRUCache + + +class TokenizerGroup(TokenizerGroup): + """A group of tokenizers that can be used for LoRA adapters.""" + + def __init__(self, tokenizer: PreTrainedTokenizer, enable_lora: bool, max_num_seqs: int, + max_input_length: Optional[int]): + self.enable_lora = enable_lora + self.max_input_length = max_input_length + self.tokenizer = tokenizer + self.lora_tokenizers = LRUCache[PreTrainedTokenizer](capacity=max_num_seqs) if enable_lora else None + + # FIXME(sgm): for simplicity, we assign the special token here + @property + def pad_token_id(self): + return self.tokenizer.pad_token_id + + @property + def eos_token_id(self): + return self.tokenizer.eos_token_id diff --git a/verl/third_party/vllm/vllm_v_0_6_3/worker.py b/verl/third_party/vllm/vllm_v_0_6_3/worker.py new file mode 100644 index 0000000..cb1a7ab --- /dev/null +++ b/verl/third_party/vllm/vllm_v_0_6_3/worker.py @@ -0,0 +1,333 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023 The vLLM team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/worker/worker.py +"""A GPU worker class.""" +import gc +import os +from typing import Dict, List, Optional, Tuple, Type, Union + +import torch +import torch.distributed +import torch.nn as nn +from vllm.config import ( + CacheConfig, + DeviceConfig, + LoRAConfig, + ParallelConfig, + PromptAdapterConfig, + SchedulerConfig, + SpeculativeConfig, +) + +# TODO(sgm): check why vllm has similar file in vllm.model_executor.parallel_utils.parallel_state +from vllm.distributed import get_tensor_model_parallel_group, init_distributed_environment, set_custom_all_reduce +from vllm.model_executor import set_random_seed +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.sequence import ExecuteModelRequest, IntermediateTensors +from vllm.worker.cache_engine import CacheEngine +from vllm.worker.embedding_model_runner import EmbeddingModelRunner +from vllm.worker.model_runner import GPUModelRunnerBase +from vllm.worker.model_runner_base import ModelRunnerInputBase +from vllm.worker.worker import Worker, _check_if_gpu_supports_dtype +from vllm.worker.worker_base import WorkerInput + +from .config import LoadConfig, LoadFormat, ModelConfig +from .dtensor_weight_loaders import load_dtensor_weights +from .hf_weight_loader import load_hf_weights +from .megatron_weight_loaders import load_megatron_weights +from .model_runner import ModelRunner +from .parallel_state import ensure_model_parallel_initialized + + +class Worker(Worker): + """A worker class that executes (a partition of) the model on a GPU. + + Each worker is associated with a single GPU. The worker is responsible for + maintaining the KV cache and executing the model on the GPU. In case of + distributed inference, each worker is assigned a partition of the model. + """ + + def __init__( + self, + model: Union[nn.Module, Dict], # model itself or its parameter dict + model_config: ModelConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + device_config: DeviceConfig, + cache_config: CacheConfig, + load_config: LoadConfig, + local_rank: int, + rank: int, + distributed_init_method: str, + lora_config: Optional[LoRAConfig] = None, + speculative_config: Optional[SpeculativeConfig] = None, + prompt_adapter_config: Optional[PromptAdapterConfig] = None, + is_driver_worker: bool = False, + model_runner_cls: Optional[Type[GPUModelRunnerBase]] = None, + ) -> None: + # self.model = model # will be replaced in the init_model + self.model_config = model_config + self.parallel_config = parallel_config + self.parallel_config.rank = rank + self.scheduler_config = scheduler_config + self.device_config = device_config + self.cache_config = cache_config + self.local_rank = local_rank + self.rank = rank + self.distributed_init_method = distributed_init_method + self.lora_config = lora_config + self.load_config = load_config + self.prompt_adapter_config = prompt_adapter_config + self.is_driver_worker = is_driver_worker # TODO: we don't need driver + # if parallel_config and is_driver_worker: + # assert rank % parallel_config.tensor_parallel_size == 0, \ + # "Driver worker should be rank 0 of tensor parallel group." + if self.model_config.trust_remote_code: + # note: lazy import to avoid importing torch before initializing + from vllm.utils import init_cached_hf_modules + + init_cached_hf_modules() + + # Return hidden states from target model if the draft model is an + # mlp_speculator + speculative_args = ( + {} if speculative_config is None or (speculative_config.draft_model_config.model == model_config.model) or + (speculative_config.draft_model_config.hf_config.model_type not in ["medusa", "mlp_speculator"]) else { + "return_hidden_states": True + }) + + # TODO(sgm): set correct model runner class + ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner + if model_runner_cls is not None: + ModelRunnerClass = model_runner_cls + elif self.model_config.embedding_mode: + ModelRunnerClass = EmbeddingModelRunner + self.model_runner: GPUModelRunnerBase = ModelRunnerClass( + model, # [VERL]: add for verl + model_config, + parallel_config, + scheduler_config, + device_config, + cache_config, + load_config=load_config, + lora_config=self.lora_config, + kv_cache_dtype=self.cache_config.cache_dtype, + is_driver_worker=is_driver_worker, + prompt_adapter_config=prompt_adapter_config, + **speculative_args, + ) + + # Uninitialized cache engine. Will be initialized by + # initialize_cache. + self.cache_engine: List[CacheEngine] = None + # Initialize gpu_cache as embedding models don't initialize kv_caches + self.gpu_cache: Optional[List[List[torch.Tensor]]] = None + + # NOTE(sgm): [VERL] For offloading inference engine params + self.cpu_model = None + + def init_device(self) -> None: + if self.device_config.device.type == "cuda": + # torch.distributed.all_reduce does not free the input tensor until + # the synchronization point. This causes the memory usage to grow + # as the number of all_reduce calls increases. This env var disables + # this behavior. + # Related issue: + # https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573 + os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" + + # NOTE(sgm): Modify for verl, Env vars will be set by TORCHRUN. + self.rank = self.rank if self.rank is not None else int(os.getenv("RANK", "-1")) + local_rank = int(os.getenv("LOCAL_RANK", "0")) + self.device = torch.device(f"cuda:{local_rank}") + if self.rank < 0: + raise ValueError("Invalid or unspecified rank.") + torch.cuda.set_device(self.device) + + # Use the world_size set by TORCHRUN + world_size = int(os.getenv("WORLD_SIZE", "-1")) + assert world_size != -1, "The world_size is set to -1, not initialized by TORCHRUN" + self.parallel_config.world_size = world_size + + _check_if_gpu_supports_dtype(self.model_config.dtype) + torch.cuda.empty_cache() + self.init_gpu_memory = torch.cuda.mem_get_info()[0] + else: + raise RuntimeError(f"Not support device type: {self.device_config.device}") + + # Initialize the distributed environment. + init_worker_distributed_environment(self.parallel_config, self.rank, self.distributed_init_method, + self.local_rank) + # Set random seed. + set_random_seed(self.model_config.seed) + # self.model = get_model(actor_model=self.model, model_config=self.model_config) + + @torch.inference_mode() + def determine_num_available_blocks(self) -> Tuple[int, int]: + """Profiles the peak memory usage of the model to determine how many + KV blocks may be allocated without OOMs. + + The engine will first conduct a profiling of the existing memory usage. + Then, it calculate the maximum possible number of GPU and CPU blocks + that can be allocated with the remaining free memory. + + .. tip:: + You may limit the usage of GPU memory + by adjusting the `gpu_memory_utilization` parameter. + """ + # Profile the memory usage of the model and get the maximum number of + # cache blocks that can be allocated with the remaining free memory. + torch.cuda.empty_cache() + # torch.cuda.reset_peak_memory_stats() + + # Execute a forward pass with dummy inputs to profile the memory usage + # of the model. + self.model_runner.profile_run() + + # Calculate the number of blocks that can be allocated with the + # profiled peak memory. + torch.cuda.synchronize() + free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info() + peak_memory = total_gpu_memory - free_gpu_memory + + assert peak_memory > 0, ("Error in memory profiling. This happens when the GPU memory was " + "not properly cleaned up before initializing the vLLM instance.") + + cache_block_size = self.get_cache_block_size_bytes() + + # NOTE(sgm) [VERL] use the remaining memory + num_gpu_blocks = int((free_gpu_memory * self.cache_config.gpu_memory_utilization) // cache_block_size) + # num_gpu_blocks = int((total_gpu_memory * self.cache_config.gpu_memory_utilization - peak_memory) // cache_block_size) + + num_cpu_blocks = int(self.cache_config.swap_space_bytes // cache_block_size) + num_gpu_blocks = max(num_gpu_blocks, 0) + num_cpu_blocks = max(num_cpu_blocks, 0) + if self.model_runner.lora_manager: + self.model_runner.remove_all_loras() + + # NOTE(sgm): Add for [VERL], synchronize number of blocks with all the rank + num_gpu_blocks = torch.tensor([num_gpu_blocks], device="cuda") + num_cpu_blocks = torch.tensor([num_cpu_blocks], device="cuda") + + torch.distributed.all_reduce(num_gpu_blocks, + op=torch.distributed.ReduceOp.MIN, + group=get_tensor_model_parallel_group().device_group) + torch.distributed.all_reduce(num_cpu_blocks, + op=torch.distributed.ReduceOp.MIN, + group=get_tensor_model_parallel_group().device_group) + num_gpu_blocks = num_gpu_blocks.item() + num_cpu_blocks = num_cpu_blocks.item() + gc.collect() + torch.cuda.empty_cache() + return num_gpu_blocks, num_cpu_blocks + + def _init_cache_engine(self): + if self.cache_engine is None and self.gpu_cache is None: + super()._init_cache_engine() + + def free_cache_engine(self): + # ensure `enforce_eager=True` + self.cache_engine = None + self.gpu_cache = None + + # NOTE(sgm): [VERL]: adapt from _execute_model_spmd() + def execute_model(self, + execute_model_req: ExecuteModelRequest, + intermediate_tensors: Optional[IntermediateTensors] = None) -> Optional[List[SamplerOutput]]: + """ + Execute model in Single Program Multiple Data (SPMD) fashion. + All workers take the same request, prepare the input and + execute the model. + """ + assert execute_model_req is not None, ("_execute_model_spmd() requires each worker to take in an " + "ExecuteModelRequest") + worker_input: WorkerInput = self.prepare_worker_input(execute_model_req=execute_model_req) + model_input: ModelRunnerInputBase = self.model_runner.prepare_model_input( + execute_model_req.seq_group_metadata_list) + + # verl.worker.workerbase.WorkerBase + # swap cache + super().execute_worker(worker_input) + + # If there is no input, we don't need to execute the model. + if worker_input.num_seq_groups == 0: + return [] + + return self.model_runner.execute_model( + model_input, + self.kv_cache[worker_input.virtual_engine] if self.kv_cache is not None else None, + intermediate_tensors, + ) + + # assume the input is .state_dict() + def sync_model_weights(self, actor_weights: Dict, load_format: str): + if load_format in [LoadFormat.MEGATRON, LoadFormat.AUTO]: + load_megatron_weights(actor_weights, self.model_runner.model) + elif load_format == LoadFormat.HF: + # full model state dict without no sharding + load_hf_weights(actor_weights, self.model_runner.model) + elif load_format == LoadFormat.DTENSOR: + load_dtensor_weights(actor_weights, self.model_runner.model) + + def offload_model_weights(self) -> None: + if self.cpu_model == None: + self.cpu_model = {} + for name, params in self.model_runner.model.named_parameters(): + self.cpu_model[name] = torch.empty_like(params, device="cpu") + params.data = self.cpu_model[name] + else: + for name, params in self.model_runner.model.named_parameters(): + params.data = self.cpu_model[name] + + +def init_worker_distributed_environment( + parallel_config: ParallelConfig, + rank: int, + distributed_init_method: Optional[str] = "env://", + local_rank: int = -1, +) -> None: + """Initialize the distributed environment.""" + set_custom_all_reduce(not parallel_config.disable_custom_all_reduce) + + # NOTE(sgm) use tcp://localhost:xxxx will hang in HF setting without megatron + init_distributed_environment(parallel_config.world_size, rank, distributed_init_method, local_rank) + + ensure_model_parallel_initialized( + tensor_model_parallel_size=parallel_config.tensor_parallel_size, + pipeline_model_parallel_size=parallel_config.pipeline_parallel_size, + ) + + # TODO(sgm): check whether need this + # if pynccl_utils.is_initialized(): + # pynccl_world_size = pynccl_utils.get_world_size() + # if pynccl_world_size != parallel_config.world_size: + # raise RuntimeError( + # "pynccl is already initialized but the pynccl world " + # "size does not match parallel_config.world_size " + # f"({pynccl_world_size} vs. {parallel_config.world_size}).") + # elif parallel_config.world_size > 1: + # # NOTE(woosuk): We don't initialize pynccl process group when world size + # # is 1. + # # NOTE(kaichao): By default, pynccl is initialized for tp group. + # pynccl_utils.init_process_group( + # group=get_tensor_model_parallel_cpu_group()) + + # # Initialize a custom fast all-reduce implementation. + # if not parallel_config.disable_custom_all_reduce: + # init_custom_ar() + + # A small all_reduce for warmup. + torch.distributed.all_reduce(torch.zeros(1).cuda()) + # if pynccl_utils.is_initialized(): + # pynccl_utils.all_reduce(torch.zeros(1).cuda()) diff --git a/verl/trainer/ppo/hybrid_engine/__init__.py b/verl/trainer/ppo/hybrid_engine/__init__.py index aebff5b..3713733 100644 --- a/verl/trainer/ppo/hybrid_engine/__init__.py +++ b/verl/trainer/ppo/hybrid_engine/__init__.py @@ -14,6 +14,8 @@ from verl.utils.import_utils import is_vllm_available, is_megatron_core_available +from .base import BaseShardingManager + AllGatherPPModel = None if is_megatron_core_available() and is_vllm_available(): diff --git a/verl/trainer/ppo/rollout/vllm_rollout/vllm_rollout.py b/verl/trainer/ppo/rollout/vllm_rollout/vllm_rollout.py index a631e81..e66275c 100644 --- a/verl/trainer/ppo/rollout/vllm_rollout/vllm_rollout.py +++ b/verl/trainer/ppo/rollout/vllm_rollout/vllm_rollout.py @@ -82,7 +82,7 @@ def __init__(self, actor_module: nn.Module, config: DictConfig, tokenizer, model os.environ['MEGATRON_IMPORT_TIMERS'] = '0' train_tp = kwargs.get('train_tp', None) num_tp_per_train_tp = train_tp // tensor_parallel_size - if vllm_version == '0.4.2' or vllm_version == '0.5.4': + if vllm_version in ('0.4.2', '0.5.4', '0.6.3'): vllm_ps.initialize_parallel_state(tensor_model_parallel_size=tensor_parallel_size, num_tp_per_train_tp=num_tp_per_train_tp) @@ -109,7 +109,7 @@ def __init__(self, actor_module: nn.Module, config: DictConfig, tokenizer, model ) # we may detokenize the result all together later - if vllm_version == '0.4.2' or vllm_version == '0.5.4': + if vllm_version in ('0.4.2', '0.5.4', '0.6.3'): kwargs['detokenize'] = False # supporting adding any sampling params from the config file From 50ac7252f97cf39997344237bca74162b2dc6d79 Mon Sep 17 00:00:00 2001 From: HL Date: Sun, 8 Dec 2024 22:10:36 -0800 Subject: [PATCH 03/14] [distro] feat: add docker support (#41) * [distro] feat: add docker support * update docker tag * update description --- README.md | 99 ++++++++++++++++++++++----------- docker/Dockerfile.ngc.vllm | 31 +++++++++++ docker/Dockerfile.vemlp.vllm.te | 41 ++++++++++++++ 3 files changed, 137 insertions(+), 34 deletions(-) create mode 100644 docker/Dockerfile.ngc.vllm create mode 100644 docker/Dockerfile.vemlp.vllm.te diff --git a/README.md b/README.md index 21cde82..83969c3 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,3 @@ -
- -
-

veRL: Volcano Engine Reinforcement Learning for LLM

veRL (HybridFlow) is a flexible, efficient and industrial-level RL(HF) training framework designed for large language models (LLMs). veRL is the open-source version of [HybridFlow](https://arxiv.org/abs/2409.19256v2) paper. @@ -29,66 +25,100 @@ veRL is fast with:

+## Installation Guide + +Below are the steps to install veRL in your environment. + +### Requirements +- **Python**: Version >= 3.9 +- **CUDA**: Version >= 12.1 + +veRL supports various backends. Currently, the following configurations are available: +- **FSDP** and **Megatron-LM** for training. +- **vLLM** for rollout generation. + +**Training backends** + +We recommend using **FSDP** backend to investigate, research and prototype different models, datasets and RL algorithms. The guide for using FSDP backend can be found in [PyTorch FSDP Backend](https://verl.readthedocs.io/en/latest/workers/fsdp_workers.html) + +For users who pursue better scalability, we recommend using **Megatron-LM** backend. Currently, we support Megatron-LM@core_v0.4.0 and we fix some internal issues of Megatron-LM. Here's the additional installation guide. The guide for using Megatron-LM backend can be found in [Megatron-LM Backend](https://verl.readthedocs.io/en/latest/workers/megatron_workers.html) + +### Installation Options +#### 1. From Docker Image -## Installation +We provide pre-built Docker images for quick setup. -For installing the latest version of veRL, the best way is to clone and install it from source. Then you can modify our code to customize your own post-training jobs. +Image and tag: `verlai/verl:vemlp-th2.4.0-cu124-vllm0.6.3-ray2.10-te1.7-v0.0.3` + +1. Launch the desired Docker image: ```bash -# install verl together with some lightweight dependencies in setup.py -git clone https://github.com/volcengine/verl.git -cd verl -pip3 install -e . +docker run --runtime=nvidia -it --rm --shm-size="10g" --cap-add=SYS_ADMIN -v ``` -You can also install veRL using `pip3 install` +2. Inside the container, install veRL: ```bash -# directly install from pypi -pip3 install verl +# install the nightly version +git clone https://github.com/volcengine/verl && cd verl && pip3 install -e . +# or install from pypi via `pip3 install verl` ``` -### Dependencies +4. Setup Megatron (optional) -veRL requires Python >= 3.9 and CUDA >= 12.1. +If you want to enable training with Megatron, Megatron code must be added to PYTHONPATH: -veRL support various backend, we currently release FSDP and Megatron-LM for actor training and vLLM for rollout generation. +```bash +cd .. +git clone -b core_v0.4.0 https://github.com/NVIDIA/Megatron-LM.git +cp verl/patches/megatron_v4.patch Megatron-LM/ +cd Megatron-LM && git apply megatron_v4.patch +pip3 install -e . +export PYTHONPATH=$PYTHONPATH:$(pwd) +``` + +You can also get the Megatron code after verl's patch via +```bash +git clone -b core_v0.4.0_verl https://github.com/eric-haibin-lin/Megatron-LM +``` + +#### 2. From Custom Environments + +
If you prefer setting up veRL in your custom environment, expand this section and follow the steps below. + +Using **conda** is recommended for managing dependencies. -To install the dependencies, we recommend using conda: +1. Create a conda environment: ```bash conda create -n verl python==3.9 conda activate verl ``` -The following dependencies are required for all backends. +2. Install common dependencies (required for all backends) ```bash # install torch [or you can skip this step and let vllm to install the correct version for you] -pip install torch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0 --index-url https://download.pytorch.org/whl/cu121 +pip3 install torch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0 --index-url https://download.pytorch.org/whl/cu121 # install vllm pip3 install vllm==0.6.3 # or you can install 0.5.4, 0.4.2 and 0.3.1 -pip3 install ray==2.10 # other version may have bug +pip3 install ray # flash attention 2 pip3 install flash-attn --no-build-isolation ``` -**FSDP** +3. Install veRL -We recommend using FSDP backend to investigate, research and prototype different models, datasets and RL algorithms. - -The pros, cons and extension guide for using FSDP backend can be found in [PyTorch FSDP Backend](https://verl.readthedocs.io/en/latest/workers/fsdp_workers.html) - -**Megatron-LM** - -For users who pursue better scalability, we recommend using Megatron-LM backend. Please install the above dependencies first. - -Currently, we support Megatron-LM@core_v0.4.0 and we fix some internal issues of Megatron-LM. Here's the additional installation guide. +```bash +# install the nightly version +git clone https://github.com/volcengine/verl && cd verl && pip3 install -e . +# or install from pypi via `pip3 install verl` +``` -The pros, cons and extension guide for using Megatron-LM backend can be found in [Megatron-LM Backend](https://verl.readthedocs.io/en/latest/workers/megatron_workers.html) +4. Setup Megatron (optional) ```bash # FOR Megatron-LM Backend @@ -103,13 +133,14 @@ pip3 install git+https://github.com/NVIDIA/TransformerEngine.git@v1.7 # megatron core v0.4.0 cd .. git clone -b core_v0.4.0 https://github.com/NVIDIA/Megatron-LM.git -cd Megatron-LM -cp ../verl/patches/megatron_v4.patch . -git apply megatron_v4.patch +cp verl/patches/megatron_v4.patch Megatron-LM/ +cd Megatron-LM && git apply megatron_v4.patch pip3 install -e . export PYTHONPATH=$PYTHONPATH:$(pwd) ``` +
+ ## Getting Started Visit our [documentation](https://verl.readthedocs.io/en/latest/index.html) to learn more. diff --git a/docker/Dockerfile.ngc.vllm b/docker/Dockerfile.ngc.vllm new file mode 100644 index 0000000..e6ecd98 --- /dev/null +++ b/docker/Dockerfile.ngc.vllm @@ -0,0 +1,31 @@ +FROM nvcr.io/nvidia/pytorch:24.05-py3 + +# uninstall nv-pytorch fork +RUN pip3 uninstall pytorch-quantization \ + pytorch-triton \ + torch \ + torch-tensorrt \ + torchvision \ + xgboost transformer_engine flash_attn \ + apex megatron-core -y + +RUN pip3 install torch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0 --index-url https://download.pytorch.org/whl/cu124 + +# make sure torch version is kept +RUN pip3 install --no-cache-dir \ + "torch==2.4.0" \ + accelerate \ + codetiming \ + datasets \ + dill \ + hydra-core \ + numpy \ + pybind11 \ + tensordict \ + "transformers<=4.46.0" + +# ray is installed via vllm +RUN pip3 install --no-cache-dir vllm==0.6.3 + +# we choose flash-attn v2.7.0 or v2.7.2 which contain pre-built wheels +RUN pip3 install --no-cache-dir --no-build-isolation flash-attn==2.7.0.post2 diff --git a/docker/Dockerfile.vemlp.vllm.te b/docker/Dockerfile.vemlp.vllm.te new file mode 100644 index 0000000..a03d7c1 --- /dev/null +++ b/docker/Dockerfile.vemlp.vllm.te @@ -0,0 +1,41 @@ +# docker buildx build --platform linux/x86_64 -t "verlai/verl:$TAG" -f docker/$FILE . + +# the one in docker.io is an alias for the one veturbo +# FROM vemlp-cn-beijing.cr.volces.com/veturbo/pytorch:2.4-cu124 +FROM docker.io/haibinlin/verl:v0.0.5-th2.4.0-cu124-base + +# only config pip index with https://pypi.tuna.tsinghua.edu.cn/simple if needed +# unset for now +RUN pip3 config unset global.index-url + +# transformers 4.47.0 contains the following bug: +# AttributeError: 'Gemma2Attention' object has no attribute '_flash_attn_uses_top_left_mask' +RUN pip3 install --no-cache-dir \ + torch==2.4.0 \ + accelerate \ + codetiming \ + dill \ + hydra-core \ + numpy \ + pybind11 \ + tensordict \ + "transformers <= 4.46.0" + +RUN pip3 install --no-cache-dir flash-attn==2.7.0.post2 --no-build-isolation + +# vllm depends on ray, and veRL does not support ray > 2.37 +RUN pip3 install --no-cache-dir vllm==0.6.3 ray==2.10 + +# install apex +RUN MAX_JOBS=4 pip3 install -v --disable-pip-version-check --no-cache-dir --no-build-isolation \ + --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" \ + git+https://github.com/NVIDIA/apex + +# install Transformer Engine +# - flash-attn pinned to 2.5.3 by TransformerEngine, switch to eric-haibin-lin/TransformerEngine.git@v1.7.0 to relax version req +# - install with: MAX_JOBS=1 NINJA_FLAGS="-j1" TE_BUILD_WITH_NINJA=0 to avoid OOM +# - cudnn is required by TransformerEngine +# RUN CUDNN_PATH=/opt/conda/lib/python3.11/site-packages/nvidia/cudnn \ +# pip3 install git+https://github.com/eric-haibin-lin/TransformerEngine.git@v1.7.0 +RUN MAX_JOBS=1 NINJA_FLAGS="-j1" pip3 install flash-attn==2.5.3 --no-cache-dir --no-build-isolation +RUN MAX_JOBS=1 NINJA_FLAGS="-j1" pip3 install git+https://github.com/NVIDIA/TransformerEngine.git@v1.7 From ed2eaf4e272efdbea195575319177b912e2a1a3f Mon Sep 17 00:00:00 2001 From: HL Date: Tue, 10 Dec 2024 23:24:21 -0800 Subject: [PATCH 04/14] [misc] docs: add neurips links --- README.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/README.md b/README.md index 83969c3..faa3d1f 100644 --- a/README.md +++ b/README.md @@ -25,6 +25,12 @@ veRL is fast with:

+## News + +- [2024/12] The team will present Post-training LLMs: From Algorithms to Infrastructure at NeurIPS 2024. + - [Slides](https://github.com/eric-haibin-lin/verl-data/tree/neurips), notebooks, and video be available soon +- [2024/08] HybridFlow (verl) is accepted to EuroSys 2025. + ## Installation Guide Below are the steps to install veRL in your environment. From 6e8667bd66022c648536d105e34ffa79d04780fb Mon Sep 17 00:00:00 2001 From: Guangming Sheng Date: Wed, 11 Dec 2024 22:41:22 +0800 Subject: [PATCH 05/14] [example] add a split placement tutorial (#43) * [example] add a split placement tutorial * lint --- examples/split_placement/README.md | 61 ++++++ .../config/ppo_trainer_split.yaml | 131 ++++++++++++ examples/split_placement/main_ppo_split.py | 200 ++++++++++++++++++ .../split_placement/run_deepseek7b_llm.sh | 38 ++++ .../split_placement/split_monkey_patch.py | 161 ++++++++++++++ verl/trainer/ppo/ray_trainer.py | 5 +- 6 files changed, 595 insertions(+), 1 deletion(-) create mode 100644 examples/split_placement/README.md create mode 100644 examples/split_placement/config/ppo_trainer_split.yaml create mode 100644 examples/split_placement/main_ppo_split.py create mode 100644 examples/split_placement/run_deepseek7b_llm.sh create mode 100644 examples/split_placement/split_monkey_patch.py diff --git a/examples/split_placement/README.md b/examples/split_placement/README.md new file mode 100644 index 0000000..a5e4ffd --- /dev/null +++ b/examples/split_placement/README.md @@ -0,0 +1,61 @@ +# Split Placement Example +Here we introduce how to run the naive implementation of the split placement of PPO algorithm. +We will release the complete version of flexible placement in the near future. + + For quickstart, you can only follow Step 2 to modify the code and then follow Step 4 to execute the split placement example. + +### Step 1: Placing the models to different GPUs +Specify the placement and resource allocation. In the example, we place the actor and reference in the first half of the GPUs while map the critic and reward model (if any) to the second half of the GPUs. +```python +actor_rollout_ref_pool_id = 'actor_rollout_ref_pool' +critic_pool_id = 'critic_pool' +if config.trainer.nnodes // 2 == 0 and config.trainer.n_gpus_per_node // 2 > 0: + resource_pool_spec = { + actor_rollout_ref_pool_id: [config.trainer.n_gpus_per_node // 2] * config.trainer.nnodes, + critic_pool_id: [config.trainer.n_gpus_per_node // 2] * config.trainer.nnodes, + } +else: + resource_pool_spec = { + actor_rollout_ref_pool_id: [config.trainer.n_gpus_per_node] * (config.trainer.nnodes // 2), + critic_pool_id: [config.trainer.n_gpus_per_node] * (config.trainer.nnodes // 2), + } +print(f'resource_pool_spec: {resource_pool_spec}') +mapping = { + Role.ActorRollout: actor_rollout_ref_pool_id, + Role.Critic: critic_pool_id, + Role.RefPolicy: actor_rollout_ref_pool_id, +} +mapping[Role.RewardModel] = critic_pool_id +``` + +### Step 2: Make the models executed asynchronously +Based on the model placement, we need to make the models executed asynchronously. + +To do so, you need to turn off the `blocking` flag (i.e., `blocking=False`) in our decorator of some model operations. +For example, we hope the actor update and critic update can be executed in parallel, then we need to make the following modification in `fsdp_workers.py` + +``` +@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO, blocking=False) +def update_actor(self, data: DataProto): + ... + +@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO, blocking=False) +def update_critic(self, data: DataProto): + ... +``` + +We can also parallelize the computation of `ref_log_prob` and `values` and `rewards` in the split placement. For simplicity of the tutorial, we + +### Step 3: Execute these operation in parallel in the single controller process +To implement the parallel execution of the actor and critic update, the only thing we need to modify in the `ray_trainer.py` is to `get` the concurrent `futures` on the single controller process. + +```python +critic_output = critic_output.get() +actor_output = actor_output.get() +``` + +### Step 4: Run the split placement example + +``` +bash run_deepseek7b_llm.sh +``` \ No newline at end of file diff --git a/examples/split_placement/config/ppo_trainer_split.yaml b/examples/split_placement/config/ppo_trainer_split.yaml new file mode 100644 index 0000000..bd6bcf2 --- /dev/null +++ b/examples/split_placement/config/ppo_trainer_split.yaml @@ -0,0 +1,131 @@ +data: + tokenizer: null + train_files: ~/data/rlhf/gsm8k/train.parquet + val_files: ~/data/rlhf/gsm8k/test.parquet + prompt_key: prompt + max_prompt_length: 512 + max_response_length: 512 + train_batch_size: 1024 + val_batch_size: 1312 + return_raw_input_ids: False # This should be set to true when the tokenizer between policy and rm differs + return_raw_chat: False + +actor_rollout_ref: + hybrid_engine: True + model: + path: ~/models/deepseek-llm-7b-chat + external_lib: null + override_config: {} + enable_gradient_checkpointing: False + actor: + strategy: fsdp # This is for backward-compatibility + ppo_mini_batch_size: 256 + ppo_micro_batch_size: 64 + grad_clip: 1.0 + clip_ratio: 0.2 + entropy_coeff: 0.001 + ppo_epochs: 1 + shuffle: True + optim: + lr: 1e-6 + lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime + min_lr_ratio: null # only useful for warmup with cosine + warmup_style: constant # select from constant/cosine + total_training_steps: -1 # must be override by program + fsdp_config: + wrap_policy: + # transformer_layer_cls_to_wrap: None + min_num_params: 0 + param_offload: False + grad_offload: False + optimizer_offload: False + ref: + fsdp_config: + param_offload: False + wrap_policy: + # transformer_layer_cls_to_wrap: None + min_num_params: 0 + log_prob_micro_batch_size: 128 + rollout: + name: vllm + temperature: 1.0 + top_k: -1 # 0 for hf rollout, -1 for vllm rollout + top_p: 1 + prompt_length: ${data.max_prompt_length} # not use for opensource + response_length: ${data.max_response_length} + # for vllm rollout + dtype: bfloat16 # should align with FSDP + gpu_memory_utilization: 0.5 + ignore_eos: False + enforce_eager: True + free_cache_engine: True + load_format: dummy_dtensor + tensor_model_parallel_size: 2 + max_num_batched_tokens: 8192 + max_num_seqs: 1024 + log_prob_micro_batch_size: 128 + # for hf rollout + do_sample: True + +critic: + strategy: fsdp + optim: + lr: 1e-5 + lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime + min_lr_ratio: null # only useful for warmup with cosine + warmup_style: constant # select from constant/cosine + total_training_steps: -1 # must be override by program + model: + path: ~/models/deepseek-llm-7b-chat + tokenizer_path: ${actor_rollout_ref.model.path} + override_config: {} + external_lib: ${actor_rollout_ref.model.external_lib} + enable_gradient_checkpointing: False + fsdp_config: + param_offload: False + grad_offload: False + optimizer_offload: False + wrap_policy: + # transformer_layer_cls_to_wrap: None + min_num_params: 0 + ppo_mini_batch_size: ${actor_rollout_ref.actor.ppo_mini_batch_size} + ppo_micro_batch_size: 64 + ppo_epochs: ${actor_rollout_ref.actor.ppo_epochs} + shuffle: ${actor_rollout_ref.actor.shuffle} + grad_clip: 1.0 + cliprange_value: 0.5 + +reward_model: + enable: False + strategy: fsdp + model: + input_tokenizer: ${actor_rollout_ref.model.path} # set this to null if the chat template is identical + path: ~/models/FsfairX-LLaMA3-RM-v0.1 + external_lib: ${actor_rollout_ref.model.external_lib} + fsdp_config: + min_num_params: 0 + param_offload: False + micro_batch_size: 64 + max_length: null + +algorithm: + gamma: 1.0 + lam: 1.0 + adv_estimator: gae + kl_penalty: kl # how to estimate kl divergence + kl_ctrl: + type: fixed + kl_coef: 0.001 + +trainer: + total_epochs: 30 + project_name: verl_examples + experiment_name: gsm8k + logger: ['console', 'tracking'] + nnodes: 1 + n_gpus_per_node: 8 + save_freq: -1 + test_freq: 2 + critic_warmup: 0 + default_hdfs_dir: ~/experiments/gsm8k/ppo/${trainer.experiment_name} + default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name} diff --git a/examples/split_placement/main_ppo_split.py b/examples/split_placement/main_ppo_split.py new file mode 100644 index 0000000..524f35e --- /dev/null +++ b/examples/split_placement/main_ppo_split.py @@ -0,0 +1,200 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Note that we don't combine the main with ray_trainer as ray_trainer is used by other main. +""" + +from verl import DataProto +import torch +from verl.utils.reward_score import gsm8k, math +from verl.trainer.ppo.ray_trainer import RayPPOTrainer + + +def _select_rm_score_fn(data_source): + if data_source == 'openai/gsm8k': + return gsm8k.compute_score + elif data_source == 'lighteval/MATH': + return math.compute_score + else: + raise NotImplementedError + + +class RewardManager(): + + def __init__(self, tokenizer, num_examine) -> None: + self.tokenizer = tokenizer + self.num_examine = num_examine # the number of batches of decoded responses to print to the console + + def __call__(self, data: DataProto): + """We will expand this function gradually based on the available datasets""" + + # If there is rm score, we directly return rm score. Otherwise, we compute via rm_score_fn + if 'rm_scores' in data.batch.keys(): + return data.batch['rm_scores'] + + reward_tensor = torch.zeros_like(data.batch['responses'], dtype=torch.float32) + + already_print_data_sources = {} + + for i in range(len(data)): + data_item = data[i] # DataProtoItem + + prompt_ids = data_item.batch['prompts'] + + prompt_length = prompt_ids.shape[-1] + + valid_prompt_length = data_item.batch['attention_mask'][:prompt_length].sum() + valid_prompt_ids = prompt_ids[-valid_prompt_length:] + + response_ids = data_item.batch['responses'] + valid_response_length = data_item.batch['attention_mask'][prompt_length:].sum() + valid_response_ids = response_ids[:valid_response_length] + + # decode + sequences = torch.cat((valid_prompt_ids, valid_response_ids)) + sequences_str = self.tokenizer.decode(sequences) + + ground_truth = data_item.non_tensor_batch['reward_model']['ground_truth'] + + # select rm_score + data_source = data_item.non_tensor_batch['data_source'] + compute_score_fn = _select_rm_score_fn(data_source) + + score = compute_score_fn(solution_str=sequences_str, ground_truth=ground_truth) + reward_tensor[i, valid_response_length - 1] = score + + if data_source not in already_print_data_sources: + already_print_data_sources[data_source] = 0 + + if already_print_data_sources[data_source] < self.num_examine: + already_print_data_sources[data_source] += 1 + print(sequences_str) + + return reward_tensor + + +import ray +import hydra +from split_monkey_patch import fit + + +@hydra.main(config_path='config', config_name='ppo_trainer_split', version_base=None) +def main(config): + if not ray.is_initialized(): + # this is for local ray cluster + ray.init(runtime_env={'env_vars': {'TOKENIZERS_PARALLELISM': 'true', 'NCCL_DEBUG': 'WARN'}}) + + ray.get(main_task.remote(config)) + + +@ray.remote +def main_task(config): + from verl.utils.fs import copy_local_path_from_hdfs + from transformers import AutoTokenizer + + # print initial config + from pprint import pprint + from omegaconf import OmegaConf + pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values + OmegaConf.resolve(config) + + # download the checkpoint from hdfs + local_path = copy_local_path_from_hdfs(config.actor_rollout_ref.model.path) + + # instantiate tokenizer + tokenizer = AutoTokenizer.from_pretrained(local_path) + from verl.utils import set_pad_token_id + set_pad_token_id(tokenizer) + + # define worker classes + if config.actor_rollout_ref.actor.strategy == 'fsdp': + assert config.actor_rollout_ref.actor.strategy == config.critic.strategy + from verl.trainer.ppo.workers.fsdp_workers import ActorRolloutRefWorker, CriticWorker + from single_controller.ray import RayWorkerGroup + ray_worker_group_cls = RayWorkerGroup + + elif config.actor_rollout_ref.actor.strategy == 'megatron': + assert config.actor_rollout_ref.actor.strategy == config.critic.strategy + from verl.trainer.ppo.workers.megatron_workers import ActorRolloutRefWorker, CriticWorker + from single_controller.ray.megatron import NVMegatronRayWorkerGroup + ray_worker_group_cls = NVMegatronRayWorkerGroup + + else: + raise NotImplementedError + + from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role + + role_worker_mapping = { + Role.ActorRollout: ActorRolloutRefWorker, + Role.Critic: CriticWorker, + Role.RefPolicy: ActorRolloutRefWorker + } + + # NOTE: initialze two resource pool + actor_rollout_ref_pool_id = 'actor_rollout_ref_pool' + critic_pool_id = 'critic_pool' + if config.trainer.nnodes // 2 == 0 and config.trainer.n_gpus_per_node // 2 > 0: + resource_pool_spec = { + actor_rollout_ref_pool_id: [config.trainer.n_gpus_per_node // 2] * config.trainer.nnodes, + critic_pool_id: [config.trainer.n_gpus_per_node // 2] * config.trainer.nnodes, + } + else: + resource_pool_spec = { + actor_rollout_ref_pool_id: [config.trainer.n_gpus_per_node] * (config.trainer.nnodes // 2), + critic_pool_id: [config.trainer.n_gpus_per_node] * (config.trainer.nnodes // 2), + } + print(f'resource_pool_spec: {resource_pool_spec}') + mapping = { + Role.ActorRollout: actor_rollout_ref_pool_id, + Role.Critic: critic_pool_id, + Role.RefPolicy: actor_rollout_ref_pool_id, + } + + # we should adopt a multi-source reward function here + # - for rule-based rm, we directly call a reward score + # - for model-based rm, we call a model + # - for code related prompt, we send to a sandbox if there are test cases + # - finally, we combine all the rewards together + # - The reward type depends on the tag of the data + if config.reward_model.enable: + if config.reward_model.strategy == 'fsdp': + from verl.trainer.ppo.workers.fsdp_workers import RewardModelWorker + elif config.reward_model.strategy == 'megatron': + from verl.trainer.ppo.workers.megatron_workers import RewardModelWorker + else: + raise NotImplementedError + role_worker_mapping[Role.RewardModel] = RewardModelWorker + mapping[Role.RewardModel] = critic_pool_id + + reward_fn = RewardManager(tokenizer=tokenizer, num_examine=0) + + # Note that we always use function-based RM for validation + val_reward_fn = RewardManager(tokenizer=tokenizer, num_examine=1) + + resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping) + + RayPPOTrainer.fit = fit + trainer = RayPPOTrainer(config=config, + tokenizer=tokenizer, + role_worker_mapping=role_worker_mapping, + resource_pool_manager=resource_pool_manager, + ray_worker_group_cls=ray_worker_group_cls, + reward_fn=reward_fn, + val_reward_fn=val_reward_fn) + trainer.init_workers() + trainer.fit() + + +if __name__ == '__main__': + main() diff --git a/examples/split_placement/run_deepseek7b_llm.sh b/examples/split_placement/run_deepseek7b_llm.sh new file mode 100644 index 0000000..6afd399 --- /dev/null +++ b/examples/split_placement/run_deepseek7b_llm.sh @@ -0,0 +1,38 @@ +set -x + +python3 main_ppo_split.py \ + data.train_files=$HOME/data/gsm8k/train.parquet \ + data.val_files=$HOME/data/gsm8k/test.parquet \ + data.train_batch_size=1024 \ + data.val_batch_size=1312 \ + data.max_prompt_length=512 \ + data.max_response_length=512 \ + actor_rollout_ref.model.path=deepseek-ai/deepseek-llm-7b-chat \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.ppo_mini_batch_size=256 \ + actor_rollout_ref.actor.ppo_micro_batch_size=16 \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.grad_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.rollout.log_prob_micro_batch_size=32 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=4 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \ + actor_rollout_ref.ref.log_prob_micro_batch_size=32 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + critic.optim.lr=1e-5 \ + critic.model.path=deepseek-ai/deepseek-llm-7b-chat \ + critic.model.enable_gradient_checkpointing=False \ + critic.ppo_micro_batch_size=16 \ + critic.model.fsdp_config.param_offload=False \ + critic.model.fsdp_config.grad_offload=False \ + critic.model.fsdp_config.optimizer_offload=False \ + algorithm.kl_ctrl.kl_coef=0.001 \ + trainer.critic_warmup=0 \ + trainer.logger=['console','tracking'] \ + trainer.project_name='verl_example_gsm8k' \ + trainer.experiment_name='deepseek_llm_7b_function_rm' \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=1 \ + trainer.save_freq=-1 \ + trainer.total_epochs=15 $@ diff --git a/examples/split_placement/split_monkey_patch.py b/examples/split_placement/split_monkey_patch.py new file mode 100644 index 0000000..70ed267 --- /dev/null +++ b/examples/split_placement/split_monkey_patch.py @@ -0,0 +1,161 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +An naive implementation of split placment example +""" +import os +from pprint import pprint +from single_controller.ray import RayResourcePool, RayWorkerGroup, RayClassWithInitArgs +from verl import DataProto +from verl.trainer.ppo.ray_trainer import compute_advantage, apply_kl_penalty, reduce_metrics, compute_data_metrics, Role, create_colocated_worker_cls +from codetiming import Timer + + +def fit(self): + """ + The training loop of PPO. + The driver process only need to call the compute functions of the worker group through RPC to construct the PPO dataflow. + The light-weight advantage computation is done on the driver process. + """ + from verl.utils.tracking import Tracking + from omegaconf import OmegaConf + + logger = Tracking(project_name=self.config.trainer.project_name, + experiment_name=self.config.trainer.experiment_name, + default_backend=self.config.trainer.logger, + config=OmegaConf.to_container(self.config, resolve=True)) + + global_steps = 0 + + # perform validation before training + # currently, we only support validation using the reward_function. + if self.val_reward_fn is not None: + val_metrics = self._validate() + pprint(f'Initial validation metrics: {val_metrics}') + + for epoch in range(self.config.trainer.total_epochs): + for batch_dict in self.train_dataloader: + metrics = {} + + batch: DataProto = DataProto.from_single_dict(batch_dict) + # batch = batch.to('cuda') + + # pop those keys for generation + gen_batch = batch.pop(batch_keys=['input_ids', 'attention_mask', 'position_ids']) + + # generate a batch + with Timer(name='gen', logger=None) as timer: + gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch) + metrics['timing/gen'] = timer.last + + batch = batch.union(gen_batch_output) + + if self.use_reference_policy: + # compute reference log_prob + with Timer(name='ref', logger=None) as timer: + ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch) + batch = batch.union(ref_log_prob) + metrics['timing/ref'] = timer.last + + # compute values + with Timer(name='values', logger=None) as timer: + values = self.critic_wg.compute_values(batch) + batch = batch.union(values) + metrics['timing/values'] = timer.last + + with Timer(name='adv', logger=None) as timer: + # compute scores. Support both model and function-based. + # We first compute the scores using reward model. Then, we call reward_fn to combine + # the results from reward model and rule-based results. + if self.use_rm: + # we first compute reward model score + reward_tensor = self.rm_wg.compute_rm_score(batch) + batch = batch.union(reward_tensor) + + # we combine with rule-based rm + reward_tensor = self.reward_fn(batch) + batch.batch['token_level_scores'] = reward_tensor + + # compute rewards. apply_kl_penalty if available + batch, kl_metrics = apply_kl_penalty(batch, + kl_ctrl=self.kl_ctrl, + kl_penalty=self.config.algorithm.kl_penalty) + metrics.update(kl_metrics) + + # compute advantages, executed on the driver process + batch = compute_advantage(batch, + self.config.algorithm.gamma, + self.config.algorithm.lam, + adv_estimator=self.config.algorithm.adv_estimator) + metrics['timing/adv'] = timer.last + + # update critic + if self.use_critic: + with Timer(name='update_critic_call', logger=None) as timer: + critic_output = self.critic_wg.update_critic(batch) + metrics['timing/update_critic_call'] = timer.last + + # implement critic warmup + if self.config.trainer.critic_warmup <= global_steps: + # update actor + with Timer(name='update_actor_call', logger=None) as timer: + actor_output = self.actor_rollout_wg.update_actor(batch) + metrics['timing/update_acto_call'] = timer.last + + # NOTE: make sure you set blocking=False in update_actor and update_crtic in the worker class + with Timer(name='update_actor_critic', logger=None) as timer: + # NOTE: get the DataProtoFuture + critic_output = critic_output.get() + critic_output_metrics = reduce_metrics(critic_output.meta_info['metrics']) + metrics.update(critic_output_metrics) + + # NOTE: get the DataProtoFuture + actor_output = actor_output.get() + actor_output_metrics = reduce_metrics(actor_output.meta_info['metrics']) + metrics.update(actor_output_metrics) + metrics['timing/update_actor_critic'] = timer.last + + # validate + if self.val_reward_fn is not None and (global_steps + 1) % self.config.trainer.test_freq == 0: + with Timer(name='testing', logger=None) as timer: + val_metrics: dict = self._validate() + val_metrics = {f'val/{key}': val for key, val in val_metrics.items()} + metrics['timing/testing'] = timer.last + metrics.update(val_metrics) + + # collect metrics + data_metrics = compute_data_metrics(batch=batch) + metrics.update(data_metrics) + + # TODO: make a canonical logger that supports various backend + logger.log(data=metrics, step=global_steps) + + if self.config.trainer.save_freq > 0 and (global_steps + 1) % self.config.trainer.save_freq == 0: + actor_local_path = os.path.join(self.config.trainer.default_local_dir, 'actor', + f'global_step_{global_steps}') + actor_remote_path = os.path.join(self.config.trainer.default_hdfs_dir, 'actor') + self.actor_rollout_wg.save_checkpoint(actor_local_path, actor_remote_path) + + if self.use_critic: + critic_local_path = os.path.join(self.config.trainer.default_local_dir, 'critic', + f'global_step_{global_steps}') + critic_remote_path = os.path.join(self.config.trainer.default_hdfs_dir, 'critic') + self.critic_wg.save_checkpoint(critic_local_path, critic_remote_path) + + global_steps += 1 + + # perform validation after training + if self.val_reward_fn is not None: + val_metrics = self._validate() + pprint(f'Final validation metrics: {val_metrics}') diff --git a/verl/trainer/ppo/ray_trainer.py b/verl/trainer/ppo/ray_trainer.py index 1316e14..95814f5 100644 --- a/verl/trainer/ppo/ray_trainer.py +++ b/verl/trainer/ppo/ray_trainer.py @@ -63,7 +63,10 @@ def create_resource_pool(self): # Due to the Ray issue, we can only support max_colocate_count=1 for now. # This means that each GPU can only have one process. # We can support max_colocate > 1 when applying this pull request: https://github.com/ray-project/ray/pull/44385 - resource_pool = RayResourcePool(process_on_nodes=process_on_nodes, use_gpu=True, max_colocate_count=1) + resource_pool = RayResourcePool(process_on_nodes=process_on_nodes, + use_gpu=True, + max_colocate_count=1, + name_prefix=resource_pool_name) self.resource_pool_dict[resource_pool_name] = resource_pool def get_resource_pool(self, role: Role) -> RayResourcePool: From 1b24a3a8847d0863efb477f561ab83a903546a9c Mon Sep 17 00:00:00 2001 From: Guangming Sheng Date: Wed, 11 Dec 2024 22:57:59 +0800 Subject: [PATCH 06/14] [doc] add a new quickstart section (#44) --- docs/index.rst | 12 +- docs/{preparation => start}/install.rst | 0 docs/start/quickstart.rst | 172 ++++++++++++++++++++++++ 3 files changed, 182 insertions(+), 2 deletions(-) rename docs/{preparation => start}/install.rst (100%) create mode 100644 docs/start/quickstart.rst diff --git a/docs/index.rst b/docs/index.rst index 0cadb74..c91a8f9 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -26,11 +26,19 @@ veRL is fast with: .. toctree:: :maxdepth: 5 - :caption: Preparation + :caption: Quickstart + :titlesonly: + :numbered: + + start/install + start/quickstart + +.. toctree:: + :maxdepth: 5 + :caption: Data Preparation :titlesonly: :numbered: - preparation/install preparation/prepare_data preparation/reward_function diff --git a/docs/preparation/install.rst b/docs/start/install.rst similarity index 100% rename from docs/preparation/install.rst rename to docs/start/install.rst diff --git a/docs/start/quickstart.rst b/docs/start/quickstart.rst new file mode 100644 index 0000000..2ac6845 --- /dev/null +++ b/docs/start/quickstart.rst @@ -0,0 +1,172 @@ +.. _quickstart: + +========== +Quickstart: Fintune a LLM using PPO with GSM8K dataset +========== + +Post-train a LLM using GSM8K dataset +==================== + +Introduction +------------ + +In this example, we train an LLM to tackle the GSM8k task. + +Paper: https://arxiv.org/pdf/2110.14168 + +Dataset: https://huggingface.co/datasets/gsm8k + +Note that the original paper mainly focuses on training a verifier (a +reward model) to solve math problems via Best-of-N sampling. In this +example, we train an RLHF agent using a rule-based reward model. + +Dataset Introduction +-------------------- + +GSM8k is a math problem dataset. The prompt is an elementary school +problem. The LLM model is required to answer the math problem. + +The training set contains 7473 samples and the test set contains 1319 +samples. + +**An example** + +Prompt + + Katy makes coffee using teaspoons of sugar and cups of water in the + ratio of 7:13. If she used a total of 120 teaspoons of sugar and cups + of water, calculate the number of teaspoonfuls of sugar she used. + +Solution + + The total ratio representing the ingredients she used to make the + coffee is 7+13 = <<7+13=20>>20 Since the fraction representing the + number of teaspoons she used is 7/20, she used 7/20\ *120 = + <<7/20*\ 120=42>>42 #### 42 + +Step 1: Prepare dataset +----------------------- + +.. code:: bash + + cd examples/data_preprocess + python3 gsm8k.py --local_dir ~/data/gsm8k + +Step 2: Download Model +---------------------- + +There’re three ways to prepare the model checkpoints for post-training: + +- Download the required models from huggingface + +.. code:: bash + + huggingface-cli download deepseek-ai/deepseek-math-7b-instruct --local-dir ~/models/deepseek-math-7b-instruct --local-dir-use-symlinks False + +- Already store your store model in the local directory or HDFS path. +- Also, you can directly use the model name in huggingface (e.g., + deepseek-ai/deepseek-math-7b-instruct) in + ``actor_rollout_ref.model.path`` and ``critic.model.path`` field in + the run script. + +Noted that users should prepare checkpoints for actor, critic and reward +model. + +[Optional] Step 3: SFT your Model +--------------------------------- + +We provide a SFT Trainer using PyTorch FSDP in +`fsdp_sft_trainer.py `_. +Users can customize their own SFT +script using our FSDP SFT Trainer. + +We also provide various training scripts for SFT on GSM8K dataset in `gsm8k sft directory `_. + +.. code:: shell + + set -x + + torchrun -m verl.trainer.fsdp_sft_trainer \ + data.train_files=$HOME/data/gsm8k/train.parquet \ + data.val_files=$HOME/data/gsm8k/test.parquet \ + data.prompt_key=question \ + data.response_key=answer \ + data.micro_batch_size=8 \ + model.partial_pretrain=deepseek-ai/deepseek-coder-6.7b-instruct \ + trainer.default_hdfs_dir=hdfs://user/verl/experiments/gsm8k/deepseek-coder-6.7b-instruct/ \ + trainer.project_name=gsm8k-sft \ + trainer.experiment_name=gsm8k-sft-deepseek-coder-6.7b-instruct \ + trainer.total_epochs=4 \ + trainer.logger=['console','tracking'] + +Step 4: Perform PPO training with your model on GSM8K Dataset +------------------------------------------------------------- + +- Prepare your own run.sh script. Here’s an example for GSM8k dataset + and deepseek-llm-7b-chat model. +- Users could replace the ``data.train_files`` ,\ ``data.val_files``, + ``actor_rollout_ref.model.path`` and ``critic.model.path`` based on + their environment. +- See :doc:`config` for detailed explaination of each config field. + +**Reward Model/Function** + +We use a rule-based reward model. We force the model to produce a final +answer following 4 “#” as shown in the solution. We extract the final +answer from both the solution and model’s output using regular +expression matching. We compare them and assign a reward of 1 to correct +answer, 0.1 to incorrect answer and 0 to no answer. + +**Training Script** + +The training script example for FSDP and Megatron-LM backend are stored in +`examples/ppo_trainer `_ directory. + +.. code:: bash + + cd ../ppo_trainer + bash run_deepseek7b_llm.sh + +The script of `run_deepseek7b_llm.sh` + +.. code:: bash + + set -x + + python3 -m verl.trainer.main_ppo \ + data.train_files=~/data/rlhf/gsm8k/train.parquet \ + data.val_files=~/data/rlhf/gsm8k/test.parquet \ + data.train_batch_size=1024 \ + data.val_batch_size=1312 \ + data.max_prompt_length=512 \ + data.max_response_length=512 \ + actor_rollout_ref.model.path=~/models/deepseek-llm-7b-chat \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.ppo_mini_batch_size=256 \ + actor_rollout_ref.actor.ppo_micro_batch_size=64 \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.grad_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.rollout.micro_batch_size=256 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size=128 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \ + actor_rollout_ref.ref.log_prob_micro_batch_size=128 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + critic.optim.lr=1e-5 \ + critic.model.path=~/models/deepseek-llm-7b-chat \ + critic.model.enable_gradient_checkpointing=False \ + critic.ppo_micro_batch_size=64 \ + critic.model.fsdp_config.param_offload=False \ + critic.model.fsdp_config.grad_offload=False \ + critic.model.fsdp_config.optimizer_offload=False \ + algorithm.kl_ctrl.kl_coef=0.001 \ + trainer.critic_warmup=0 \ + trainer.logger=['console','tracking'] \ + trainer.project_name='verl_example_gsm8k' \ + trainer.experiment_name='deepseek_llm_7b_function_rm' \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=1 \ + trainer.save_freq=-1 \ + trainer.total_epochs=15 From cd6cef609e27490e2db9c3a8bea10558f3ab1547 Mon Sep 17 00:00:00 2001 From: Guangming Sheng Date: Wed, 11 Dec 2024 23:48:24 +0800 Subject: [PATCH 07/14] [BREAKING][core] move single_controller into verl directory (#45) * [BREAKING][core] move single_controller into verl directory * fix blocking flag in fsdp workers --- .github/workflows/yapf_format.yml | 2 +- README.md | 2 +- docs/advance/dpo_extension.rst | 14 +++++------ docs/examples/ppo_code_architecture.rst | 4 ++-- docs/workers/megatron_workers.rst | 2 +- examples/ray/tutorial.ipynb | 16 ++++++------- examples/split_placement/main_ppo_split.py | 4 ++-- .../split_placement/split_monkey_patch.py | 2 +- tests/ray/check_worker_alive/main.py | 6 ++--- tests/ray/detached_worker/client.py | 4 ++-- tests/ray/detached_worker/server.py | 8 +++---- tests/ray/test_colocated_workers.py | 6 ++--- tests/ray/test_data_transfer.py | 6 ++--- tests/ray/test_driverfunc_to_worker.py | 6 ++--- tests/ray/test_high_level_scheduling_api.py | 4 ++-- tests/ray/test_ray_local_envs.py | 8 +++---- tests/ray/test_remote_api.py | 6 ++--- tests/ray/test_worker_group_basics.py | 8 +++---- tests/ray/test_worker_group_torch.py | 4 ++-- .../single_controller}/__init__.py | 0 .../single_controller}/base/__init__.py | 0 .../single_controller}/base/decorator.py | 24 +++++++++---------- .../single_controller}/base/dp.py | 2 +- .../base/megatron/__init__.py | 0 .../base/megatron/worker.py | 2 +- .../base/megatron/worker_group.py | 2 +- .../base/register_center/__init__.py | 0 .../base/register_center/ray.py | 0 .../single_controller}/base/worker.py | 6 ++--- .../single_controller}/base/worker_group.py | 2 +- .../single_controller}/ray/__init__.py | 0 .../single_controller}/ray/base.py | 4 ++-- .../single_controller}/ray/decorator.py | 2 +- .../ray/dist_data_pass_protocol.py | 0 .../single_controller}/ray/dp.py | 2 +- .../single_controller}/ray/megatron.py | 4 ++-- .../single_controller}/version/version | 0 verl/trainer/main_generation.py | 2 +- verl/trainer/main_ppo.py | 4 ++-- verl/trainer/ppo/ray_trainer.py | 6 ++--- verl/trainer/ppo/workers/fsdp_workers.py | 4 ++-- verl/trainer/ppo/workers/megatron_workers.py | 4 ++-- 42 files changed, 91 insertions(+), 91 deletions(-) rename {single_controller => verl/single_controller}/__init__.py (100%) rename {single_controller => verl/single_controller}/base/__init__.py (100%) rename {single_controller => verl/single_controller}/base/decorator.py (93%) rename {single_controller => verl/single_controller}/base/dp.py (96%) rename {single_controller => verl/single_controller}/base/megatron/__init__.py (100%) rename {single_controller => verl/single_controller}/base/megatron/worker.py (94%) rename {single_controller => verl/single_controller}/base/megatron/worker_group.py (96%) rename {single_controller => verl/single_controller}/base/register_center/__init__.py (100%) rename {single_controller => verl/single_controller}/base/register_center/ray.py (100%) rename {single_controller => verl/single_controller}/base/worker.py (95%) rename {single_controller => verl/single_controller}/base/worker_group.py (98%) rename {single_controller => verl/single_controller}/ray/__init__.py (100%) rename {single_controller => verl/single_controller}/ray/base.py (99%) rename {single_controller => verl/single_controller}/ray/decorator.py (97%) rename {single_controller => verl/single_controller}/ray/dist_data_pass_protocol.py (100%) rename {single_controller => verl/single_controller}/ray/dp.py (97%) rename {single_controller => verl/single_controller}/ray/megatron.py (94%) rename {single_controller => verl/single_controller}/version/version (100%) diff --git a/.github/workflows/yapf_format.yml b/.github/workflows/yapf_format.yml index c6d27c3..548df4c 100644 --- a/.github/workflows/yapf_format.yml +++ b/.github/workflows/yapf_format.yml @@ -42,4 +42,4 @@ jobs: pip install toml==0.10.2 - name: Running yapf run: | - yapf -r -vv -d --style=./.style.yapf verl tests single_controller examples + yapf -r -vv -d --style=./.style.yapf verl tests examples diff --git a/README.md b/README.md index faa3d1f..2f96ba3 100644 --- a/README.md +++ b/README.md @@ -180,7 +180,7 @@ pip3 install yapf ``` Then, make sure you are at top level of verl repo and run ```bash -yapf -ir -vv --style ./.style.yapf verl single_controller examples +yapf -ir -vv --style ./.style.yapf verl examples ``` diff --git a/docs/advance/dpo_extension.rst b/docs/advance/dpo_extension.rst index fb7754c..592d971 100644 --- a/docs/advance/dpo_extension.rst +++ b/docs/advance/dpo_extension.rst @@ -47,8 +47,8 @@ Implementation details: .. code:: python - from single_controller.base import Worker - from single_controller.ray import RayWorkerGroup, RayClassWithInitArgs, RayResourcePool + from verl.single_controller.base import Worker + from verl.single_controller.ray import RayWorkerGroup, RayClassWithInitArgs, RayResourcePool import ray @ray.remote @@ -75,7 +75,7 @@ API: compute reference log probability .. code:: python - from single_controller.base import Worker + from verl.single_controller.base import Worker import ray @ray.remote @@ -93,7 +93,7 @@ API: Update actor model parameters .. code:: python - from single_controller.base import Worker + from verl.single_controller.base import Worker import ray @ray.remote @@ -184,7 +184,7 @@ registered into the worker_group** .. code:: python - from single_controller.base.decorator import register + from verl.single_controller.base.decorator import register def dispatch_data(worker_group, data): return data.chunk(worker_group.world_size) @@ -214,11 +214,11 @@ computation, and data collection. Furthermore, the model parallelism size of each model is usually fixed, including dp, tp, pp. So for these common distributed scenarios, we have -pre-implemented specific dispatch and collect methods,in `decorator.py `_, which can be directly used to wrap the computations. +pre-implemented specific dispatch and collect methods,in `decorator.py `_, which can be directly used to wrap the computations. .. code:: python - from single_controller.base.decorator import register, Dispatch + from verl.single_controller.base.decorator import register, Dispatch @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) def generate_sequences(self, data: DataProto) -> DataProto: diff --git a/docs/examples/ppo_code_architecture.rst b/docs/examples/ppo_code_architecture.rst index bd247a2..ab1f66a 100644 --- a/docs/examples/ppo_code_architecture.rst +++ b/docs/examples/ppo_code_architecture.rst @@ -49,13 +49,13 @@ Define worker classes if config.actor_rollout_ref.actor.strategy == 'fsdp': # for FSDP backend assert config.actor_rollout_ref.actor.strategy == config.critic.strategy from verl.trainer.ppo.workers.fsdp_workers import ActorRolloutRefWorker, CriticWorker - from single_controller.ray import RayWorkerGroup + from verl.single_controller.ray import RayWorkerGroup ray_worker_group_cls = RayWorkerGroup elif config.actor_rollout_ref.actor.strategy == 'megatron': # for Megatron backend assert config.actor_rollout_ref.actor.strategy == config.critic.strategy from verl.trainer.ppo.workers.megatron_workers import ActorRolloutRefWorker, CriticWorker - from single_controller.ray.megatron import NVMegatronRayWorkerGroup + from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup ray_worker_group_cls = NVMegatronRayWorkerGroup # Ray worker class for Megatron-LM else: diff --git a/docs/workers/megatron_workers.rst b/docs/workers/megatron_workers.rst index 5a8f5ed..d6f88c3 100644 --- a/docs/workers/megatron_workers.rst +++ b/docs/workers/megatron_workers.rst @@ -40,7 +40,7 @@ We implement various of APIs for each ``Worker`` class decorated by the ``@register(dispatch_mode=)`` . These APIs can be called by the ray driver process. The data can be correctly collect and dispatch following the ``dispatch_mode`` on each function. The supported dispatch_model -(i.e., transfer protocols) can be found in `decorator.py `_. +(i.e., transfer protocols) can be found in `decorator.py `_. ActorRolloutRefWorker ^^^^^^^^^^^^^^^^^^^^^ diff --git a/examples/ray/tutorial.ipynb b/examples/ray/tutorial.ipynb index 9b8591a..f270cd9 100644 --- a/examples/ray/tutorial.ipynb +++ b/examples/ray/tutorial.ipynb @@ -232,8 +232,8 @@ }, "outputs": [], "source": [ - "from single_controller.ray.base import RayResourcePool, RayClassWithInitArgs, RayWorkerGroup, merge_resource_pool\n", - "from single_controller.base import Worker" + "from verl.single_controller.ray.base import RayResourcePool, RayClassWithInitArgs, RayWorkerGroup, merge_resource_pool\n", + "from verl.single_controller.base import Worker" ] }, { @@ -437,7 +437,7 @@ }, "outputs": [], "source": [ - "from single_controller.ray.decorator import register, Dispatch, Execute" + "from verl.single_controller.ray.decorator import register, Dispatch, Execute" ] }, { @@ -518,7 +518,7 @@ }, "outputs": [], "source": [ - "from single_controller.ray.decorator import register, Dispatch, collect_all_to_all, Execute" + "from verl.single_controller.ray.decorator import register, Dispatch, collect_all_to_all, Execute" ] }, { @@ -723,10 +723,10 @@ }, "outputs": [], "source": [ - "from single_controller.ray.decorator import register, Dispatch, Execute\n", - "from single_controller.ray.megatron import NVMegatronRayWorkerGroup\n", - "from single_controller.base.megatron.worker import MegatronWorker\n", - "from single_controller.ray.base import RayResourcePool, RayClassWithInitArgs, RayWorkerGroup\n", + "from verl.single_controller.ray.decorator import register, Dispatch, Execute\n", + "from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup\n", + "from verl.single_controller.base.megatron.worker import MegatronWorker\n", + "from verl.single_controller.ray.base import RayResourcePool, RayClassWithInitArgs, RayWorkerGroup\n", "from omegaconf import OmegaConf\n", "from megatron.core import parallel_state as mpu" ] diff --git a/examples/split_placement/main_ppo_split.py b/examples/split_placement/main_ppo_split.py index 524f35e..5ae4b21 100644 --- a/examples/split_placement/main_ppo_split.py +++ b/examples/split_placement/main_ppo_split.py @@ -121,13 +121,13 @@ def main_task(config): if config.actor_rollout_ref.actor.strategy == 'fsdp': assert config.actor_rollout_ref.actor.strategy == config.critic.strategy from verl.trainer.ppo.workers.fsdp_workers import ActorRolloutRefWorker, CriticWorker - from single_controller.ray import RayWorkerGroup + from verl.single_controller.ray import RayWorkerGroup ray_worker_group_cls = RayWorkerGroup elif config.actor_rollout_ref.actor.strategy == 'megatron': assert config.actor_rollout_ref.actor.strategy == config.critic.strategy from verl.trainer.ppo.workers.megatron_workers import ActorRolloutRefWorker, CriticWorker - from single_controller.ray.megatron import NVMegatronRayWorkerGroup + from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup ray_worker_group_cls = NVMegatronRayWorkerGroup else: diff --git a/examples/split_placement/split_monkey_patch.py b/examples/split_placement/split_monkey_patch.py index 70ed267..5e09377 100644 --- a/examples/split_placement/split_monkey_patch.py +++ b/examples/split_placement/split_monkey_patch.py @@ -16,7 +16,7 @@ """ import os from pprint import pprint -from single_controller.ray import RayResourcePool, RayWorkerGroup, RayClassWithInitArgs +from verl.single_controller.ray import RayResourcePool, RayWorkerGroup, RayClassWithInitArgs from verl import DataProto from verl.trainer.ppo.ray_trainer import compute_advantage, apply_kl_penalty, reduce_metrics, compute_data_metrics, Role, create_colocated_worker_cls from codetiming import Timer diff --git a/tests/ray/check_worker_alive/main.py b/tests/ray/check_worker_alive/main.py index 9526f1b..fcebbfe 100644 --- a/tests/ray/check_worker_alive/main.py +++ b/tests/ray/check_worker_alive/main.py @@ -18,9 +18,9 @@ import ray -from single_controller.ray.base import RayResourcePool, RayClassWithInitArgs, RayWorkerGroup -from single_controller.base.worker import Worker -from single_controller.base.decorator import register, Dispatch +from verl.single_controller.ray.base import RayResourcePool, RayClassWithInitArgs, RayWorkerGroup +from verl.single_controller.base.worker import Worker +from verl.single_controller.base.decorator import register, Dispatch @ray.remote diff --git a/tests/ray/detached_worker/client.py b/tests/ray/detached_worker/client.py index 0595bcf..1773fff 100644 --- a/tests/ray/detached_worker/client.py +++ b/tests/ray/detached_worker/client.py @@ -19,8 +19,8 @@ import torch from verl import DataProto -from single_controller.ray import RayClassWithInitArgs -from single_controller.ray.megatron import NVMegatronRayWorkerGroup +from verl.single_controller.ray import RayClassWithInitArgs +from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup from tensordict import TensorDict diff --git a/tests/ray/detached_worker/server.py b/tests/ray/detached_worker/server.py index 1842f0a..c8057e3 100644 --- a/tests/ray/detached_worker/server.py +++ b/tests/ray/detached_worker/server.py @@ -25,10 +25,10 @@ from torch import nn import ray -from single_controller.ray import RayClassWithInitArgs, RayResourcePool -from single_controller.ray.megatron import NVMegatronRayWorkerGroup -from single_controller.base.megatron.worker import MegatronWorker -from single_controller.ray.decorator import register, Dispatch +from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool +from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup +from verl.single_controller.base.megatron.worker import MegatronWorker +from verl.single_controller.ray.decorator import register, Dispatch from verl import DataProto from verl.models.llama.megatron import ParallelLlamaForCausalLMRmPadPP diff --git a/tests/ray/test_colocated_workers.py b/tests/ray/test_colocated_workers.py index 0515400..96b859b 100644 --- a/tests/ray/test_colocated_workers.py +++ b/tests/ray/test_colocated_workers.py @@ -14,9 +14,9 @@ import ray -from single_controller.base import Worker -from single_controller.base.decorator import register, Dispatch -from single_controller.ray.base import RayResourcePool, RayClassWithInitArgs, RayWorkerGroup, create_colocated_worker_cls +from verl.single_controller.base import Worker +from verl.single_controller.base.decorator import register, Dispatch +from verl.single_controller.ray.base import RayResourcePool, RayClassWithInitArgs, RayWorkerGroup, create_colocated_worker_cls from verl import DataProto diff --git a/tests/ray/test_data_transfer.py b/tests/ray/test_data_transfer.py index 480e576..46b962c 100644 --- a/tests/ray/test_data_transfer.py +++ b/tests/ray/test_data_transfer.py @@ -15,10 +15,10 @@ In this test, we instantiate a data parallel worker with 8 GPUs """ -from single_controller.base import Worker -from single_controller.ray import RayWorkerGroup, RayClassWithInitArgs, RayResourcePool +from verl.single_controller.base import Worker +from verl.single_controller.ray import RayWorkerGroup, RayClassWithInitArgs, RayResourcePool -from single_controller.base.decorator import Dispatch, register +from verl.single_controller.base.decorator import Dispatch, register import ray import torch diff --git a/tests/ray/test_driverfunc_to_worker.py b/tests/ray/test_driverfunc_to_worker.py index 2c7007b..ea253fd 100644 --- a/tests/ray/test_driverfunc_to_worker.py +++ b/tests/ray/test_driverfunc_to_worker.py @@ -18,9 +18,9 @@ from verl import DataProto from tensordict import TensorDict -from single_controller.base.worker import Worker -from single_controller.ray.base import RayResourcePool, RayClassWithInitArgs -from single_controller.ray import RayWorkerGroup +from verl.single_controller.base.worker import Worker +from verl.single_controller.ray.base import RayResourcePool, RayClassWithInitArgs +from verl.single_controller.ray import RayWorkerGroup os.environ['RAY_DEDUP_LOGS'] = '0' os.environ['NCCL_DEBUG'] = 'WARN' diff --git a/tests/ray/test_high_level_scheduling_api.py b/tests/ray/test_high_level_scheduling_api.py index 33d0d14..2d83206 100644 --- a/tests/ray/test_high_level_scheduling_api.py +++ b/tests/ray/test_high_level_scheduling_api.py @@ -16,8 +16,8 @@ import ray -from single_controller.ray.base import RayResourcePool, RayClassWithInitArgs, RayWorkerGroup, merge_resource_pool -from single_controller.base.worker import Worker +from verl.single_controller.ray.base import RayResourcePool, RayClassWithInitArgs, RayWorkerGroup, merge_resource_pool +from verl.single_controller.base.worker import Worker @ray.remote diff --git a/tests/ray/test_ray_local_envs.py b/tests/ray/test_ray_local_envs.py index 53bf850..542d536 100644 --- a/tests/ray/test_ray_local_envs.py +++ b/tests/ray/test_ray_local_envs.py @@ -12,14 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. """ -e2e test single_controller.ray +e2e test verl.single_controller.ray """ import os import ray -from single_controller.ray.base import RayResourcePool, RayClassWithInitArgs, RayWorkerGroup -from single_controller.base.worker import Worker -from single_controller.ray.decorator import register, Dispatch, collect_all_to_all, Execute +from verl.single_controller.ray.base import RayResourcePool, RayClassWithInitArgs, RayWorkerGroup +from verl.single_controller.base.worker import Worker +from verl.single_controller.ray.decorator import register, Dispatch, collect_all_to_all, Execute @ray.remote diff --git a/tests/ray/test_remote_api.py b/tests/ray/test_remote_api.py index aa4c1c1..b7a64b6 100644 --- a/tests/ray/test_remote_api.py +++ b/tests/ray/test_remote_api.py @@ -12,9 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from single_controller.remote import remote, RemoteBackend, SharedResourcePool -from single_controller.base.decorator import register, Dispatch -from single_controller.base.worker import Worker +from verl.single_controller.remote import remote, RemoteBackend, SharedResourcePool +from verl.single_controller.base.decorator import register, Dispatch +from verl.single_controller.base.worker import Worker @remote(process_on_nodes=[3], use_gpu=True, name_prefix="actor", sharing=SharedResourcePool) diff --git a/tests/ray/test_worker_group_basics.py b/tests/ray/test_worker_group_basics.py index ee1ef10..fa18e9b 100644 --- a/tests/ray/test_worker_group_basics.py +++ b/tests/ray/test_worker_group_basics.py @@ -12,15 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. """ -e2e test single_controller.ray +e2e test verl.single_controller.ray """ import torch import ray -from single_controller.ray.base import RayResourcePool, RayClassWithInitArgs, RayWorkerGroup -from single_controller.base.worker import Worker -from single_controller.ray.decorator import register, Dispatch, collect_all_to_all, Execute +from verl.single_controller.ray.base import RayResourcePool, RayClassWithInitArgs, RayWorkerGroup +from verl.single_controller.base.worker import Worker +from verl.single_controller.ray.decorator import register, Dispatch, collect_all_to_all, Execute def two_to_all_dispatch_fn(worker_group, *args, **kwargs): diff --git a/tests/ray/test_worker_group_torch.py b/tests/ray/test_worker_group_torch.py index c48ed1e..13508ed 100644 --- a/tests/ray/test_worker_group_torch.py +++ b/tests/ray/test_worker_group_torch.py @@ -21,8 +21,8 @@ import torch.distributed import ray -from single_controller.ray.base import RayResourcePool, RayClassWithInitArgs, RayWorkerGroup -from single_controller.base.worker import Worker +from verl.single_controller.ray.base import RayResourcePool, RayClassWithInitArgs, RayWorkerGroup +from verl.single_controller.base.worker import Worker @ray.remote diff --git a/single_controller/__init__.py b/verl/single_controller/__init__.py similarity index 100% rename from single_controller/__init__.py rename to verl/single_controller/__init__.py diff --git a/single_controller/base/__init__.py b/verl/single_controller/base/__init__.py similarity index 100% rename from single_controller/base/__init__.py rename to verl/single_controller/base/__init__.py diff --git a/single_controller/base/decorator.py b/verl/single_controller/base/decorator.py similarity index 93% rename from single_controller/base/decorator.py rename to verl/single_controller/base/decorator.py index 9544ac6..6fdacb6 100644 --- a/single_controller/base/decorator.py +++ b/verl/single_controller/base/decorator.py @@ -75,7 +75,7 @@ def dispatch_megatron_compute(worker_group, *args, **kwargs): """ User passes in dp data. The data is dispatched to all tp/pp ranks with the same dp """ - from single_controller.base.megatron.worker_group import MegatronWorkerGroup + from verl.single_controller.base.megatron.worker_group import MegatronWorkerGroup assert isinstance(worker_group, MegatronWorkerGroup), f'worker_group must be MegatronWorkerGroup, Got {type(worker_group)}' @@ -104,7 +104,7 @@ def collect_megatron_compute(worker_group, output): """ Only collect the data from the tp=0 and pp=last and every dp ranks """ - from single_controller.base.megatron.worker_group import MegatronWorkerGroup + from verl.single_controller.base.megatron.worker_group import MegatronWorkerGroup assert isinstance(worker_group, MegatronWorkerGroup) output_in_dp = [] pp_size = worker_group.get_megatron_global_info().pp_size @@ -119,7 +119,7 @@ def dispatch_megatron_compute_data_proto(worker_group, *args, **kwargs): """ All the args and kwargs must be DataProto. The batch will be chunked by dp_size and passed to each rank """ - from single_controller.base.megatron.worker_group import MegatronWorkerGroup + from verl.single_controller.base.megatron.worker_group import MegatronWorkerGroup assert isinstance(worker_group, MegatronWorkerGroup) splitted_args, splitted_kwargs = _split_args_kwargs_data_proto(worker_group.dp_size, *args, **kwargs) @@ -162,7 +162,7 @@ def dispatch_megatron_pp_as_dp(worker_group, *args, **kwargs): """ treat pp as dp. """ - from single_controller.base.megatron.worker_group import MegatronWorkerGroup + from verl.single_controller.base.megatron.worker_group import MegatronWorkerGroup assert isinstance(worker_group, MegatronWorkerGroup) pp_size = worker_group.pp_size @@ -210,7 +210,7 @@ def collect_megatron_pp_as_dp(worker_group, output): """ treat pp as dp. Only collect data on tp=0 """ - from single_controller.base.megatron.worker_group import MegatronWorkerGroup + from verl.single_controller.base.megatron.worker_group import MegatronWorkerGroup assert isinstance(worker_group, MegatronWorkerGroup) output_in_dp = [] for global_rank in range(worker_group.world_size): @@ -224,7 +224,7 @@ def collect_megatron_pp_only(worker_group, output): """ Only collect output of megatron pp. This is useful when examine weight names as they are identical in tp/dp """ - from single_controller.base.megatron.worker_group import MegatronWorkerGroup + from verl.single_controller.base.megatron.worker_group import MegatronWorkerGroup assert isinstance(worker_group, MegatronWorkerGroup) output_in_pp = [] for global_rank in range(worker_group.world_size): @@ -235,7 +235,7 @@ def collect_megatron_pp_only(worker_group, output): def dispatch_megatron_pp_as_dp_data_proto(worker_group, *args, **kwargs): - from single_controller.base.megatron.worker_group import MegatronWorkerGroup + from verl.single_controller.base.megatron.worker_group import MegatronWorkerGroup assert isinstance(worker_group, MegatronWorkerGroup) pp_dp_size = worker_group.dp_size * worker_group.pp_size @@ -245,7 +245,7 @@ def dispatch_megatron_pp_as_dp_data_proto(worker_group, *args, **kwargs): def collect_megatron_pp_as_dp_data_proto(worker_group, output): from verl.protocol import DataProto - from single_controller.base.megatron.worker_group import MegatronWorkerGroup + from verl.single_controller.base.megatron.worker_group import MegatronWorkerGroup assert isinstance(worker_group, MegatronWorkerGroup) output = collect_megatron_pp_as_dp(worker_group, output) @@ -253,7 +253,7 @@ def collect_megatron_pp_as_dp_data_proto(worker_group, output): def dispatch_dp_compute(worker_group, *args, **kwargs): - from single_controller.base.worker_group import WorkerGroup + from verl.single_controller.base.worker_group import WorkerGroup assert isinstance(worker_group, WorkerGroup) for arg in args: assert isinstance(arg, (Tuple, List)) and len(arg) == worker_group.world_size @@ -263,21 +263,21 @@ def dispatch_dp_compute(worker_group, *args, **kwargs): def collect_dp_compute(worker_group, output): - from single_controller.base.worker_group import WorkerGroup + from verl.single_controller.base.worker_group import WorkerGroup assert isinstance(worker_group, WorkerGroup) assert len(output) == worker_group.world_size return output def dispatch_dp_compute_data_proto(worker_group, *args, **kwargs): - from single_controller.base.worker_group import WorkerGroup + from verl.single_controller.base.worker_group import WorkerGroup assert isinstance(worker_group, WorkerGroup) splitted_args, splitted_kwargs = _split_args_kwargs_data_proto(worker_group.world_size, *args, **kwargs) return splitted_args, splitted_kwargs def dispatch_dp_compute_data_proto_with_func(worker_group, *args, **kwargs): - from single_controller.base.worker_group import WorkerGroup + from verl.single_controller.base.worker_group import WorkerGroup assert isinstance(worker_group, WorkerGroup) assert type(args[0]) == FunctionType # NOTE: The first one args is a function! diff --git a/single_controller/base/dp.py b/verl/single_controller/base/dp.py similarity index 96% rename from single_controller/base/dp.py rename to verl/single_controller/base/dp.py index 5d534e8..2d19188 100644 --- a/single_controller/base/dp.py +++ b/verl/single_controller/base/dp.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from single_controller.base.worker import Worker +from verl.single_controller.base.worker import Worker class DPEngineWorker(Worker): diff --git a/single_controller/base/megatron/__init__.py b/verl/single_controller/base/megatron/__init__.py similarity index 100% rename from single_controller/base/megatron/__init__.py rename to verl/single_controller/base/megatron/__init__.py diff --git a/single_controller/base/megatron/worker.py b/verl/single_controller/base/megatron/worker.py similarity index 94% rename from single_controller/base/megatron/worker.py rename to verl/single_controller/base/megatron/worker.py index 46608bb..2d84d29 100644 --- a/single_controller/base/megatron/worker.py +++ b/verl/single_controller/base/megatron/worker.py @@ -14,7 +14,7 @@ import os from dataclasses import dataclass -from single_controller.base.worker import Worker, DistRankInfo, DistGlobalInfo +from verl.single_controller.base.worker import Worker, DistRankInfo, DistGlobalInfo class MegatronWorker(Worker): diff --git a/single_controller/base/megatron/worker_group.py b/verl/single_controller/base/megatron/worker_group.py similarity index 96% rename from single_controller/base/megatron/worker_group.py rename to verl/single_controller/base/megatron/worker_group.py index 59a78ff..67c21d3 100644 --- a/single_controller/base/megatron/worker_group.py +++ b/verl/single_controller/base/megatron/worker_group.py @@ -15,7 +15,7 @@ from typing import Dict from .worker import DistRankInfo, DistGlobalInfo -from single_controller.base import ResourcePool, WorkerGroup +from verl.single_controller.base import ResourcePool, WorkerGroup class MegatronWorkerGroup(WorkerGroup): diff --git a/single_controller/base/register_center/__init__.py b/verl/single_controller/base/register_center/__init__.py similarity index 100% rename from single_controller/base/register_center/__init__.py rename to verl/single_controller/base/register_center/__init__.py diff --git a/single_controller/base/register_center/ray.py b/verl/single_controller/base/register_center/ray.py similarity index 100% rename from single_controller/base/register_center/ray.py rename to verl/single_controller/base/register_center/ray.py diff --git a/single_controller/base/worker.py b/verl/single_controller/base/worker.py similarity index 95% rename from single_controller/base/worker.py rename to verl/single_controller/base/worker.py index efd23a8..2ca961d 100644 --- a/single_controller/base/worker.py +++ b/verl/single_controller/base/worker.py @@ -17,7 +17,7 @@ import os import socket from dataclasses import dataclass -from single_controller.base.decorator import register, Dispatch +from verl.single_controller.base.decorator import register, Dispatch @dataclass @@ -43,7 +43,7 @@ def get_node_ip_by_sdk(): import ray return ray._private.services.get_node_ip_address() elif os.getenv("WG_BACKEND", None) == "torch_rpc": - from single_controller.torchrpc.k8s_client import get_ip_addr + from verl.single_controller.torchrpc.k8s_client import get_ip_addr return get_ip_addr() return None @@ -110,7 +110,7 @@ def _configure_before_init(self, register_center_name: str, rank: int): } if os.getenv("WG_BACKEND", None) == "ray": - from single_controller.base.register_center.ray import create_worker_group_register_center + from verl.single_controller.base.register_center.ray import create_worker_group_register_center self.register_center = create_worker_group_register_center(name=register_center_name, info=rank_zero_info) diff --git a/single_controller/base/worker_group.py b/verl/single_controller/base/worker_group.py similarity index 98% rename from single_controller/base/worker_group.py rename to verl/single_controller/base/worker_group.py index a6bc927..bd58458 100644 --- a/single_controller/base/worker_group.py +++ b/verl/single_controller/base/worker_group.py @@ -20,7 +20,7 @@ import time from typing import List, Any, Callable, Dict -from single_controller.base.decorator import MAGIC_ATTR, Dispatch, get_predefined_dispatch_fn, get_predefined_execute_fn +from verl.single_controller.base.decorator import MAGIC_ATTR, Dispatch, get_predefined_dispatch_fn, get_predefined_execute_fn class ResourcePool: diff --git a/single_controller/ray/__init__.py b/verl/single_controller/ray/__init__.py similarity index 100% rename from single_controller/ray/__init__.py rename to verl/single_controller/ray/__init__.py diff --git a/single_controller/ray/base.py b/verl/single_controller/ray/base.py similarity index 99% rename from single_controller/ray/base.py rename to verl/single_controller/ray/base.py index 2cb8148..eaa1b00 100644 --- a/single_controller/ray/base.py +++ b/verl/single_controller/ray/base.py @@ -21,7 +21,7 @@ from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy, NodeAffinitySchedulingStrategy from ray.experimental.state.api import get_actor -from single_controller.base import WorkerGroup, ResourcePool, ClassWithInitArgs, Worker +from verl.single_controller.base import WorkerGroup, ResourcePool, ClassWithInitArgs, Worker __all__ = ['Worker'] @@ -373,7 +373,7 @@ def world_size(self): """ from unittest.mock import patch -from single_controller.base.decorator import MAGIC_ATTR +from verl.single_controller.base.decorator import MAGIC_ATTR import os diff --git a/single_controller/ray/decorator.py b/verl/single_controller/ray/decorator.py similarity index 97% rename from single_controller/ray/decorator.py rename to verl/single_controller/ray/decorator.py index 006a80c..1de452f 100644 --- a/single_controller/ray/decorator.py +++ b/verl/single_controller/ray/decorator.py @@ -19,7 +19,7 @@ import ray # compatiblity cern -from single_controller.base.decorator import * +from verl.single_controller.base.decorator import * def maybe_remote(main): diff --git a/single_controller/ray/dist_data_pass_protocol.py b/verl/single_controller/ray/dist_data_pass_protocol.py similarity index 100% rename from single_controller/ray/dist_data_pass_protocol.py rename to verl/single_controller/ray/dist_data_pass_protocol.py diff --git a/single_controller/ray/dp.py b/verl/single_controller/ray/dp.py similarity index 97% rename from single_controller/ray/dp.py rename to verl/single_controller/ray/dp.py index fab4da9..b53d4b9 100644 --- a/single_controller/ray/dp.py +++ b/verl/single_controller/ray/dp.py @@ -14,7 +14,7 @@ import ray -from single_controller.ray.base import RayWorkerGroup, RayResourcePool, RayClassWithInitArgs +from verl.single_controller.ray.base import RayWorkerGroup, RayResourcePool, RayClassWithInitArgs @ray.remote diff --git a/single_controller/ray/megatron.py b/verl/single_controller/ray/megatron.py similarity index 94% rename from single_controller/ray/megatron.py rename to verl/single_controller/ray/megatron.py index 3aad741..2cdb49f 100644 --- a/single_controller/ray/megatron.py +++ b/verl/single_controller/ray/megatron.py @@ -17,8 +17,8 @@ import ray from .base import RayWorkerGroup, RayResourcePool, RayClassWithInitArgs -from single_controller.base.megatron.worker import DistRankInfo, DistGlobalInfo -from single_controller.base.megatron.worker_group import MegatronWorkerGroup +from verl.single_controller.base.megatron.worker import DistRankInfo, DistGlobalInfo +from verl.single_controller.base.megatron.worker_group import MegatronWorkerGroup # NOTE(sgm): for opensource megatron-core diff --git a/single_controller/version/version b/verl/single_controller/version/version similarity index 100% rename from single_controller/version/version rename to verl/single_controller/version/version diff --git a/verl/trainer/main_generation.py b/verl/trainer/main_generation.py index 0d17073..42469b6 100644 --- a/verl/trainer/main_generation.py +++ b/verl/trainer/main_generation.py @@ -33,7 +33,7 @@ from verl.utils.fs import copy_local_path_from_hdfs from verl.trainer.ppo.workers.fsdp_workers import ActorRolloutRefWorker from verl.utils.hdfs_io import makedirs -from single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup +from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup @hydra.main(config_path='config', config_name='generation', version_base=None) diff --git a/verl/trainer/main_ppo.py b/verl/trainer/main_ppo.py index cbb1c61..3c165b6 100644 --- a/verl/trainer/main_ppo.py +++ b/verl/trainer/main_ppo.py @@ -120,13 +120,13 @@ def main_task(config): if config.actor_rollout_ref.actor.strategy == 'fsdp': assert config.actor_rollout_ref.actor.strategy == config.critic.strategy from verl.trainer.ppo.workers.fsdp_workers import ActorRolloutRefWorker, CriticWorker - from single_controller.ray import RayWorkerGroup + from verl.single_controller.ray import RayWorkerGroup ray_worker_group_cls = RayWorkerGroup elif config.actor_rollout_ref.actor.strategy == 'megatron': assert config.actor_rollout_ref.actor.strategy == config.critic.strategy from verl.trainer.ppo.workers.megatron_workers import ActorRolloutRefWorker, CriticWorker - from single_controller.ray.megatron import NVMegatronRayWorkerGroup + from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup ray_worker_group_cls = NVMegatronRayWorkerGroup else: diff --git a/verl/trainer/ppo/ray_trainer.py b/verl/trainer/ppo/ray_trainer.py index 95814f5..3a2b258 100644 --- a/verl/trainer/ppo/ray_trainer.py +++ b/verl/trainer/ppo/ray_trainer.py @@ -26,9 +26,9 @@ import numpy as np from codetiming import Timer -from single_controller.base import Worker -from single_controller.ray import RayResourcePool, RayWorkerGroup, RayClassWithInitArgs -from single_controller.ray.base import create_colocated_worker_cls +from verl.single_controller.base import Worker +from verl.single_controller.ray import RayResourcePool, RayWorkerGroup, RayClassWithInitArgs +from verl.single_controller.ray.base import create_colocated_worker_cls from verl import DataProto from verl.trainer.ppo import core_algos diff --git a/verl/trainer/ppo/workers/fsdp_workers.py b/verl/trainer/ppo/workers/fsdp_workers.py index 9d47bae..e7b34cd 100644 --- a/verl/trainer/ppo/workers/fsdp_workers.py +++ b/verl/trainer/ppo/workers/fsdp_workers.py @@ -23,8 +23,8 @@ import torch.distributed from omegaconf import DictConfig, open_dict -from single_controller.base import Worker -from single_controller.base.decorator import register, Dispatch +from verl.single_controller.base import Worker +from verl.single_controller.base.decorator import register, Dispatch import verl.utils.torch_functional as verl_F from verl import DataProto from verl.trainer.ppo.actor import DataParallelPPOActor diff --git a/verl/trainer/ppo/workers/megatron_workers.py b/verl/trainer/ppo/workers/megatron_workers.py index a481d46..b2ce989 100644 --- a/verl/trainer/ppo/workers/megatron_workers.py +++ b/verl/trainer/ppo/workers/megatron_workers.py @@ -22,13 +22,13 @@ import torch.distributed import torch.nn as nn from omegaconf import DictConfig -from single_controller.base.megatron.worker import MegatronWorker +from verl.single_controller.base.megatron.worker import MegatronWorker from verl.trainer.ppo.actor.megatron_actor import MegatronPPOActor from verl.trainer.ppo.critic.megatron_critic import MegatronPPOCritic from verl.trainer.ppo.hybrid_engine import AllGatherPPModel from verl.trainer.ppo.reward_model.megatron.reward_model import MegatronRewardModel -from single_controller.base.decorator import register, Dispatch +from verl.single_controller.base.decorator import register, Dispatch from verl import DataProto from verl.utils.fs import copy_local_path_from_hdfs from verl.utils.debug import log_gpu_memory_usage From 9fa2cfb4af1c7b469cd7ef5f9ef4aadbe0900e3c Mon Sep 17 00:00:00 2001 From: shengguangming Date: Thu, 12 Dec 2024 00:13:08 +0800 Subject: [PATCH 08/14] [v0.1][release] upload v0.1 release --- verl/version/version | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/verl/version/version b/verl/version/version index 23f7536..ceab6e1 100644 --- a/verl/version/version +++ b/verl/version/version @@ -1 +1 @@ -0.1.pre \ No newline at end of file +0.1 \ No newline at end of file From a45bf4af1c87b2b69c98c89a421df6ed4bddf8d3 Mon Sep 17 00:00:00 2001 From: HL Date: Wed, 11 Dec 2024 11:54:51 -0800 Subject: [PATCH 09/14] [news] neurips: add notebook link --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 2f96ba3..cdfe0a8 100644 --- a/README.md +++ b/README.md @@ -28,7 +28,7 @@ veRL is fast with: ## News - [2024/12] The team will present Post-training LLMs: From Algorithms to Infrastructure at NeurIPS 2024. - - [Slides](https://github.com/eric-haibin-lin/verl-data/tree/neurips), notebooks, and video be available soon + - [Slides](https://github.com/eric-haibin-lin/verl-data/tree/neurips), [notebooks](https://lightning.ai/eric-haibin-lin/studios/verl-neurips~01je0d1benfjb9grmfjxqahvkn?view=public§ion=featured), and video be available soon - [2024/08] HybridFlow (verl) is accepted to EuroSys 2025. ## Installation Guide From 13e4b9dca73d6fd384a7424893e556c6c29e62ab Mon Sep 17 00:00:00 2001 From: HL Date: Wed, 11 Dec 2024 20:54:36 -0800 Subject: [PATCH 10/14] [docs] neurips: add video link --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index cdfe0a8..c888654 100644 --- a/README.md +++ b/README.md @@ -27,8 +27,8 @@ veRL is fast with: ## News -- [2024/12] The team will present Post-training LLMs: From Algorithms to Infrastructure at NeurIPS 2024. - - [Slides](https://github.com/eric-haibin-lin/verl-data/tree/neurips), [notebooks](https://lightning.ai/eric-haibin-lin/studios/verl-neurips~01je0d1benfjb9grmfjxqahvkn?view=public§ion=featured), and video be available soon +- [2024/12] The team presented Post-training LLMs: From Algorithms to Infrastructure at NeurIPS 2024. + - [Slides](https://github.com/eric-haibin-lin/verl-data/tree/neurips), [notebooks](https://lightning.ai/eric-haibin-lin/studios/verl-neurips~01je0d1benfjb9grmfjxqahvkn?view=public§ion=featured), and [video](https://neurips.cc/Expo/Conferences/2024/workshop/100677) available. - [2024/08] HybridFlow (verl) is accepted to EuroSys 2025. ## Installation Guide From 0deed9a7376bfe07484aa96799c233d8b558b0ad Mon Sep 17 00:00:00 2001 From: HL Date: Wed, 11 Dec 2024 21:14:24 -0800 Subject: [PATCH 11/14] [docs] community: add slack info --- README.md | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index c888654..ab07a2b 100644 --- a/README.md +++ b/README.md @@ -172,7 +172,12 @@ Visit our [documentation](https://verl.readthedocs.io/en/latest/index.html) to l - [Add models to Megatron-LM backend](https://verl.readthedocs.io/en/latest/advance/megatron_extension.html) -## Contribution +## Community and Contribution + +### Communication channel + +[Join us](https://join.slack.com/t/verlgroup/shared_invite/zt-2w5p9o4c3-yy0x2Q56s_VlGLsJ93A6vA) for discussions on slack! + ### Code formatting We use yapf (Google style) to enforce strict code formatting when reviewing MRs. To reformat you code locally, make sure you installed `yapf` ```bash From c2296e80dfdf77647a0ef6901cb8948abd466ee3 Mon Sep 17 00:00:00 2001 From: HL Date: Sat, 14 Dec 2024 21:57:55 -0800 Subject: [PATCH 12/14] api: rename tracking logger to wandb logger type (#47) --- docs/examples/config.rst | 5 ++--- docs/examples/gsm8k_example.rst | 4 ++-- docs/start/quickstart.rst | 4 ++-- examples/ppo_trainer/run_deepseek7b_llm.sh | 2 +- examples/ppo_trainer/run_deepseek_full_hh_rlhf.sh | 2 +- .../ppo_trainer/run_deepseek_math_gsm8k_megatron.sh | 2 +- examples/ppo_trainer/run_deepseek_megatron.sh | 2 +- examples/ppo_trainer/run_gemma.sh | 2 +- examples/ppo_trainer/run_qwen2-7b.sh | 2 +- examples/ppo_trainer/run_qwen2-7b_rm.sh | 2 +- examples/ppo_trainer/run_qwen2.5-32b.sh | 2 +- examples/sft/gsm8k/run_deepseek_6b7.sh | 2 +- examples/sft/gsm8k/run_gemma_2b.sh | 11 +++++++++-- examples/sft/gsm8k/run_gemma_7b.sh | 2 +- .../split_placement/config/ppo_trainer_split.yaml | 2 +- examples/split_placement/run_deepseek7b_llm.sh | 2 +- verl/trainer/config/ppo_megatron_trainer.yaml | 2 +- verl/trainer/config/ppo_trainer.yaml | 2 +- verl/utils/tracking.py | 12 ++++++++---- 19 files changed, 37 insertions(+), 27 deletions(-) diff --git a/docs/examples/config.rst b/docs/examples/config.rst index 07e1dce..d7d8fa4 100644 --- a/docs/examples/config.rst +++ b/docs/examples/config.rst @@ -307,7 +307,7 @@ Trainer total_epochs: 30 project_name: verl_examples experiment_name: gsm8k - logger: ['console', 'tracking'] + logger: ['console', 'wandb'] nnodes: 1 n_gpus_per_node: 8 save_freq: -1 @@ -319,8 +319,7 @@ Trainer - ``trainer.total_epochs``: Number of epochs in training. - ``trainer.project_name``: For wandb - ``trainer.experiment_name``: For wandb -- ``trainer.logger``: Support console and tracking. For tracking, we - will initialize a wandb +- ``trainer.logger``: Support console and wandb - ``trainer.nnodes``: Number of nodes used in the training. - ``trainer.n_gpus_per_node``: Number of GPUs per node. - ``trainer.save_freq``: The frequency (by iteration) to save checkpoint diff --git a/docs/examples/gsm8k_example.rst b/docs/examples/gsm8k_example.rst index 90b61e0..0d3c1f8 100644 --- a/docs/examples/gsm8k_example.rst +++ b/docs/examples/gsm8k_example.rst @@ -91,7 +91,7 @@ We also provide various training scripts for SFT on GSM8K dataset in `gsm8k sft trainer.project_name=gsm8k-sft \ trainer.experiment_name=gsm8k-sft-deepseek-coder-6.7b-instruct \ trainer.total_epochs=4 \ - trainer.logger=['console','tracking'] + trainer.logger=['console','wandb'] Step 4: Perform PPO training with your model on GSM8K Dataset ------------------------------------------------------------- @@ -156,7 +156,7 @@ The script of run_deepseek7b_llm.sh critic.model.fsdp_config.optimizer_offload=False \ algorithm.kl_ctrl.kl_coef=0.001 \ trainer.critic_warmup=0 \ - trainer.logger=['console','tracking'] \ + trainer.logger=['console','wandb'] \ trainer.project_name='verl_example_gsm8k' \ trainer.experiment_name='deepseek_llm_7b_function_rm' \ trainer.n_gpus_per_node=8 \ diff --git a/docs/start/quickstart.rst b/docs/start/quickstart.rst index 2ac6845..69888f3 100644 --- a/docs/start/quickstart.rst +++ b/docs/start/quickstart.rst @@ -97,7 +97,7 @@ We also provide various training scripts for SFT on GSM8K dataset in `gsm8k sft trainer.project_name=gsm8k-sft \ trainer.experiment_name=gsm8k-sft-deepseek-coder-6.7b-instruct \ trainer.total_epochs=4 \ - trainer.logger=['console','tracking'] + trainer.logger=['console','wandb'] Step 4: Perform PPO training with your model on GSM8K Dataset ------------------------------------------------------------- @@ -163,7 +163,7 @@ The script of `run_deepseek7b_llm.sh` critic.model.fsdp_config.optimizer_offload=False \ algorithm.kl_ctrl.kl_coef=0.001 \ trainer.critic_warmup=0 \ - trainer.logger=['console','tracking'] \ + trainer.logger=['console','wandb'] \ trainer.project_name='verl_example_gsm8k' \ trainer.experiment_name='deepseek_llm_7b_function_rm' \ trainer.n_gpus_per_node=8 \ diff --git a/examples/ppo_trainer/run_deepseek7b_llm.sh b/examples/ppo_trainer/run_deepseek7b_llm.sh index 1df0da0..108fba1 100644 --- a/examples/ppo_trainer/run_deepseek7b_llm.sh +++ b/examples/ppo_trainer/run_deepseek7b_llm.sh @@ -29,7 +29,7 @@ python3 -m verl.trainer.main_ppo \ critic.model.fsdp_config.optimizer_offload=False \ algorithm.kl_ctrl.kl_coef=0.001 \ trainer.critic_warmup=0 \ - trainer.logger=['console','tracking'] \ + trainer.logger=['console','wandb'] \ trainer.project_name='verl_example_gsm8k' \ trainer.experiment_name='deepseek_llm_7b_function_rm' \ trainer.n_gpus_per_node=8 \ diff --git a/examples/ppo_trainer/run_deepseek_full_hh_rlhf.sh b/examples/ppo_trainer/run_deepseek_full_hh_rlhf.sh index f4e2587..bd2c0bc 100644 --- a/examples/ppo_trainer/run_deepseek_full_hh_rlhf.sh +++ b/examples/ppo_trainer/run_deepseek_full_hh_rlhf.sh @@ -31,7 +31,7 @@ python3 -m verl.trainer.main_ppo --config-path=./config --config-name='ppo_megat reward_model.param_offload=False \ algorithm.kl_ctrl.kl_coef=0.001 \ trainer.critic_warmup=0 \ - trainer.logger=['console','tracking'] \ + trainer.logger=['console','wandb'] \ trainer.project_name='verl_megatron_full_hh_rlhf_examples' \ trainer.experiment_name='deepseek_llm_7b_model_rm' \ trainer.n_gpus_per_node=8 \ diff --git a/examples/ppo_trainer/run_deepseek_math_gsm8k_megatron.sh b/examples/ppo_trainer/run_deepseek_math_gsm8k_megatron.sh index ed113b2..c342d52 100644 --- a/examples/ppo_trainer/run_deepseek_math_gsm8k_megatron.sh +++ b/examples/ppo_trainer/run_deepseek_math_gsm8k_megatron.sh @@ -30,7 +30,7 @@ python3 -m verl.trainer.main_ppo --config-path=./config --config-name='ppo_megat critic.ppo_micro_batch_size=32 \ algorithm.kl_ctrl.kl_coef=0.001 \ trainer.critic_warmup=0 \ - trainer.logger=['console','tracking'] \ + trainer.logger=['console','wandb'] \ trainer.project_name='verl_megatron_math_gsm8k_examples' \ trainer.experiment_name='deepseek_llm_7b_function_rm' \ trainer.n_gpus_per_node=8 \ diff --git a/examples/ppo_trainer/run_deepseek_megatron.sh b/examples/ppo_trainer/run_deepseek_megatron.sh index 2d1cab2..c63285a 100644 --- a/examples/ppo_trainer/run_deepseek_megatron.sh +++ b/examples/ppo_trainer/run_deepseek_megatron.sh @@ -22,7 +22,7 @@ python3 -m verl.trainer.main_ppo --config-path=./config --config-name='ppo_megat critic.ppo_micro_batch_size=64 \ algorithm.kl_ctrl.kl_coef=0.001 \ trainer.critic_warmup=0 \ - trainer.logger=['console','tracking'] \ + trainer.logger=['console','wandb'] \ trainer.project_name='verl_megatron_gsm8k_examples' \ trainer.experiment_name='deepseek_llm_7b_function_rm' \ trainer.n_gpus_per_node=8 \ diff --git a/examples/ppo_trainer/run_gemma.sh b/examples/ppo_trainer/run_gemma.sh index bcd5452..200ebdb 100644 --- a/examples/ppo_trainer/run_gemma.sh +++ b/examples/ppo_trainer/run_gemma.sh @@ -29,7 +29,7 @@ python3 -m verl.trainer.main_ppo \ critic.model.fsdp_config.optimizer_offload=False \ algorithm.kl_ctrl.kl_coef=0.001 \ trainer.critic_warmup=0 \ - trainer.logger=['console','tracking'] \ + trainer.logger=['console','wandb'] \ trainer.project_name='verl_example' \ trainer.experiment_name='gemma2b_function_rm' \ trainer.n_gpus_per_node=2 \ diff --git a/examples/ppo_trainer/run_qwen2-7b.sh b/examples/ppo_trainer/run_qwen2-7b.sh index 396eb63..c6ffc1b 100644 --- a/examples/ppo_trainer/run_qwen2-7b.sh +++ b/examples/ppo_trainer/run_qwen2-7b.sh @@ -37,7 +37,7 @@ python3 -m verl.trainer.main_ppo \ critic.model.fsdp_config.optimizer_offload=False \ algorithm.kl_ctrl.kl_coef=0.001 \ trainer.critic_warmup=0 \ - trainer.logger=['console','tracking'] \ + trainer.logger=['console','wandb'] \ trainer.project_name='verl_example' \ trainer.experiment_name='Qwen2-7B-Instruct_function_rm' \ trainer.n_gpus_per_node=8 \ diff --git a/examples/ppo_trainer/run_qwen2-7b_rm.sh b/examples/ppo_trainer/run_qwen2-7b_rm.sh index 2f77e87..3755b38 100644 --- a/examples/ppo_trainer/run_qwen2-7b_rm.sh +++ b/examples/ppo_trainer/run_qwen2-7b_rm.sh @@ -44,7 +44,7 @@ python3 -m verl.trainer.main_ppo \ reward_model.micro_batch_size=16 \ algorithm.kl_ctrl.kl_coef=0.001 \ trainer.critic_warmup=0 \ - trainer.logger=['console','tracking'] \ + trainer.logger=['console','wandb'] \ trainer.project_name='verl_example' \ trainer.experiment_name='Qwen2-7B-Instruct_hybrid_rm' \ trainer.n_gpus_per_node=8 \ diff --git a/examples/ppo_trainer/run_qwen2.5-32b.sh b/examples/ppo_trainer/run_qwen2.5-32b.sh index e7f93cc..1192f1e 100644 --- a/examples/ppo_trainer/run_qwen2.5-32b.sh +++ b/examples/ppo_trainer/run_qwen2.5-32b.sh @@ -38,7 +38,7 @@ python3 -m verl.trainer.main_ppo \ critic.model.fsdp_config.optimizer_offload=False \ algorithm.kl_ctrl.kl_coef=0.0001 \ trainer.critic_warmup=0 \ - trainer.logger=['console','tracking'] \ + trainer.logger=['console','wandb'] \ trainer.project_name='verl_example' \ trainer.experiment_name='Qwen2.5-32B-Instruct_function_rm' \ trainer.n_gpus_per_node=8 \ diff --git a/examples/sft/gsm8k/run_deepseek_6b7.sh b/examples/sft/gsm8k/run_deepseek_6b7.sh index f944a14..8e4d54c 100644 --- a/examples/sft/gsm8k/run_deepseek_6b7.sh +++ b/examples/sft/gsm8k/run_deepseek_6b7.sh @@ -16,4 +16,4 @@ torchrun --standalone --nnodes=1 --nproc_per_node=$nproc_per_node \ trainer.project_name=gsm8k-sft \ trainer.experiment_name=gsm8k-sft-deepseek-coder-6.7b-instruct \ trainer.total_epochs=4 \ - trainer.logger=['console','tracking'] \ No newline at end of file + trainer.logger=['console','wandb'] \ No newline at end of file diff --git a/examples/sft/gsm8k/run_gemma_2b.sh b/examples/sft/gsm8k/run_gemma_2b.sh index fb5773c..4eca025 100644 --- a/examples/sft/gsm8k/run_gemma_2b.sh +++ b/examples/sft/gsm8k/run_gemma_2b.sh @@ -2,9 +2,16 @@ set -x -hdfs_path=hdfs://user/verl/experiments/gsm8k/gemma-2b-it/ # replace to your own hdfs/local path +if [ "$#" -lt 2 ]; then + echo "Usage: run_gemma_2b.sh [other_configs...]" + exit 1 +fi nproc_per_node=$1 +hdfs_path=$2 + +# Shift the arguments so $@ refers to the rest +shift 2 torchrun --standalone --nnodes=1 --nproc_per_node=$nproc_per_node \ -m verl.trainer.fsdp_sft_trainer \ @@ -18,4 +25,4 @@ torchrun --standalone --nnodes=1 --nproc_per_node=$nproc_per_node \ trainer.project_name=gsm8k-sft \ trainer.experiment_name=gsm8k-sft-gemma-2b-it \ trainer.total_epochs=3 \ - trainer.logger=['console','tracking'] \ No newline at end of file + trainer.logger=['console','wandb'] $@ \ No newline at end of file diff --git a/examples/sft/gsm8k/run_gemma_7b.sh b/examples/sft/gsm8k/run_gemma_7b.sh index 8239136..9c35792 100644 --- a/examples/sft/gsm8k/run_gemma_7b.sh +++ b/examples/sft/gsm8k/run_gemma_7b.sh @@ -16,4 +16,4 @@ torchrun --standalone --nnodes=1 --nproc_per_node=$nproc_per_node \ trainer.project_name=gsm8k-sft \ trainer.experiment_name=gsm8k-sft-gemma-1.1-7b-it \ trainer.total_epochs=4 \ - trainer.logger=['console','tracking'] \ No newline at end of file + trainer.logger=['console','wandb'] \ No newline at end of file diff --git a/examples/split_placement/config/ppo_trainer_split.yaml b/examples/split_placement/config/ppo_trainer_split.yaml index bd6bcf2..22835cc 100644 --- a/examples/split_placement/config/ppo_trainer_split.yaml +++ b/examples/split_placement/config/ppo_trainer_split.yaml @@ -121,7 +121,7 @@ trainer: total_epochs: 30 project_name: verl_examples experiment_name: gsm8k - logger: ['console', 'tracking'] + logger: ['console', 'wandb'] nnodes: 1 n_gpus_per_node: 8 save_freq: -1 diff --git a/examples/split_placement/run_deepseek7b_llm.sh b/examples/split_placement/run_deepseek7b_llm.sh index 6afd399..a2db960 100644 --- a/examples/split_placement/run_deepseek7b_llm.sh +++ b/examples/split_placement/run_deepseek7b_llm.sh @@ -29,7 +29,7 @@ python3 main_ppo_split.py \ critic.model.fsdp_config.optimizer_offload=False \ algorithm.kl_ctrl.kl_coef=0.001 \ trainer.critic_warmup=0 \ - trainer.logger=['console','tracking'] \ + trainer.logger=['console','wandb'] \ trainer.project_name='verl_example_gsm8k' \ trainer.experiment_name='deepseek_llm_7b_function_rm' \ trainer.n_gpus_per_node=8 \ diff --git a/verl/trainer/config/ppo_megatron_trainer.yaml b/verl/trainer/config/ppo_megatron_trainer.yaml index 2048490..364452a 100644 --- a/verl/trainer/config/ppo_megatron_trainer.yaml +++ b/verl/trainer/config/ppo_megatron_trainer.yaml @@ -135,7 +135,7 @@ trainer: total_epochs: 30 project_name: verl_examples experiment_name: gsm8k - logger: ['console', 'tracking'] + logger: ['console', 'wandb'] nnodes: 1 n_gpus_per_node: 8 save_freq: -1 diff --git a/verl/trainer/config/ppo_trainer.yaml b/verl/trainer/config/ppo_trainer.yaml index bd6bcf2..22835cc 100644 --- a/verl/trainer/config/ppo_trainer.yaml +++ b/verl/trainer/config/ppo_trainer.yaml @@ -121,7 +121,7 @@ trainer: total_epochs: 30 project_name: verl_examples experiment_name: gsm8k - logger: ['console', 'tracking'] + logger: ['console', 'wandb'] nnodes: 1 n_gpus_per_node: 8 save_freq: -1 diff --git a/verl/utils/tracking.py b/verl/utils/tracking.py index 19aab11..5a65f95 100644 --- a/verl/utils/tracking.py +++ b/verl/utils/tracking.py @@ -19,20 +19,24 @@ class Tracking(object): - supported_backend = ['tracking', 'console'] + supported_backend = ['wandb', 'console'] def __init__(self, project_name, experiment_name, default_backend: Union[str, List[str]] = 'console', config=None): if isinstance(default_backend, str): default_backend = [default_backend] for backend in default_backend: - assert backend in self.supported_backend, f'{backend} is not supported' + if backend == 'tracking': + import warnings + warnings.warn("`tracking` logger is deprecated. use `wandb` instead.", DeprecationWarning) + else: + assert backend in self.supported_backend, f'{backend} is not supported' self.logger = {} - if 'tracking' in default_backend: + if 'tracking' in default_backend or 'wandb' in default_backend: import wandb wandb.init(project=project_name, name=experiment_name, config=config) - self.logger['tracking'] = wandb + self.logger['wandb'] = wandb if 'console' in default_backend: from verl.utils.logger.aggregate_logger import LocalLogger From c7534db2d9ec8db4f1eb8470ce6bce473020930b Mon Sep 17 00:00:00 2001 From: PanAndy <1163962054@qq.com> Date: Mon, 16 Dec 2024 13:54:15 +0800 Subject: [PATCH 13/14] (fix): fix values response mask in dp critic. (#50) --- verl/trainer/ppo/critic/dp_critic.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/verl/trainer/ppo/critic/dp_critic.py b/verl/trainer/ppo/critic/dp_critic.py index 9efccc3..e078c62 100644 --- a/verl/trainer/ppo/critic/dp_critic.py +++ b/verl/trainer/ppo/critic/dp_critic.py @@ -84,6 +84,10 @@ def compute_values(self, data: DataProto) -> torch.Tensor: values = self._forward_micro_batch(micro_batch) values_lst.append(values) values = torch.concat(values_lst, dim=0) + responses = data.batch['responses'] + attention_mask = data.batch['attention_mask'] + response_length = responses.size(1) + values = values * attention_mask[:, -response_length - 1:-1] return values def update_critic(self, data: DataProto): From d60f843091980e678648900ae1288fe396639e58 Mon Sep 17 00:00:00 2001 From: HL Date: Mon, 16 Dec 2024 16:19:46 -0800 Subject: [PATCH 14/14] [sft] feat: fix sft dataset with latest preprocess code (#49) * api: rename tracking logger to wandb logger type * [sft] feat: add tests for sft dataset * refresh dataset * force refresh * use ds model for tokenizer * add option for trainer.val_only * fix path * fix lint * add sft test for cot and raw q&a * add hf_tokenizer api to patch gemma tokenizer * fix test --- examples/data_preprocess/gsm8k.py | 14 +++-- examples/sft/gsm8k/run_gemma_2b.sh | 19 ++++--- examples/split_placement/main_ppo_split.py | 5 +- tests/verl/utils/dataset/test_rl_dataset.py | 9 +-- tests/verl/utils/dataset/test_rm_dataset.py | 9 +-- tests/verl/utils/dataset/test_sft_dataset.py | 60 ++++++++++++++++++++ verl/trainer/fsdp_sft_trainer.py | 18 +++--- verl/trainer/main_generation.py | 5 +- verl/trainer/main_ppo.py | 7 ++- verl/trainer/ppo/ray_trainer.py | 4 ++ verl/trainer/ppo/workers/fsdp_workers.py | 20 ++----- verl/trainer/ppo/workers/megatron_workers.py | 18 +++--- verl/utils/dataset/rm_dataset.py | 5 +- verl/utils/dataset/sft_dataset.py | 56 +++++++++++------- verl/utils/hdfs_io.py | 8 ++- verl/utils/logger/aggregate_logger.py | 3 +- verl/utils/reward_score/gsm8k.py | 2 +- verl/utils/tokenizer.py | 31 +++++++++- 18 files changed, 192 insertions(+), 101 deletions(-) create mode 100644 tests/verl/utils/dataset/test_sft_dataset.py diff --git a/examples/data_preprocess/gsm8k.py b/examples/data_preprocess/gsm8k.py index b3f491b..d666845 100644 --- a/examples/data_preprocess/gsm8k.py +++ b/examples/data_preprocess/gsm8k.py @@ -52,17 +52,17 @@ def extract_solution(solution_str): def make_map_fn(split): def process_fn(example, idx): - question = example.pop('question') + question_raw = example.pop('question') - question = question + ' ' + instruction_following + question = question_raw + ' ' + instruction_following - answer = example.pop('answer') - solution = extract_solution(answer) + answer_raw = example.pop('answer') + solution = extract_solution(answer_raw) data = { "data_source": data_source, "prompt": [{ "role": "user", - "content": question + "content": question, }], "ability": "math", "reward_model": { @@ -71,7 +71,9 @@ def process_fn(example, idx): }, "extra_info": { 'split': split, - 'index': idx + 'index': idx, + 'answer': answer_raw, + "question": question_raw, } } return data diff --git a/examples/sft/gsm8k/run_gemma_2b.sh b/examples/sft/gsm8k/run_gemma_2b.sh index 4eca025..7ec85c0 100644 --- a/examples/sft/gsm8k/run_gemma_2b.sh +++ b/examples/sft/gsm8k/run_gemma_2b.sh @@ -1,4 +1,4 @@ -# Tested in 4 GPUs +# Tested with 2 & 4 GPUs set -x @@ -8,7 +8,7 @@ if [ "$#" -lt 2 ]; then fi nproc_per_node=$1 -hdfs_path=$2 +save_path=$2 # Shift the arguments so $@ refers to the rest shift 2 @@ -17,12 +17,15 @@ torchrun --standalone --nnodes=1 --nproc_per_node=$nproc_per_node \ -m verl.trainer.fsdp_sft_trainer \ data.train_files=$HOME/data/gsm8k/train.parquet \ data.val_files=$HOME/data/gsm8k/test.parquet \ - data.prompt_key=prompt \ - data.response_key=answer \ - data.micro_batch_size=32 \ + data.prompt_key=extra_info \ + data.response_key=extra_info \ + +data.prompt_dict_keys=['question'] \ + +data.response_dict_keys=['answer'] \ + data.micro_batch_size=8 \ model.partial_pretrain=google/gemma-2b-it \ - trainer.default_hdfs_dir=$hdfs_path \ + trainer.default_local_dir=$save_path \ trainer.project_name=gsm8k-sft \ trainer.experiment_name=gsm8k-sft-gemma-2b-it \ - trainer.total_epochs=3 \ - trainer.logger=['console','wandb'] $@ \ No newline at end of file + trainer.total_epochs=2 \ + trainer.logger=['console','wandb'] \ + trainer.default_hdfs_dir=null $@ \ No newline at end of file diff --git a/examples/split_placement/main_ppo_split.py b/examples/split_placement/main_ppo_split.py index 5ae4b21..1c608a7 100644 --- a/examples/split_placement/main_ppo_split.py +++ b/examples/split_placement/main_ppo_split.py @@ -113,9 +113,8 @@ def main_task(config): local_path = copy_local_path_from_hdfs(config.actor_rollout_ref.model.path) # instantiate tokenizer - tokenizer = AutoTokenizer.from_pretrained(local_path) - from verl.utils import set_pad_token_id - set_pad_token_id(tokenizer) + from verl.utils import hf_tokenizer + tokenizer = hf_tokenizer(local_path) # define worker classes if config.actor_rollout_ref.actor.strategy == 'fsdp': diff --git a/tests/verl/utils/dataset/test_rl_dataset.py b/tests/verl/utils/dataset/test_rl_dataset.py index 3ddeb75..9d3bba5 100644 --- a/tests/verl/utils/dataset/test_rl_dataset.py +++ b/tests/verl/utils/dataset/test_rl_dataset.py @@ -23,18 +23,13 @@ def get_gsm8k_data(): local_folder = os.path.expanduser('~/verl-data/gsm8k/') local_path = os.path.join(local_folder, 'train.parquet') os.makedirs(local_folder, exist_ok=True) - # import fsspec - # with fsspec.open(url, mode='rb') as fin, fsspec.open(local_path, mode='wb') as fout: - # content = fin.read() - # fout.write(content) return local_path def test_rl_dataset(): from verl.utils.dataset.rl_dataset import RLHFDataset, collate_fn - tokenizer = AutoTokenizer.from_pretrained('deepseek-ai/deepseek-coder-1.3b-instruct') - from verl.utils import set_pad_token_id - set_pad_token_id(tokenizer) + from verl.utils import hf_tokenizer + tokenizer = hf_tokenizer('deepseek-ai/deepseek-coder-1.3b-instruct') local_path = get_gsm8k_data() dataset = RLHFDataset(parquet_files=local_path, tokenizer=tokenizer, prompt_key='prompt', max_prompt_length=256) diff --git a/tests/verl/utils/dataset/test_rm_dataset.py b/tests/verl/utils/dataset/test_rm_dataset.py index c139134..f40d4ac 100644 --- a/tests/verl/utils/dataset/test_rm_dataset.py +++ b/tests/verl/utils/dataset/test_rm_dataset.py @@ -14,7 +14,7 @@ import os from transformers import AutoTokenizer -from verl.utils import set_pad_token_id +from verl.utils import hf_tokenizer from verl.utils.dataset.rm_dataset import RMDataset @@ -24,16 +24,11 @@ def get_rm_data(): local_folder = os.path.expanduser('~/verl-data/full_hh_rlhf/rm/') local_path = os.path.join(local_folder, 'test.parquet') os.makedirs(local_folder, exist_ok=True) - # import fsspec - # with fsspec.open(url, mode='rb') as fin, fsspec.open(local_path, mode='wb') as fout: - # content = fin.read() - # fout.write(content) return local_path def test_rm_dataset(): - tokenizer = AutoTokenizer.from_pretrained("facebook/opt-1.3b") - set_pad_token_id(tokenizer) + tokenizer = hf_tokenizer("facebook/opt-1.3b") local_path = get_rm_data() dataset = RMDataset(parquet_files=local_path, tokenizer=tokenizer, max_length=512) data = dataset[0]['input_ids'] diff --git a/tests/verl/utils/dataset/test_sft_dataset.py b/tests/verl/utils/dataset/test_sft_dataset.py new file mode 100644 index 0000000..8834225 --- /dev/null +++ b/tests/verl/utils/dataset/test_sft_dataset.py @@ -0,0 +1,60 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +from transformers import AutoTokenizer +from verl.utils import hf_tokenizer +from verl.utils.dataset.sft_dataset import SFTDataset + + +def get_gsm8k_data(): + # prepare test dataset + url = "https://github.com/eric-haibin-lin/verl-data/raw/refs/heads/main/gsm8k/train.parquet" + local_folder = os.path.expanduser('~/verl-data/gsm8k/') + local_path = os.path.join(local_folder, 'train.parquet') + return local_path + + +def test_sft_cot_dataset(): + tokenizer = hf_tokenizer('deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct') + local_path = get_gsm8k_data() + dataset = SFTDataset(parquet_files=local_path, + tokenizer=tokenizer, + prompt_key='prompt', + prompt_dict_keys=['content'], + response_key='extra_info', + response_dict_keys=['answer'], + max_length=512) + + data = dataset[0]['input_ids'] + output = tokenizer.batch_decode([data])[0] + assert len(output) > 1 + assert type(output) == str + + +def test_sft_dataset(): + tokenizer = hf_tokenizer('deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct') + local_path = get_gsm8k_data() + dataset = SFTDataset(parquet_files=local_path, + tokenizer=tokenizer, + prompt_key='extra_info', + prompt_dict_keys=['question'], + response_key='extra_info', + response_dict_keys=['answer'], + max_length=512) + + data = dataset[0]['input_ids'] + output = tokenizer.batch_decode([data])[0] + assert len(output) > 1 + assert type(output) == str \ No newline at end of file diff --git a/verl/trainer/fsdp_sft_trainer.py b/verl/trainer/fsdp_sft_trainer.py index 9af663e..292d2a4 100644 --- a/verl/trainer/fsdp_sft_trainer.py +++ b/verl/trainer/fsdp_sft_trainer.py @@ -62,10 +62,8 @@ def __init__(self, config, device_mesh: DeviceMesh): self.device_mesh = device_mesh # build tokenizer first local_model_path = copy_local_path_from_hdfs(src=self.config.model.partial_pretrain, verbose=True) - self.tokenizer = AutoTokenizer.from_pretrained(local_model_path, - trust_remote_code=self.config.model.trust_remote_code) - from verl.utils import set_pad_token_id - set_pad_token_id(self.tokenizer) + from verl.utils import hf_tokenizer + self.tokenizer = hf_tokenizer(local_model_path, trust_remote_code=self.config.model.trust_remote_code) if self.config.data.chat_template is not None: raise ValueError('Apply Chat template from config is not supported yet.') @@ -77,6 +75,8 @@ def __init__(self, config, device_mesh: DeviceMesh): self._build_model_optimizer() # TODO: add checkpoint manager + if self.device_mesh.get_rank() == 0: + print(self.config) def _normalize_config_bsz(self): dp_size = self.device_mesh.size() @@ -95,13 +95,17 @@ def _build_dataloader(self): self.train_dataset = SFTDataset(parquet_files=config.data.train_files, tokenizer=self.tokenizer, prompt_key=config.data.prompt_key, + prompt_dict_keys=config.data.get('prompt_dict_keys', None), response_key=config.data.response_key, + response_dict_keys=config.data.get('response_dict_keys', None), max_length=config.data.max_length, truncation=config.data.truncation) self.val_dataset = SFTDataset(parquet_files=config.data.val_files, tokenizer=self.tokenizer, prompt_key=config.data.prompt_key, + prompt_dict_keys=config.data.get('prompt_dict_keys', None), response_key=config.data.response_key, + response_dict_keys=config.data.get('response_dict_keys', None), max_length=config.data.max_length, truncation=config.data.truncation) @@ -292,10 +296,11 @@ def save_checkpoint(self, step): # save huggingface model if self.device_mesh.get_rank() == 0: os.makedirs(path, exist_ok=True) - hdfs_io.makedirs(self.config.trainer.default_hdfs_dir) self.model.save_pretrained(path, state_dict=state_dict) self.tokenizer.save_pretrained(path) - hdfs_io.copy(src=path, dst=self.config.trainer.default_hdfs_dir) + if self.config.trainer.default_hdfs_dir: + hdfs_io.makedirs(self.config.trainer.default_hdfs_dir, exist_ok=True) + hdfs_io.copy(src=path, dst=self.config.trainer.default_hdfs_dir, dirs_exist_ok=True) torch.distributed.barrier() def fit(self): @@ -349,7 +354,6 @@ def main(config): local_rank, rank, world_size = initialize_global_process_group() device_mesh = init_device_mesh(device_type='cuda', mesh_shape=(world_size,), mesh_dim_names=('dp',)) - trainer = FSDPSFTTrainer(config=config, device_mesh=device_mesh) trainer.fit() diff --git a/verl/trainer/main_generation.py b/verl/trainer/main_generation.py index 42469b6..47677d3 100644 --- a/verl/trainer/main_generation.py +++ b/verl/trainer/main_generation.py @@ -43,9 +43,8 @@ def main(config): pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values OmegaConf.resolve(config) local_path = copy_local_path_from_hdfs(config.model.path) - tokenizer = AutoTokenizer.from_pretrained(local_path) - from verl.utils import set_pad_token_id - set_pad_token_id(tokenizer) + from verl.utils import hf_tokenizer + tokenizer = hf_tokenizer(local_path) if config.rollout.temperature == 0.: assert config.data.n_samples == 1, 'When temperature=0, n_samples must be 1.' diff --git a/verl/trainer/main_ppo.py b/verl/trainer/main_ppo.py index 3c165b6..2e664fe 100644 --- a/verl/trainer/main_ppo.py +++ b/verl/trainer/main_ppo.py @@ -31,6 +31,8 @@ def _select_rm_score_fn(data_source): class RewardManager(): + """The reward manager. + """ def __init__(self, tokenizer, num_examine) -> None: self.tokenizer = tokenizer @@ -112,9 +114,8 @@ def main_task(config): local_path = copy_local_path_from_hdfs(config.actor_rollout_ref.model.path) # instantiate tokenizer - tokenizer = AutoTokenizer.from_pretrained(local_path) - from verl.utils import set_pad_token_id - set_pad_token_id(tokenizer) + from verl.utils import hf_tokenizer + tokenizer = hf_tokenizer(local_path) # define worker classes if config.actor_rollout_ref.actor.strategy == 'fsdp': diff --git a/verl/trainer/ppo/ray_trainer.py b/verl/trainer/ppo/ray_trainer.py index 3a2b258..a5f8879 100644 --- a/verl/trainer/ppo/ray_trainer.py +++ b/verl/trainer/ppo/ray_trainer.py @@ -420,6 +420,9 @@ def fit(self): if self.val_reward_fn is not None: val_metrics = self._validate() pprint(f'Initial validation metrics: {val_metrics}') + logger.log(data=val_metrics, step=global_steps) + if self.config.trainer.get('val_only', False): + return for epoch in range(self.config.trainer.total_epochs): for batch_dict in self.train_dataloader: @@ -527,3 +530,4 @@ def fit(self): if self.val_reward_fn is not None: val_metrics = self._validate() pprint(f'Final validation metrics: {val_metrics}') + logger.log(data=val_metrics, step=global_steps) diff --git a/verl/trainer/ppo/workers/fsdp_workers.py b/verl/trainer/ppo/workers/fsdp_workers.py index e7b34cd..439a36b 100644 --- a/verl/trainer/ppo/workers/fsdp_workers.py +++ b/verl/trainer/ppo/workers/fsdp_workers.py @@ -35,7 +35,7 @@ from verl.utils.import_utils import import_external_libs from verl.utils.debug import log_gpu_memory_usage import verl.utils.hdfs_io as hdfs_io -from verl.utils import set_pad_token_id +from verl.utils import hf_tokenizer logger = logging.getLogger(__file__) logger.setLevel(os.getenv('VERL_PPO_LOGGING_LEVEL', 'WARN')) @@ -107,8 +107,7 @@ def _build_model_optimizer(self, # note that we have to create model in fp32. Otherwise, the optimizer is in bf16, which is incorrect # TODO(zhangchi.usc1992): 1. support create from random initialized model. 2. Support init with FSDP directly - self.tokenizer = AutoTokenizer.from_pretrained(local_path, trust_remote_code=trust_remote_code) - set_pad_token_id(self.tokenizer) + self.tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code) torch_dtype = fsdp_config.get('model_dtype', None) if torch_dtype is None: @@ -467,9 +466,7 @@ def _build_critic_model_optimizer(self, config): from transformers import AutoTokenizer tokenizer_path = copy_local_path_from_hdfs(config.model.tokenizer_path) - self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, - trust_remote_code=config.model.get('trust_remote_code', False)) - set_pad_token_id(self.tokenizer) + self.tokenizer = hf_tokenizer(tokenizer_path, trust_remote_code=config.model.get('trust_remote_code', False)) from omegaconf import OmegaConf override_config = OmegaConf.to_container(self.config.model.get('override_config', OmegaConf.create())) @@ -673,14 +670,9 @@ def _build_model(self, config): else: self._do_switch_chat_template = True input_tokenizer_local_path = copy_local_path_from_hdfs(config.model.input_tokenizer) - self.input_tokenizer = AutoTokenizer.from_pretrained(input_tokenizer_local_path, - trust_remote_code=config.model.get( - 'trust_remote_code', False)) - self.tokenizer = AutoTokenizer.from_pretrained(local_path, - trust_remote_code=config.model.get( - 'trust_remote_code', False)) - set_pad_token_id(self.tokenizer) - set_pad_token_id(self.input_tokenizer) + self.input_tokenizer = hf_tokenizer(input_tokenizer_local_path, + trust_remote_code=config.model.get('trust_remote_code', False)) + self.tokenizer = hf_tokenizer(local_path, trust_remote_code=config.model.get('trust_remote_code', False)) trust_remote_code = config.model.get('trust_remote_code', False) model_config = AutoConfig.from_pretrained(local_path, trust_remote_code=trust_remote_code) diff --git a/verl/trainer/ppo/workers/megatron_workers.py b/verl/trainer/ppo/workers/megatron_workers.py index b2ce989..b826905 100644 --- a/verl/trainer/ppo/workers/megatron_workers.py +++ b/verl/trainer/ppo/workers/megatron_workers.py @@ -35,7 +35,7 @@ from verl.utils.model import load_megatron_model_weights from verl.utils.megatron_utils import init_model_parallel_config from verl.utils.megatron_utils import offload_megatron_param_and_grad, load_megatron_param_and_grad -from verl.utils import set_pad_token_id +from verl.utils import hf_tokenizer from megatron.core import parallel_state as mpu from megatron.core import ModelParallelConfig @@ -136,8 +136,7 @@ def _build_model_optimizer(self, # Step 1: initialize the tokenizer local_path = copy_local_path_from_hdfs(model_path) - self.tokenizer = AutoTokenizer.from_pretrained(local_path) - set_pad_token_id(self.tokenizer) + self.tokenizer = hf_tokenizer(local_path) # Step 2: get the actor_model_config actor_model_config = AutoConfig.from_pretrained(local_path) @@ -460,8 +459,7 @@ def _build_critic_model_optimizer(self, # Step 1: initialize the tokenizer local_path = copy_local_path_from_hdfs(model_path) - self.tokenizer = AutoTokenizer.from_pretrained(local_path) - set_pad_token_id(self.tokenizer) + self.tokenizer = hf_tokenizer(local_path) # Step 2: get the actor_model_config critic_model_config = AutoConfig.from_pretrained(local_path) @@ -624,8 +622,7 @@ def _build_rm_model(self, model_path, megatron_config: ModelParallelConfig, over # Step 1: initialize the tokenizer local_path = copy_local_path_from_hdfs(model_path) - self.tokenizer = AutoTokenizer.from_pretrained(local_path) - set_pad_token_id(self.tokenizer) + self.tokenizer = hf_tokenizer(local_path) # Step 2: get the actor_model_config rm_model_config = AutoConfig.from_pretrained(local_path) @@ -688,14 +685,13 @@ def init_model(self): override_model_config = OmegaConf.to_container(self.config.model.get('override_config', OmegaConf.create())) sft_tokenizer_local_path = copy_local_path_from_hdfs(self.config.model.input_tokenizer) - sft_tokenizer = AutoTokenizer.from_pretrained(sft_tokenizer_local_path) - set_pad_token_id(sft_tokenizer) + sft_tokenizer = hf_tokenizer(sft_tokenizer_local_path) rm_tokenizer_path = self.config.model.get('rm_tokenizer', None) rm_tokenizer = None if rm_tokenizer_path is not None: rm_tokenizer_local_path = copy_local_path_from_hdfs(rm_tokenizer_path) - rm_tokenizer = AutoTokenizer.from_pretrained(rm_tokenizer_local_path) - set_pad_token_id(rm_tokenizer) + rm_tokenizer = hf_tokenizer(rm_tokenizer_local_path) + torch_dtype = torch.bfloat16 megatron_config = OmegaConf.create({ diff --git a/verl/utils/dataset/rm_dataset.py b/verl/utils/dataset/rm_dataset.py index 29d3cd9..cba178d 100644 --- a/verl/utils/dataset/rm_dataset.py +++ b/verl/utils/dataset/rm_dataset.py @@ -21,7 +21,7 @@ from torch.utils.data import Dataset from transformers import AutoTokenizer -from verl.utils import set_pad_token_id +from verl.utils import hf_tokenizer def download_files_distributed(download_fn): @@ -54,8 +54,7 @@ def __init__(self, self.parquet_files = parquet_files self.cache_dir = os.path.expanduser(cache_dir) if isinstance(tokenizer, str): - tokenizer = AutoTokenizer.from_pretrained(tokenizer) - set_pad_token_id(tokenizer) + tokenizer = hf_tokenizer(tokenizer) self.tokenizer = tokenizer self.prompt_key = prompt_key diff --git a/verl/utils/dataset/sft_dataset.py b/verl/utils/dataset/sft_dataset.py index c39d7be..9c7a296 100644 --- a/verl/utils/dataset/sft_dataset.py +++ b/verl/utils/dataset/sft_dataset.py @@ -28,7 +28,7 @@ from verl.utils.fs import copy_local_path_from_hdfs from verl.utils.model import compute_position_id_with_mask -from verl.utils import set_pad_token_id +from verl.utils import hf_tokenizer class SFTDataset(Dataset): @@ -40,7 +40,9 @@ def __init__(self, parquet_files: Union[str, List[str]], tokenizer, prompt_key='prompt', + prompt_dict_keys=None, response_key='response', + response_dict_keys=None, max_length=1024, truncation='error'): assert truncation in ['error', 'left', 'right'] @@ -51,12 +53,13 @@ def __init__(self, self.parquet_files = parquet_files if isinstance(tokenizer, str): - tokenizer = AutoTokenizer.from_pretrained(tokenizer) - set_pad_token_id(tokenizer) + tokenizer = hf_tokenizer(tokenizer) self.tokenizer: PreTrainedTokenizer = tokenizer - self.prompt_key = prompt_key - self.response_key = response_key + self.prompt_key = prompt_key if isinstance(prompt_key, (tuple, list)) else [prompt_key] + self.response_key = response_key if isinstance(response_key, (tuple, list)) else [response_key] + self.prompt_dict_keys = [] if not prompt_dict_keys else prompt_dict_keys + self.response_dict_keys = [] if not response_dict_keys else response_dict_keys self.max_length = max_length @@ -68,14 +71,38 @@ def _download(self): self.parquet_files[i] = copy_local_path_from_hdfs(parquet_file, verbose=True) def _read_files_and_tokenize(self): + + def series_to_item(ls): + import pandas, numpy + while isinstance(ls, (pandas.core.series.Series, numpy.ndarray)) and len(ls) == 1: + ls = ls[0] + return ls + dataframes = [] for parquet_file in self.parquet_files: # read parquet files and cache dataframe = pd.read_parquet(parquet_file) dataframes.append(dataframe) self.dataframe = pd.concat(dataframes) - self.prompts = self.dataframe[self.prompt_key].tolist() - self.responses = self.dataframe[self.response_key].tolist() + self.prompts = self.dataframe[self.prompt_key] + for key in self.prompt_dict_keys: + # type(x): pandas.core.series.Series + # type(x[0]): numpy.ndarray + # type(x[0][0]): dict + try: + self.prompts = self.prompts.apply(lambda x: series_to_item(x)[key], axis=1) + except Exception: + print(f'self.prompts={self.prompts}') + raise + self.prompts = self.prompts.tolist() + self.responses = self.dataframe[self.response_key] + for key in self.response_dict_keys: + try: + self.responses = self.responses.apply(lambda x: series_to_item(x)[key], axis=1) + except Exception: + print(f'self.responses={self.responses}') + raise + self.responses = self.responses.tolist() def __len__(self): return len(self.prompts) @@ -145,18 +172,3 @@ def __getitem__(self, item): 'position_ids': position_ids, 'loss_mask': loss_mask } - - -if __name__ == '__main__': - local_model_path = copy_local_path_from_hdfs('~/models/gemma-2b-it') - tokenizer = AutoTokenizer.from_pretrained(local_model_path) - set_pad_token_id(tokenizer) - dataset = SFTDataset(parquet_files='~/data/gsm8k/train.parquet', - tokenizer=tokenizer, - prompt_key='question', - response_key='answer', - max_length=512) - - data = dataset[0]['input_ids'] - output = tokenizer.batch_decode([data])[0] - print(output) diff --git a/verl/utils/hdfs_io.py b/verl/utils/hdfs_io.py index f23b406..08c4ecb 100644 --- a/verl/utils/hdfs_io.py +++ b/verl/utils/hdfs_io.py @@ -14,7 +14,6 @@ import os import shutil -import subprocess import logging logger = logging.getLogger(__file__) @@ -83,7 +82,7 @@ def _mkdir(file_path: str) -> bool: def copy(src: str, dst: str, **kwargs) -> bool: - r"""Works like shutil.copy() but supports hdfs. + r"""Works like shutil.copy() for file, and shutil.copytree for dir, and supports hdfs. Copy data and mode bits ("cp src dst"). Return the file's destination. The destination may be a directory. @@ -105,7 +104,10 @@ def copy(src: str, dst: str, **kwargs) -> bool: # - return file destination for hdfs files return _copy(src, dst) else: - return shutil.copy(src, dst) + if os.path.isdir(src): + return shutil.copytree(src, dst, **kwargs) + else: + return shutil.copy(src, dst, **kwargs) def _copy(from_path: str, to_path: str, timeout: int = None) -> bool: diff --git a/verl/utils/logger/aggregate_logger.py b/verl/utils/logger/aggregate_logger.py index 453f414..ac57cf5 100644 --- a/verl/utils/logger/aggregate_logger.py +++ b/verl/utils/logger/aggregate_logger.py @@ -14,7 +14,6 @@ """ A Ray logger will receive logging info from different processes. """ - import numbers from typing import Dict @@ -40,4 +39,4 @@ def flush(self): def log(self, data, step): if self.print_to_console: - print(concat_dict_to_str(data, step=step)) + print(concat_dict_to_str(data, step=step), flush=True) \ No newline at end of file diff --git a/verl/utils/reward_score/gsm8k.py b/verl/utils/reward_score/gsm8k.py index 9e21d58..ab7eda4 100644 --- a/verl/utils/reward_score/gsm8k.py +++ b/verl/utils/reward_score/gsm8k.py @@ -49,4 +49,4 @@ def compute_score(solution_str, ground_truth, method='strict', format_score=0., if answer == ground_truth: return score else: - return format_score + return format_score \ No newline at end of file diff --git a/verl/utils/tokenizer.py b/verl/utils/tokenizer.py index 55de19f..b64b662 100644 --- a/verl/utils/tokenizer.py +++ b/verl/utils/tokenizer.py @@ -12,8 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. """Utils for tokenization.""" +import warnings -__all__ = ['set_pad_token_id'] +__all__ = ['hf_tokenizer'] def set_pad_token_id(tokenizer): @@ -25,5 +26,33 @@ def set_pad_token_id(tokenizer): """ if tokenizer.pad_token_id is None: tokenizer.pad_token_id = tokenizer.eos_token_id + warnings.warn(f'tokenizer.pad_token_id is None. Now set to {tokenizer.eos_token_id}') if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token + warnings.warn(f'tokenizer.pad_token is None. Now set to {tokenizer.eos_token}') + + +def hf_tokenizer(name_or_path, correct_pad_token=True, correct_gemma2=True, **kwargs): + """Create a huggingface pretrained tokenizer. + + Args: + name (str): The name of the tokenizer. + correct_pad_token (bool): Whether to correct the pad token id. + correct_gemma2 (bool): Whether to correct the gemma2 tokenizer. + **kwargs: The keyword arguments for the tokenizer. + + Returns: + transformers.PreTrainedTokenizer: The pretrained tokenizer. + + """ + from transformers import AutoTokenizer + if correct_gemma2 and isinstance(name_or_path, str) and 'gemma-2-2b-it' in name_or_path: + # the EOS token in gemma2 is ambiguious, which may worsen RL performance. + # https://huggingface.co/google/gemma-2-2b-it/commit/17a01657f5c87135bcdd0ec7abb4b2dece04408a + warnings.warn('Found gemma-2-2b-it tokenizer. Set eos_token and eos_token_id to and 107.') + kwargs['eos_token'] = '' + kwargs['eos_token_id'] = 107 + tokenizer = AutoTokenizer.from_pretrained(name_or_path, **kwargs) + if correct_pad_token: + set_pad_token_id(tokenizer) + return tokenizer \ No newline at end of file