-
Notifications
You must be signed in to change notification settings - Fork 239
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Do not aggregate the losses since last log step #779
base: main
Are you sure you want to change the base?
Conversation
utils.dist_mean(loss, world_mesh["dp_cp"]), | ||
utils.dist_max(loss, world_mesh["dp_cp"]), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These use functional collectives under the hood so there shouldn't be issues with passing in the same tensor reference
@@ -34,12 +34,12 @@ def get_device_info(): | |||
device_type, device_module = get_device_info() | |||
|
|||
|
|||
def dist_max(x: Union[int, float], mesh: DeviceMesh) -> float: | |||
def dist_max(x: Union[int, float, torch.Tensor], mesh: DeviceMesh) -> float: | |||
tensor = torch.tensor(x).to(device_type) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
AFAIK these will result in no-ops if a tensor of with the same device type is passed.
utils.dist_mean(loss, world_mesh["dp_cp"]), | ||
utils.dist_max(loss, world_mesh["dp_cp"]), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems with Tensor Parallel, the loss
is a DTensor, which doesn't support functional collectives. Also we should not require gradients on this all-reduce.
Maybe it's still fine to do .item()
outside as before? or use detach and full_tensor()
?
Fixes #763