Skip to content

Commit

Permalink
[misc] fix issue in hf_weight_loader and fix typo in doc (#30)
Browse files Browse the repository at this point in the history
* [fix] fix some bugs related to hf_weight_loader

* [doc] fix doc typo

* [ci] fix github action

* [ci] lint

* [doc] fix typo
  • Loading branch information
PeterSH6 authored Dec 1, 2024
1 parent c5a0964 commit cfc976b
Show file tree
Hide file tree
Showing 7 changed files with 32 additions and 25 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/yapf_format.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ jobs:
python-version: ["3.12"]
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
ref: ${{ github.head_ref }} # Checkout the branch associated with the pull request
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
with:
Expand Down
2 changes: 1 addition & 1 deletion docs/preparation/prepare_data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ into two parts:
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--local_dir', default='/opt/tiger/gsm8k')
parser.add_argument('--hdfs_dir', default='hdfs://haruna/home/byte_data_seed/lf_lq/user/zhangchi.usc1992/data/rlhf')
parser.add_argument('--hdfs_dir', default=None)
args = parser.parse_args()
Expand Down
3 changes: 2 additions & 1 deletion tests/ray/detached_worker/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ def compute_position_id_with_mask(mask):
# get the worker group using names
worker_names = ['trainerTrainer_0:0', 'trainerTrainer_0:1']
cls_with_init_args = RayClassWithInitArgs(cls=Trainer)
worker_group = NVMegatronRayWorkerGroup.from_detached(worker_names=worker_names, ray_cls_with_init=cls_with_init_args)
worker_group = NVMegatronRayWorkerGroup.from_detached(worker_names=worker_names,
ray_cls_with_init=cls_with_init_args)

batch_size = 16
sequence_length = 1024
Expand Down
42 changes: 21 additions & 21 deletions tests/ray/detached_worker/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,11 @@ def __init__(self):
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def init_model(self):
actor_model_config = LlamaConfig(vocab_size=256,
hidden_size=2048,
intermediate_size=5504,
num_hidden_layers=24,
num_attention_heads=16,
num_key_value_heads=16)
hidden_size=2048,
intermediate_size=5504,
num_hidden_layers=24,
num_attention_heads=16,
num_key_value_heads=16)

megatron_config = OmegaConf.create({
'sequence_parallel_enabled': True,
Expand All @@ -96,21 +96,18 @@ def megatron_actor_model_provider(pre_process, post_process):
# this_megatron_config = copy.deepcopy(megatron_config)
# this_megatron_config.virtual_pipeline_model_parallel_rank = vpp_rank
parallel_model = ParallelLlamaForCausalLMRmPadPP(config=actor_model_config,
megatron_config=megatron_config,
pre_process=pre_process,
post_process=post_process)
megatron_config=megatron_config,
pre_process=pre_process,
post_process=post_process)
parallel_model.cuda()
return parallel_model

actor_module = get_model(model_provider_func=megatron_actor_model_provider,
actor_module = get_model(model_provider_func=megatron_actor_model_provider,
model_type=ModelType.encoder_or_decoder,
wrap_with_ddp=True)
actor_module = nn.ModuleList(actor_module)

optim_config = OmegaConf.create({
'lr': 1e-6,
'clip_grad': 1.0
})
optim_config = OmegaConf.create({'lr': 1e-6, 'clip_grad': 1.0})

optim_config = init_megatron_optim_config(optim_config)
self.optimizer_config = optim_config
Expand All @@ -126,13 +123,15 @@ def train_model(self, data: DataProto) -> DataProto:
position_ids = data.batch['position_ids']

self.optimizer.zero_grad()
self.model.zero_grad_buffer(zero_buffer=(not self.optimizer_config.use_distributed_optimizer)) # use use_contiguous_buffers_in_local_ddp and no overlap_dp_param_comm
self.model.zero_grad_buffer(
zero_buffer=(not self.optimizer_config.use_distributed_optimizer
)) # use use_contiguous_buffers_in_local_ddp and no overlap_dp_param_comm
# update for 1 iteration
output = self.model(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids).logits
output.mean().backward()

update_successful, grad_norm, num_zeros_in_grad = self.optimizer.step(
self.megatron_config, self.megatron_config.timers)
update_successful, grad_norm, num_zeros_in_grad = self.optimizer.step(self.megatron_config,
self.megatron_config.timers)

return DataProto(batch=TensorDict({'loss': output.detach()}, batch_size=output.shape[0]))

Expand All @@ -142,11 +141,12 @@ def train_model(self, data: DataProto) -> DataProto:

resource_pool = RayResourcePool(process_on_nodes=[2], detached=True)
cls_with_init_args = RayClassWithInitArgs(cls=Trainer)
worker_group = NVMegatronRayWorkerGroup(resource_pool=resource_pool,
ray_cls_with_init=cls_with_init_args,
name_prefix='trainer',
detached=True,
)
worker_group = NVMegatronRayWorkerGroup(
resource_pool=resource_pool,
ray_cls_with_init=cls_with_init_args,
name_prefix='trainer',
detached=True,
)

worker_group.init_model()

Expand Down
2 changes: 2 additions & 0 deletions verl/third_party/vllm/vllm_v_0_5_4/hf_weight_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ def update_hf_weight_loader():
def load_hf_weights(actor_weights: Dict, vllm_model: nn.Module):
assert isinstance(actor_weights, Dict)
with set_default_torch_dtype(next(vllm_model.parameters()).dtype): # TODO
if vllm_model.config.tie_word_embeddings and "lm_head.weight" in actor_weights.keys():
del actor_weights["lm_head.weight"]
vllm_model.load_weights(actor_weights.items())
for _, module in vllm_model.named_modules():
quant_method = getattr(module, "quant_method", None)
Expand Down
4 changes: 2 additions & 2 deletions verl/trainer/ppo/hybrid_engine/fsdp_vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import logging
import torch
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.api import ShardingStrategy, ShardedStateDictConfig, StateDictType
from torch.distributed.fsdp.api import ShardingStrategy, ShardedStateDictConfig, StateDictType, FullStateDictConfig

from verl.third_party.vllm import LLM
from verl.third_party.vllm import parallel_state as vllm_ps
Expand All @@ -42,7 +42,7 @@ def __init__(self, module: FSDP, inference_engine: LLM, model_config, full_param
if full_params:
FSDP.set_state_dict_type(self.module,
state_dict_type=StateDictType.FULL_STATE_DICT,
state_dict_config=ShardedStateDictConfig())
state_dict_config=FullStateDictConfig())
else:
FSDP.set_state_dict_type(self.module,
state_dict_type=StateDictType.SHARDED_STATE_DICT,
Expand Down
2 changes: 2 additions & 0 deletions verl/trainer/ppo/rollout/vllm_rollout/vllm_rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ def __init__(self, actor_module: nn.Module, config: DictConfig, tokenizer, model
"disable CUDA graph (enforce_eager = False) if free cache engine"

tensor_parallel_size = self.config.get('tensor_model_parallel_size', 1)
assert tensor_parallel_size <= torch.distributed.get_world_size(), \
"tensor parallel size should be less than or equal to the world size"

if kwargs.get('train_tp', None) is not None:
# deployed with megatron
Expand Down

0 comments on commit cfc976b

Please sign in to comment.