Skip to content

Commit

Permalink
Zero2: avoid graph breaks in torch.compile by using param_idx (#6803)
Browse files Browse the repository at this point in the history
inside reduce_independent_p_g_buckets_and_remove_grads and in
reduce_ipg_grads which are being executed during the BWD hook in zero2,
the model param is being stored inside params_in_ipg_bucket.
torch.compile has hard time tracing parameters.
By using the param's static index inside the group the same logic can be
maintain with less complexity.

---------

Co-authored-by: Olatunji Ruwase <[email protected]>
Co-authored-by: Logan Adams <[email protected]>
Co-authored-by: Logan Adams <[email protected]>
  • Loading branch information
4 people authored Dec 20, 2024
1 parent 4fd7920 commit 00ea0c4
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 4 deletions.
9 changes: 6 additions & 3 deletions deepspeed/runtime/zero/stage_1_and_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,7 @@ def __init__(self,
for param in param_group['params']:
if param.requires_grad:
param.grad_accum = None
param.param_idx_in_group = len(trainable_parameters)
trainable_parameters.append(param)
self.bit16_groups.append(trainable_parameters)

Expand Down Expand Up @@ -961,7 +962,7 @@ def reduce_independent_p_g_buckets_and_remove_grads(self, param, i):
assert grad_reduc is not None, f"rank {dist.get_rank()} - Invalid to reduce Param {param_id} with None gradient"

self.grads_in_ipg_bucket.append(grad_reduc)
self.params_in_ipg_bucket.append((i, param, param_id))
self.params_in_ipg_bucket.append((i, param.param_idx_in_group, param_id))

#make sure the average tensor function knows how to average the gradients
if is_moe_param(param):
Expand Down Expand Up @@ -1067,7 +1068,8 @@ def average_tensor(self, tensor):

process_group = self.dp_process_group
# count = 0
for i, param, param_id in self.params_in_ipg_bucket:
for i, param_idx_in_group, param_id in self.params_in_ipg_bucket:
param = self.bit16_groups[i][param_idx_in_group]

process_group = self.dp_process_group

Expand Down Expand Up @@ -1383,7 +1385,8 @@ def reduce_ipg_grads(self):
stream = get_accelerator().current_stream()

with get_accelerator().stream(stream):
for _, param, param_id in self.params_in_ipg_bucket:
for group_idx, param_idx_in_group, param_id in self.params_in_ipg_bucket:
param = self.bit16_groups[group_idx][param_idx_in_group]

assert self.params_already_reduced[param_id] == False, \
f"The parameter {param_id} has already been reduced. \
Expand Down
3 changes: 2 additions & 1 deletion tests/unit/moe/test_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,8 @@ def strict_average_tensor(tensor):
process_group = optimizer.dp_process_group
curr_size = 0
pg_offsets = []
for i, param, param_id in optimizer.params_in_ipg_bucket:
for i, param_idx, param_id in optimizer.params_in_ipg_bucket:
param = optimizer.bit16_groups[i][param_idx]
process_group = optimizer.dp_process_group
if optimizer.ipg_bucket_has_moe_params:
process_group = optimizer.expert_dp_process_group[param.group_name] if is_moe_param(
Expand Down

0 comments on commit 00ea0c4

Please sign in to comment.