Skip to content

Commit

Permalink
[ci] lint
Browse files Browse the repository at this point in the history
  • Loading branch information
PeterSH6 committed Dec 1, 2024
1 parent 464dcb0 commit c610ad1
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 22 deletions.
3 changes: 2 additions & 1 deletion tests/ray/detached_worker/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ def compute_position_id_with_mask(mask):
# get the worker group using names
worker_names = ['trainerTrainer_0:0', 'trainerTrainer_0:1']
cls_with_init_args = RayClassWithInitArgs(cls=Trainer)
worker_group = NVMegatronRayWorkerGroup.from_detached(worker_names=worker_names, ray_cls_with_init=cls_with_init_args)
worker_group = NVMegatronRayWorkerGroup.from_detached(worker_names=worker_names,
ray_cls_with_init=cls_with_init_args)

batch_size = 16
sequence_length = 1024
Expand Down
42 changes: 21 additions & 21 deletions tests/ray/detached_worker/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,11 @@ def __init__(self):
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def init_model(self):
actor_model_config = LlamaConfig(vocab_size=256,
hidden_size=2048,
intermediate_size=5504,
num_hidden_layers=24,
num_attention_heads=16,
num_key_value_heads=16)
hidden_size=2048,
intermediate_size=5504,
num_hidden_layers=24,
num_attention_heads=16,
num_key_value_heads=16)

megatron_config = OmegaConf.create({
'sequence_parallel_enabled': True,
Expand All @@ -96,21 +96,18 @@ def megatron_actor_model_provider(pre_process, post_process):
# this_megatron_config = copy.deepcopy(megatron_config)
# this_megatron_config.virtual_pipeline_model_parallel_rank = vpp_rank
parallel_model = ParallelLlamaForCausalLMRmPadPP(config=actor_model_config,
megatron_config=megatron_config,
pre_process=pre_process,
post_process=post_process)
megatron_config=megatron_config,
pre_process=pre_process,
post_process=post_process)
parallel_model.cuda()
return parallel_model

actor_module = get_model(model_provider_func=megatron_actor_model_provider,
actor_module = get_model(model_provider_func=megatron_actor_model_provider,
model_type=ModelType.encoder_or_decoder,
wrap_with_ddp=True)
actor_module = nn.ModuleList(actor_module)

optim_config = OmegaConf.create({
'lr': 1e-6,
'clip_grad': 1.0
})
optim_config = OmegaConf.create({'lr': 1e-6, 'clip_grad': 1.0})

optim_config = init_megatron_optim_config(optim_config)
self.optimizer_config = optim_config
Expand All @@ -126,13 +123,15 @@ def train_model(self, data: DataProto) -> DataProto:
position_ids = data.batch['position_ids']

self.optimizer.zero_grad()
self.model.zero_grad_buffer(zero_buffer=(not self.optimizer_config.use_distributed_optimizer)) # use use_contiguous_buffers_in_local_ddp and no overlap_dp_param_comm
self.model.zero_grad_buffer(
zero_buffer=(not self.optimizer_config.use_distributed_optimizer
)) # use use_contiguous_buffers_in_local_ddp and no overlap_dp_param_comm
# update for 1 iteration
output = self.model(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids).logits
output.mean().backward()

update_successful, grad_norm, num_zeros_in_grad = self.optimizer.step(
self.megatron_config, self.megatron_config.timers)
update_successful, grad_norm, num_zeros_in_grad = self.optimizer.step(self.megatron_config,
self.megatron_config.timers)

return DataProto(batch=TensorDict({'loss': output.detach()}, batch_size=output.shape[0]))

Expand All @@ -142,11 +141,12 @@ def train_model(self, data: DataProto) -> DataProto:

resource_pool = RayResourcePool(process_on_nodes=[2], detached=True)
cls_with_init_args = RayClassWithInitArgs(cls=Trainer)
worker_group = NVMegatronRayWorkerGroup(resource_pool=resource_pool,
ray_cls_with_init=cls_with_init_args,
name_prefix='trainer',
detached=True,
)
worker_group = NVMegatronRayWorkerGroup(
resource_pool=resource_pool,
ray_cls_with_init=cls_with_init_args,
name_prefix='trainer',
detached=True,
)

worker_group.init_model()

Expand Down

0 comments on commit c610ad1

Please sign in to comment.