Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NVIDIA GPU] Use memcpy for intra-node all-to-all #15144

Closed
wants to merge 3 commits into from

Conversation

terryysun
Copy link
Contributor

The communications of all-to-all rely on NCCL even when it is intra-node. By leveraging memcpy for intra-node communications, all-to-all can have better performance while reducing SM consumption (right now consumed by NCCL).

@terryysun
Copy link
Contributor Author

terryysun commented Jul 19, 2024

some screenshots showing the performance difference between NCCL call and memcpy for an example all-to-all.
Screenshot 2024-07-19 at 2 24 44 PM
Screenshot 2024-07-19 at 2 24 05 PM

@NaiyerRizz NaiyerRizz self-assigned this Jul 22, 2024
@thomasjoerg thomasjoerg requested review from frgossen and removed request for thomasjoerg July 22, 2024 08:29
Copy link
Member

@frgossen frgossen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice!

You are adding this for all-gather, all-reduce, all-to-all, collective-broadcast, and collective-permute but I only see a test case for all-to-all.
Can we add test cases for the remaining collectives? would also suggest to split this up into one PR per collective to keep reviews simpler and in case we have to roll back something.

@terryysun
Copy link
Contributor Author

Very nice!

You are adding this for all-gather, all-reduce, all-to-all, collective-broadcast, and collective-permute but I only see a test case for all-to-all. Can we add test cases for the remaining collectives? would also suggest to split this up into one PR per collective to keep reviews simpler and in case we have to roll back something.

Thanks! Right now it's only added for all-to-all, the api changes for the other collectives are just to avoid interim intricacies -- we do plan to add similar support for them in the near future, but the flag is not consumed in this PR.

@terryysun terryysun requested a review from frgossen July 23, 2024 21:31
Copy link
Member

@frgossen frgossen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for clarifying. All makes sense to me, just a few minot comments.

@@ -71,6 +71,21 @@ struct NcclCollectiveConfig {
bool IsDegenerate(int64_t replica_count, int64_t partition_count) const;
};

template <typename T>
absl::StatusOr<const int64_t> GetCurrentId(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than templating, would it make sense to impl this base on the NcclCollectiveConfig?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

after reorganizing the code turns out we don't need this util, removed.


absl::Status PutRecvPtr(int64_t send_id, int64_t recv_id, void* ptr) {
if (!IsInitialized(send_id, recv_id)) {
return absl::InternalError(absl::StrCat("Send-receive pair ", send_id, ", ", recv_id,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: recv?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated to receive.

return xla::gpu::RunAllToAll(nccl_api(), config_.has_split_dimension,
device_buffers, stream,
comm_wrapper.comm_handle);
comm_wrapper.comm_handle, current_id, use_memcpy, recv_ptr_map_);
}

absl::Status RunAllToAll(NcclApi* nccl_api, bool has_split_dimension,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is implemented in two modes which do not share much code. Can we outline that into two functions and dispatch here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

broke to two functions.

@sgerrard sgerrard requested a review from frgossen August 12, 2024 22:37
@terryysun terryysun marked this pull request as draft August 13, 2024 17:25
@frgossen
Copy link
Member

frgossen commented Sep 9, 2024

I see you pushed changes and requested a review. Can you reply to the comments and explain how the changes address them? Ty.

@terryysun terryysun marked this pull request as ready for review September 18, 2024 01:06
@terryysun
Copy link
Contributor Author

I see you pushed changes and requested a review. Can you reply to the comments and explain how the changes address them? Ty.

hey @frgossen sorry for the delayed reply, we were verifying the changes and fixed multiple issues we saw when running realistic models. I've updated the code and replied to the comments accordingly. Could you take another look? Thanks!

Copy link
Member

@frgossen frgossen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

copybara-service bot pushed a commit that referenced this pull request Sep 24, 2024
Imported from GitHub PR #15144

The communications of all-to-all rely on NCCL even when it is intra-node. By leveraging memcpy for intra-node communications, all-to-all can have better performance while reducing SM consumption (right now consumed by NCCL).
Copybara import of the project:

--
38720c7 by Terry Sun <[email protected]>:

memcpyp2p for local a2a

--
90018f4 by Terry Sun <[email protected]>:

use nccl to pass recv ptrs

--
f9b75b0 by Terry Sun <[email protected]>:

refactor and cleanup

Merging this change closes #15144

FUTURE_COPYBARA_INTEGRATE_REVIEW=#15144 from terryysun:terryysun/all2all_memcpyp2p f9b75b0
PiperOrigin-RevId: 678378925
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Sep 24, 2024
Imported from GitHub PR openxla/xla#15144

The communications of all-to-all rely on NCCL even when it is intra-node. By leveraging memcpy for intra-node communications, all-to-all can have better performance while reducing SM consumption (right now consumed by NCCL).
Copybara import of the project:

--
38720c73f5817dbbf5b6d98751140bb53f572690 by Terry Sun <[email protected]>:

memcpyp2p for local a2a

--
90018f4a3fe0ed3018767db810518faf9435bc93 by Terry Sun <[email protected]>:

use nccl to pass recv ptrs

--
f9b75b0e088286ded770b27fff9d020f8e85a648 by Terry Sun <[email protected]>:

refactor and cleanup

Merging this change closes #15144

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#15144 from terryysun:terryysun/all2all_memcpyp2p f9b75b0e088286ded770b27fff9d020f8e85a648
PiperOrigin-RevId: 678378925
copybara-service bot pushed a commit that referenced this pull request Sep 25, 2024
Imported from GitHub PR #15144

The communications of all-to-all rely on NCCL even when it is intra-node. By leveraging memcpy for intra-node communications, all-to-all can have better performance while reducing SM consumption (right now consumed by NCCL).
Copybara import of the project:

--
38720c7 by Terry Sun <[email protected]>:

memcpyp2p for local a2a

--
90018f4 by Terry Sun <[email protected]>:

use nccl to pass recv ptrs

--
f9b75b0 by Terry Sun <[email protected]>:

refactor and cleanup

Merging this change closes #15144

FUTURE_COPYBARA_INTEGRATE_REVIEW=#15144 from terryysun:terryysun/all2all_memcpyp2p f9b75b0
PiperOrigin-RevId: 678378925
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Sep 25, 2024
Imported from GitHub PR openxla/xla#15144

The communications of all-to-all rely on NCCL even when it is intra-node. By leveraging memcpy for intra-node communications, all-to-all can have better performance while reducing SM consumption (right now consumed by NCCL).
Copybara import of the project:

--
38720c73f5817dbbf5b6d98751140bb53f572690 by Terry Sun <[email protected]>:

memcpyp2p for local a2a

--
90018f4a3fe0ed3018767db810518faf9435bc93 by Terry Sun <[email protected]>:

use nccl to pass recv ptrs

--
f9b75b0e088286ded770b27fff9d020f8e85a648 by Terry Sun <[email protected]>:

refactor and cleanup

Merging this change closes #15144

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#15144 from terryysun:terryysun/all2all_memcpyp2p f9b75b0e088286ded770b27fff9d020f8e85a648
PiperOrigin-RevId: 678378925
copybara-service bot pushed a commit that referenced this pull request Sep 27, 2024
…dress sharing

Imported from GitHub PR #17636

This is a followup PR to #15144. A distributed cache is maintained when device addresses are shared across ranks. There are two issues withe the existing implementation:

1. The cache is not guarded by mutex;
2. The cache initialization process have redundant access.

These issues can cause race condition or dead lock when the progress on different ranks are very close. Consequently we need to introduce below enhancements:

1. Guard the cache with mutex;
2. Shard the initialization process by rank, so that each rank only handle a piece of the cache and should not have overlapping access in theory.

Copybara import of the project:

--
a6472fc by Terry Sun <[email protected]>:

enhance concurrency handling

--
356ab82 by Terry Sun <[email protected]>:

lock mutex

--
29ebb2d by Terry Sun <[email protected]>:

bring back test

--
91b911f by Terry Sun <[email protected]>:

better lock granularity

Merging this change closes #17636

FUTURE_COPYBARA_INTEGRATE_REVIEW=#17636 from terryysun:terryysun/sync_fix 91b911f
PiperOrigin-RevId: 679463524
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Sep 27, 2024
…dress sharing

Imported from GitHub PR openxla/xla#17636

This is a followup PR to openxla/xla#15144. A distributed cache is maintained when device addresses are shared across ranks. There are two issues withe the existing implementation:

1. The cache is not guarded by mutex;
2. The cache initialization process have redundant access.

These issues can cause race condition or dead lock when the progress on different ranks are very close. Consequently we need to introduce below enhancements:

1. Guard the cache with mutex;
2. Shard the initialization process by rank, so that each rank only handle a piece of the cache and should not have overlapping access in theory.

Copybara import of the project:

--
a6472fc75fd0411bd8e65f27082e21e9a946ab17 by Terry Sun <[email protected]>:

enhance concurrency handling

--
356ab824b95d66c793e361882e95d70689759ffd by Terry Sun <[email protected]>:

lock mutex

--
29ebb2de64711bf4b4a08cf1593317228b56f825 by Terry Sun <[email protected]>:

bring back test

--
91b911f0aaac0e590636a82956b464436e94ef9f by Terry Sun <[email protected]>:

better lock granularity

Merging this change closes #17636

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#17636 from terryysun:terryysun/sync_fix 91b911f0aaac0e590636a82956b464436e94ef9f
PiperOrigin-RevId: 679463524
copybara-service bot pushed a commit that referenced this pull request Sep 27, 2024
…dress sharing

Imported from GitHub PR #17636

This is a followup PR to #15144. A distributed cache is maintained when device addresses are shared across ranks. There are two issues withe the existing implementation:

1. The cache is not guarded by mutex;
2. The cache initialization process have redundant access.

These issues can cause race condition or dead lock when the progress on different ranks are very close. Consequently we need to introduce below enhancements:

1. Guard the cache with mutex;
2. Shard the initialization process by rank, so that each rank only handle a piece of the cache and should not have overlapping access in theory.

Copybara import of the project:

--
a6472fc by Terry Sun <[email protected]>:

enhance concurrency handling

--
356ab82 by Terry Sun <[email protected]>:

lock mutex

--
29ebb2d by Terry Sun <[email protected]>:

bring back test

--
91b911f by Terry Sun <[email protected]>:

better lock granularity

Merging this change closes #17636

FUTURE_COPYBARA_INTEGRATE_REVIEW=#17636 from terryysun:terryysun/sync_fix 91b911f
PiperOrigin-RevId: 679463524
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Sep 27, 2024
…dress sharing

Imported from GitHub PR openxla/xla#17636

This is a followup PR to openxla/xla#15144. A distributed cache is maintained when device addresses are shared across ranks. There are two issues withe the existing implementation:

1. The cache is not guarded by mutex;
2. The cache initialization process have redundant access.

These issues can cause race condition or dead lock when the progress on different ranks are very close. Consequently we need to introduce below enhancements:

1. Guard the cache with mutex;
2. Shard the initialization process by rank, so that each rank only handle a piece of the cache and should not have overlapping access in theory.

Copybara import of the project:

--
a6472fc75fd0411bd8e65f27082e21e9a946ab17 by Terry Sun <[email protected]>:

enhance concurrency handling

--
356ab824b95d66c793e361882e95d70689759ffd by Terry Sun <[email protected]>:

lock mutex

--
29ebb2de64711bf4b4a08cf1593317228b56f825 by Terry Sun <[email protected]>:

bring back test

--
91b911f0aaac0e590636a82956b464436e94ef9f by Terry Sun <[email protected]>:

better lock granularity

Merging this change closes #17636

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#17636 from terryysun:terryysun/sync_fix 91b911f0aaac0e590636a82956b464436e94ef9f
PiperOrigin-RevId: 679463524
copybara-service bot pushed a commit that referenced this pull request Sep 30, 2024
…dress sharing

Imported from GitHub PR #17636

This is a followup PR to #15144. A distributed cache is maintained when device addresses are shared across ranks. There are two issues withe the existing implementation:

1. The cache is not guarded by mutex;
2. The cache initialization process have redundant access.

These issues can cause race condition or dead lock when the progress on different ranks are very close. Consequently we need to introduce below enhancements:

1. Guard the cache with mutex;
2. Shard the initialization process by rank, so that each rank only handle a piece of the cache and should not have overlapping access in theory.

Copybara import of the project:

--
a6472fc by Terry Sun <[email protected]>:

enhance concurrency handling

--
356ab82 by Terry Sun <[email protected]>:

lock mutex

--
29ebb2d by Terry Sun <[email protected]>:

bring back test

--
91b911f by Terry Sun <[email protected]>:

better lock granularity

--
cc1d93a by Terry Sun <[email protected]>:

guard all accesses

Merging this change closes #17636

FUTURE_COPYBARA_INTEGRATE_REVIEW=#17636 from terryysun:terryysun/sync_fix cc1d93a
PiperOrigin-RevId: 679463524
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Sep 30, 2024
…dress sharing

Imported from GitHub PR openxla/xla#17636

This is a followup PR to openxla/xla#15144. A distributed cache is maintained when device addresses are shared across ranks. There are two issues withe the existing implementation:

1. The cache is not guarded by mutex;
2. The cache initialization process have redundant access.

These issues can cause race condition or dead lock when the progress on different ranks are very close. Consequently we need to introduce below enhancements:

1. Guard the cache with mutex;
2. Shard the initialization process by rank, so that each rank only handle a piece of the cache and should not have overlapping access in theory.

Copybara import of the project:

--
a6472fc75fd0411bd8e65f27082e21e9a946ab17 by Terry Sun <[email protected]>:

enhance concurrency handling

--
356ab824b95d66c793e361882e95d70689759ffd by Terry Sun <[email protected]>:

lock mutex

--
29ebb2de64711bf4b4a08cf1593317228b56f825 by Terry Sun <[email protected]>:

bring back test

--
91b911f0aaac0e590636a82956b464436e94ef9f by Terry Sun <[email protected]>:

better lock granularity

--
cc1d93a5f1032a205473961b2c2d3e14bee3a9c6 by Terry Sun <[email protected]>:

guard all accesses

Merging this change closes #17636

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#17636 from terryysun:terryysun/sync_fix cc1d93a5f1032a205473961b2c2d3e14bee3a9c6
PiperOrigin-RevId: 679463524
copybara-service bot pushed a commit that referenced this pull request Sep 30, 2024
…dress sharing

Imported from GitHub PR #17636

This is a followup PR to #15144. A distributed cache is maintained when device addresses are shared across ranks. There are two issues withe the existing implementation:

1. The cache is not guarded by mutex;
2. The cache initialization process have redundant access.

These issues can cause race condition or dead lock when the progress on different ranks are very close. Consequently we need to introduce below enhancements:

1. Guard the cache with mutex;
2. Shard the initialization process by rank, so that each rank only handle a piece of the cache and should not have overlapping access in theory.

Copybara import of the project:

--
a6472fc by Terry Sun <[email protected]>:

enhance concurrency handling

--
356ab82 by Terry Sun <[email protected]>:

lock mutex

--
29ebb2d by Terry Sun <[email protected]>:

bring back test

--
91b911f by Terry Sun <[email protected]>:

better lock granularity

--
cc1d93a by Terry Sun <[email protected]>:

guard all accesses

Merging this change closes #17636

FUTURE_COPYBARA_INTEGRATE_REVIEW=#17636 from terryysun:terryysun/sync_fix cc1d93a
PiperOrigin-RevId: 679463524
copybara-service bot pushed a commit that referenced this pull request Sep 30, 2024
…dress sharing

Imported from GitHub PR #17636

This is a followup PR to #15144. A distributed cache is maintained when device addresses are shared across ranks. There are two issues withe the existing implementation:

1. The cache is not guarded by mutex;
2. The cache initialization process have redundant access.

These issues can cause race condition or dead lock when the progress on different ranks are very close. Consequently we need to introduce below enhancements:

1. Guard the cache with mutex;
2. Shard the initialization process by rank, so that each rank only handle a piece of the cache and should not have overlapping access in theory.

Copybara import of the project:

--
a6472fc by Terry Sun <[email protected]>:

enhance concurrency handling

--
356ab82 by Terry Sun <[email protected]>:

lock mutex

--
29ebb2d by Terry Sun <[email protected]>:

bring back test

--
91b911f by Terry Sun <[email protected]>:

better lock granularity

--
cc1d93a by Terry Sun <[email protected]>:

guard all accesses

Merging this change closes #17636

FUTURE_COPYBARA_INTEGRATE_REVIEW=#17636 from terryysun:terryysun/sync_fix cc1d93a
PiperOrigin-RevId: 679463524
copybara-service bot pushed a commit that referenced this pull request Sep 30, 2024
…dress sharing

Imported from GitHub PR #17636

This is a followup PR to #15144. A distributed cache is maintained when device addresses are shared across ranks. There are two issues withe the existing implementation:

1. The cache is not guarded by mutex;
2. The cache initialization process have redundant access.

These issues can cause race condition or dead lock when the progress on different ranks are very close. Consequently we need to introduce below enhancements:

1. Guard the cache with mutex;
2. Shard the initialization process by rank, so that each rank only handle a piece of the cache and should not have overlapping access in theory.

Copybara import of the project:

--
a6472fc by Terry Sun <[email protected]>:

enhance concurrency handling

--
356ab82 by Terry Sun <[email protected]>:

lock mutex

--
29ebb2d by Terry Sun <[email protected]>:

bring back test

--
91b911f by Terry Sun <[email protected]>:

better lock granularity

--
cc1d93a by Terry Sun <[email protected]>:

guard all accesses

Merging this change closes #17636

FUTURE_COPYBARA_INTEGRATE_REVIEW=#17636 from terryysun:terryysun/sync_fix cc1d93a
PiperOrigin-RevId: 679463524
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Sep 30, 2024
…dress sharing

Imported from GitHub PR openxla/xla#17636

This is a followup PR to openxla/xla#15144. A distributed cache is maintained when device addresses are shared across ranks. There are two issues withe the existing implementation:

1. The cache is not guarded by mutex;
2. The cache initialization process have redundant access.

These issues can cause race condition or dead lock when the progress on different ranks are very close. Consequently we need to introduce below enhancements:

1. Guard the cache with mutex;
2. Shard the initialization process by rank, so that each rank only handle a piece of the cache and should not have overlapping access in theory.

Copybara import of the project:

--
a6472fc75fd0411bd8e65f27082e21e9a946ab17 by Terry Sun <[email protected]>:

enhance concurrency handling

--
356ab824b95d66c793e361882e95d70689759ffd by Terry Sun <[email protected]>:

lock mutex

--
29ebb2de64711bf4b4a08cf1593317228b56f825 by Terry Sun <[email protected]>:

bring back test

--
91b911f0aaac0e590636a82956b464436e94ef9f by Terry Sun <[email protected]>:

better lock granularity

--
cc1d93a5f1032a205473961b2c2d3e14bee3a9c6 by Terry Sun <[email protected]>:

guard all accesses

Merging this change closes #17636

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#17636 from terryysun:terryysun/sync_fix cc1d93a5f1032a205473961b2c2d3e14bee3a9c6
PiperOrigin-RevId: 679463524
copybara-service bot pushed a commit that referenced this pull request Sep 30, 2024
…dress sharing

Imported from GitHub PR #17636

This is a followup PR to #15144. A distributed cache is maintained when device addresses are shared across ranks. There are two issues withe the existing implementation:

1. The cache is not guarded by mutex;
2. The cache initialization process have redundant access.

These issues can cause race condition or dead lock when the progress on different ranks are very close. Consequently we need to introduce below enhancements:

1. Guard the cache with mutex;
2. Shard the initialization process by rank, so that each rank only handle a piece of the cache and should not have overlapping access in theory.

Copybara import of the project:

--
a6472fc by Terry Sun <[email protected]>:

enhance concurrency handling

--
356ab82 by Terry Sun <[email protected]>:

lock mutex

--
29ebb2d by Terry Sun <[email protected]>:

bring back test

--
91b911f by Terry Sun <[email protected]>:

better lock granularity

--
cc1d93a by Terry Sun <[email protected]>:

guard all accesses

Merging this change closes #17636

FUTURE_COPYBARA_INTEGRATE_REVIEW=#17636 from terryysun:terryysun/sync_fix cc1d93a
PiperOrigin-RevId: 679463524
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Sep 30, 2024
…dress sharing

Imported from GitHub PR openxla/xla#17636

This is a followup PR to openxla/xla#15144. A distributed cache is maintained when device addresses are shared across ranks. There are two issues withe the existing implementation:

1. The cache is not guarded by mutex;
2. The cache initialization process have redundant access.

These issues can cause race condition or dead lock when the progress on different ranks are very close. Consequently we need to introduce below enhancements:

1. Guard the cache with mutex;
2. Shard the initialization process by rank, so that each rank only handle a piece of the cache and should not have overlapping access in theory.

Copybara import of the project:

--
a6472fc75fd0411bd8e65f27082e21e9a946ab17 by Terry Sun <[email protected]>:

enhance concurrency handling

--
356ab824b95d66c793e361882e95d70689759ffd by Terry Sun <[email protected]>:

lock mutex

--
29ebb2de64711bf4b4a08cf1593317228b56f825 by Terry Sun <[email protected]>:

bring back test

--
91b911f0aaac0e590636a82956b464436e94ef9f by Terry Sun <[email protected]>:

better lock granularity

--
cc1d93a5f1032a205473961b2c2d3e14bee3a9c6 by Terry Sun <[email protected]>:

guard all accesses

Merging this change closes #17636

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#17636 from terryysun:terryysun/sync_fix cc1d93a5f1032a205473961b2c2d3e14bee3a9c6
PiperOrigin-RevId: 679463524
copybara-service bot pushed a commit that referenced this pull request Oct 1, 2024
…dress sharing

Imported from GitHub PR #17636

This is a followup PR to #15144. A distributed cache is maintained when device addresses are shared across ranks. There are two issues withe the existing implementation:

1. The cache is not guarded by mutex;
2. The cache initialization process have redundant access.

These issues can cause race condition or dead lock when the progress on different ranks are very close. Consequently we need to introduce below enhancements:

1. Guard the cache with mutex;
2. Shard the initialization process by rank, so that each rank only handle a piece of the cache and should not have overlapping access in theory.

Copybara import of the project:

--
a6472fc by Terry Sun <[email protected]>:

enhance concurrency handling

--
356ab82 by Terry Sun <[email protected]>:

lock mutex

--
29ebb2d by Terry Sun <[email protected]>:

bring back test

--
91b911f by Terry Sun <[email protected]>:

better lock granularity

--
cc1d93a by Terry Sun <[email protected]>:

guard all accesses

Merging this change closes #17636

FUTURE_COPYBARA_INTEGRATE_REVIEW=#17636 from terryysun:terryysun/sync_fix cc1d93a
PiperOrigin-RevId: 679464553
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Oct 1, 2024
…dress sharing

Imported from GitHub PR openxla/xla#17636

This is a followup PR to openxla/xla#15144. A distributed cache is maintained when device addresses are shared across ranks. There are two issues withe the existing implementation:

1. The cache is not guarded by mutex;
2. The cache initialization process have redundant access.

These issues can cause race condition or dead lock when the progress on different ranks are very close. Consequently we need to introduce below enhancements:

1. Guard the cache with mutex;
2. Shard the initialization process by rank, so that each rank only handle a piece of the cache and should not have overlapping access in theory.

Copybara import of the project:

--
a6472fc75fd0411bd8e65f27082e21e9a946ab17 by Terry Sun <[email protected]>:

enhance concurrency handling

--
356ab824b95d66c793e361882e95d70689759ffd by Terry Sun <[email protected]>:

lock mutex

--
29ebb2de64711bf4b4a08cf1593317228b56f825 by Terry Sun <[email protected]>:

bring back test

--
91b911f0aaac0e590636a82956b464436e94ef9f by Terry Sun <[email protected]>:

better lock granularity

--
cc1d93a5f1032a205473961b2c2d3e14bee3a9c6 by Terry Sun <[email protected]>:

guard all accesses

Merging this change closes #17636

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#17636 from terryysun:terryysun/sync_fix cc1d93a5f1032a205473961b2c2d3e14bee3a9c6
PiperOrigin-RevId: 679464553
copybara-service bot pushed a commit that referenced this pull request Oct 1, 2024
…dress sharing

Imported from GitHub PR #17636

This is a followup PR to #15144. A distributed cache is maintained when device addresses are shared across ranks. There are two issues withe the existing implementation:

1. The cache is not guarded by mutex;
2. The cache initialization process have redundant access.

These issues can cause race condition or dead lock when the progress on different ranks are very close. Consequently we need to introduce below enhancements:

1. Guard the cache with mutex;
2. Shard the initialization process by rank, so that each rank only handle a piece of the cache and should not have overlapping access in theory.

Copybara import of the project:

--
a6472fc by Terry Sun <[email protected]>:

enhance concurrency handling

--
356ab82 by Terry Sun <[email protected]>:

lock mutex

--
29ebb2d by Terry Sun <[email protected]>:

bring back test

--
91b911f by Terry Sun <[email protected]>:

better lock granularity

--
cc1d93a by Terry Sun <[email protected]>:

guard all accesses

Merging this change closes #17636

FUTURE_COPYBARA_INTEGRATE_REVIEW=#17636 from terryysun:terryysun/sync_fix cc1d93a
PiperOrigin-RevId: 679464553
copybara-service bot pushed a commit that referenced this pull request Oct 1, 2024
…dress sharing

Imported from GitHub PR #17636

This is a followup PR to #15144. A distributed cache is maintained when device addresses are shared across ranks. There are two issues withe the existing implementation:

1. The cache is not guarded by mutex;
2. The cache initialization process have redundant access.

These issues can cause race condition or dead lock when the progress on different ranks are very close. Consequently we need to introduce below enhancements:

1. Guard the cache with mutex;
2. Shard the initialization process by rank, so that each rank only handle a piece of the cache and should not have overlapping access in theory.

Copybara import of the project:

--
a6472fc by Terry Sun <[email protected]>:

enhance concurrency handling

--
356ab82 by Terry Sun <[email protected]>:

lock mutex

--
29ebb2d by Terry Sun <[email protected]>:

bring back test

--
91b911f by Terry Sun <[email protected]>:

better lock granularity

--
cc1d93a by Terry Sun <[email protected]>:

guard all accesses

Merging this change closes #17636

FUTURE_COPYBARA_INTEGRATE_REVIEW=#17636 from terryysun:terryysun/sync_fix cc1d93a
PiperOrigin-RevId: 679464553
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Oct 1, 2024
…dress sharing

Imported from GitHub PR openxla/xla#17636

This is a followup PR to openxla/xla#15144. A distributed cache is maintained when device addresses are shared across ranks. There are two issues withe the existing implementation:

1. The cache is not guarded by mutex;
2. The cache initialization process have redundant access.

These issues can cause race condition or dead lock when the progress on different ranks are very close. Consequently we need to introduce below enhancements:

1. Guard the cache with mutex;
2. Shard the initialization process by rank, so that each rank only handle a piece of the cache and should not have overlapping access in theory.

Copybara import of the project:

--
a6472fc75fd0411bd8e65f27082e21e9a946ab17 by Terry Sun <[email protected]>:

enhance concurrency handling

--
356ab824b95d66c793e361882e95d70689759ffd by Terry Sun <[email protected]>:

lock mutex

--
29ebb2de64711bf4b4a08cf1593317228b56f825 by Terry Sun <[email protected]>:

bring back test

--
91b911f0aaac0e590636a82956b464436e94ef9f by Terry Sun <[email protected]>:

better lock granularity

--
cc1d93a5f1032a205473961b2c2d3e14bee3a9c6 by Terry Sun <[email protected]>:

guard all accesses

Merging this change closes #17636

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#17636 from terryysun:terryysun/sync_fix cc1d93a5f1032a205473961b2c2d3e14bee3a9c6
PiperOrigin-RevId: 679464553
copybara-service bot pushed a commit that referenced this pull request Oct 1, 2024
…dress sharing

Imported from GitHub PR #17636

This is a followup PR to #15144. A distributed cache is maintained when device addresses are shared across ranks. There are two issues withe the existing implementation:

1. The cache is not guarded by mutex;
2. The cache initialization process have redundant access.

These issues can cause race condition or dead lock when the progress on different ranks are very close. Consequently we need to introduce below enhancements:

1. Guard the cache with mutex;
2. Shard the initialization process by rank, so that each rank only handle a piece of the cache and should not have overlapping access in theory.

Copybara import of the project:

--
a6472fc by Terry Sun <[email protected]>:

enhance concurrency handling

--
356ab82 by Terry Sun <[email protected]>:

lock mutex

--
29ebb2d by Terry Sun <[email protected]>:

bring back test

--
91b911f by Terry Sun <[email protected]>:

better lock granularity

--
cc1d93a by Terry Sun <[email protected]>:

guard all accesses

Merging this change closes #17636

FUTURE_COPYBARA_INTEGRATE_REVIEW=#17636 from terryysun:terryysun/sync_fix cc1d93a
PiperOrigin-RevId: 679464553
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Oct 1, 2024
…dress sharing

Imported from GitHub PR openxla/xla#17636

This is a followup PR to openxla/xla#15144. A distributed cache is maintained when device addresses are shared across ranks. There are two issues withe the existing implementation:

1. The cache is not guarded by mutex;
2. The cache initialization process have redundant access.

These issues can cause race condition or dead lock when the progress on different ranks are very close. Consequently we need to introduce below enhancements:

1. Guard the cache with mutex;
2. Shard the initialization process by rank, so that each rank only handle a piece of the cache and should not have overlapping access in theory.

Copybara import of the project:

--
a6472fc75fd0411bd8e65f27082e21e9a946ab17 by Terry Sun <[email protected]>:

enhance concurrency handling

--
356ab824b95d66c793e361882e95d70689759ffd by Terry Sun <[email protected]>:

lock mutex

--
29ebb2de64711bf4b4a08cf1593317228b56f825 by Terry Sun <[email protected]>:

bring back test

--
91b911f0aaac0e590636a82956b464436e94ef9f by Terry Sun <[email protected]>:

better lock granularity

--
cc1d93a5f1032a205473961b2c2d3e14bee3a9c6 by Terry Sun <[email protected]>:

guard all accesses

Merging this change closes #17636

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#17636 from terryysun:terryysun/sync_fix cc1d93a5f1032a205473961b2c2d3e14bee3a9c6
PiperOrigin-RevId: 679464553
copybara-service bot pushed a commit that referenced this pull request Oct 1, 2024
…dress sharing

Imported from GitHub PR #17636

This is a followup PR to #15144. A distributed cache is maintained when device addresses are shared across ranks. There are two issues withe the existing implementation:

1. The cache is not guarded by mutex;
2. The cache initialization process have redundant access.

These issues can cause race condition or dead lock when the progress on different ranks are very close. Consequently we need to introduce below enhancements:

1. Guard the cache with mutex;
2. Shard the initialization process by rank, so that each rank only handle a piece of the cache and should not have overlapping access in theory.

Copybara import of the project:

--
a6472fc by Terry Sun <[email protected]>:

enhance concurrency handling

--
356ab82 by Terry Sun <[email protected]>:

lock mutex

--
29ebb2d by Terry Sun <[email protected]>:

bring back test

--
91b911f by Terry Sun <[email protected]>:

better lock granularity

--
cc1d93a by Terry Sun <[email protected]>:

guard all accesses

Merging this change closes #17636

FUTURE_COPYBARA_INTEGRATE_REVIEW=#17636 from terryysun:terryysun/sync_fix cc1d93a
PiperOrigin-RevId: 679464553
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Oct 1, 2024
…dress sharing

Imported from GitHub PR openxla/xla#17636

This is a followup PR to openxla/xla#15144. A distributed cache is maintained when device addresses are shared across ranks. There are two issues withe the existing implementation:

1. The cache is not guarded by mutex;
2. The cache initialization process have redundant access.

These issues can cause race condition or dead lock when the progress on different ranks are very close. Consequently we need to introduce below enhancements:

1. Guard the cache with mutex;
2. Shard the initialization process by rank, so that each rank only handle a piece of the cache and should not have overlapping access in theory.

Copybara import of the project:

--
a6472fc75fd0411bd8e65f27082e21e9a946ab17 by Terry Sun <[email protected]>:

enhance concurrency handling

--
356ab824b95d66c793e361882e95d70689759ffd by Terry Sun <[email protected]>:

lock mutex

--
29ebb2de64711bf4b4a08cf1593317228b56f825 by Terry Sun <[email protected]>:

bring back test

--
91b911f0aaac0e590636a82956b464436e94ef9f by Terry Sun <[email protected]>:

better lock granularity

--
cc1d93a5f1032a205473961b2c2d3e14bee3a9c6 by Terry Sun <[email protected]>:

guard all accesses

Merging this change closes #17636

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#17636 from terryysun:terryysun/sync_fix cc1d93a5f1032a205473961b2c2d3e14bee3a9c6
PiperOrigin-RevId: 679464553
copybara-service bot pushed a commit that referenced this pull request Oct 1, 2024
…dress sharing

Imported from GitHub PR #17636

This is a followup PR to #15144. A distributed cache is maintained when device addresses are shared across ranks. There are two issues withe the existing implementation:

1. The cache is not guarded by mutex;
2. The cache initialization process have redundant access.

These issues can cause race condition or dead lock when the progress on different ranks are very close. Consequently we need to introduce below enhancements:

1. Guard the cache with mutex;
2. Shard the initialization process by rank, so that each rank only handle a piece of the cache and should not have overlapping access in theory.

Copybara import of the project:

--
a6472fc by Terry Sun <[email protected]>:

enhance concurrency handling

--
356ab82 by Terry Sun <[email protected]>:

lock mutex

--
29ebb2d by Terry Sun <[email protected]>:

bring back test

--
91b911f by Terry Sun <[email protected]>:

better lock granularity

--
cc1d93a by Terry Sun <[email protected]>:

guard all accesses

Merging this change closes #17636

COPYBARA_INTEGRATE_REVIEW=#17636 from terryysun:terryysun/sync_fix cc1d93a
PiperOrigin-RevId: 681120763
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Oct 1, 2024
…dress sharing

Imported from GitHub PR openxla/xla#17636

This is a followup PR to openxla/xla#15144. A distributed cache is maintained when device addresses are shared across ranks. There are two issues withe the existing implementation:

1. The cache is not guarded by mutex;
2. The cache initialization process have redundant access.

These issues can cause race condition or dead lock when the progress on different ranks are very close. Consequently we need to introduce below enhancements:

1. Guard the cache with mutex;
2. Shard the initialization process by rank, so that each rank only handle a piece of the cache and should not have overlapping access in theory.

Copybara import of the project:

--
a6472fc75fd0411bd8e65f27082e21e9a946ab17 by Terry Sun <[email protected]>:

enhance concurrency handling

--
356ab824b95d66c793e361882e95d70689759ffd by Terry Sun <[email protected]>:

lock mutex

--
29ebb2de64711bf4b4a08cf1593317228b56f825 by Terry Sun <[email protected]>:

bring back test

--
91b911f0aaac0e590636a82956b464436e94ef9f by Terry Sun <[email protected]>:

better lock granularity

--
cc1d93a5f1032a205473961b2c2d3e14bee3a9c6 by Terry Sun <[email protected]>:

guard all accesses

Merging this change closes #17636

PiperOrigin-RevId: 681120763
@loislo
Copy link
Member

loislo commented Oct 2, 2024

From time to time a test xla/tests/collective_ops_e2e_test.cc:486 xla::(anonymous namespace)::CollectiveOpsTestE2E_AsyncAllToAllMemCpy_Test::TestBody() is failing due to a thread gets stuck.

This is the stacktrace for this thread.

E1002 05:16:35.729531 5068 rendezvous.cc:43] Stack for thread tf_replicas/5068:
@ clock_gettime
@ third_party_gpus_cuda_compat_cuda_compat_data_o_tmpdir_filewrapper_s0_start
@ third_party_gpus_cuda_compat_cuda_compat_data_o_tmpdir_filewrapper_s0_start
@ third_party_gpus_cuda_compat_cuda_compat_data_o_tmpdir_filewrapper_s0_start
@ third_party_gpus_cuda_compat_cuda_compat_data_o_tmpdir_filewrapper_s0_start
@ third_party_gpus_cuda_compat_cuda_compat_data_o_tmpdir_filewrapper_s0_start
@ xla/stream_executor/cuda/cuda_driver.cc:1358 stream_executor::gpu::GpuDriver::SynchronizeStream()
@ xla/stream_executor/cuda/cuda_executor.cc:528 stream_executor::gpu::CudaExecutor::BlockHostUntilDone()
@ xla/stream_executor/stream_common.cc:148 stream_executor::StreamCommon::BlockHostUntilDone()
@ xla/service/gpu/runtime/nccl_all_to_all_thunk.cc:313 xla::gpu::RunMemCpyAllToAll()
@ xla/service/gpu/runtime/nccl_all_to_all_thunk.cc:189 xla::gpu::NcclAllToAllStartThunk::RunNcclCollective()
@ xla/service/gpu/runtime/nccl_collective_thunk.cc:447 xla::gpu::NcclCollectiveThunk::ExecuteOnStream()
@ xla/service/gpu/runtime/sequential_thunk.cc:81 xla::gpu::SequentialThunk::ExecuteOnStream()
@ xla/service/gpu/gpu_executable.cc:481 xla::gpu::(anonymous namespace)::ExecuteThunks()
@ xla/service/gpu/gpu_executable.cc:1011 xla::gpu::GpuExecutable::ExecuteAsyncOnStreamImpl()
@ xla/service/gpu/gpu_executable.cc:798 xla::gpu::GpuExecutable::ExecuteAsyncOnStream()
@ xla/service/executable.cc:84 xla::Executable::ExecuteOnStream()
@ xla/service/hlo_runner.cc:607 xla::HloRunner::ExecuteReplicated()::$_0::operator()()::{lambda()#2}::operator()()
@ Thread::ThreadBody()

We are waiting this thread from the other one that triggers the watchdog code.
@ xla/service/rendezvous.cc:83 xla::internal::AwaitAndLogIfStuck()
@ ./xla/service/rendezvous.h:307 xla::RendezvousSingle<>()
@ ./xla/service/rendezvous.h:336 xla::RendezvousSingle<>()
@ ./xla/service/rendezvous.h:361 xla::RendezvousSingle<>()
@ xla/service/gpu/runtime/nccl_collective_thunk.cc:486 xla::gpu::NcclCollectiveThunk::ExecuteOnStream()
@ xla/service/gpu/runtime/sequential_thunk.cc:81 xla::gpu::SequentialThunk::ExecuteOnStream()
@ xla/service/gpu/gpu_executable.cc:481 xla::gpu::(anonymous namespace)::ExecuteThunks()
@ xla/service/gpu/gpu_executable.cc:1011 xla::gpu::GpuExecutable::ExecuteAsyncOnStreamImpl()
@ xla/service/gpu/gpu_executable.cc:798 xla::gpu::GpuExecutable::ExecuteAsyncOnStream()
@ xla/service/executable.cc:84 xla::Executable::ExecuteOnStream()
@ xla/service/hlo_runner.cc:607 xla::HloRunner::ExecuteReplicated()::$_0::operator()()::{lambda()#2}::operator()()

@loislo
Copy link
Member

loislo commented Oct 2, 2024

@terryysun it is failing on the machine with p100 gpus

@loislo loislo reopened this Oct 2, 2024
@terryysun
Copy link
Contributor Author

terryysun commented Oct 2, 2024

@terryysun it is failing on the machine with p100 gpus

thanks! will file a followup pr to resolve P100 failures and make the flag default, right now it's an optional flag so should not affect anything. it's verified on H100s. working on reproducing the P100 failure on our side

@terryysun terryysun closed this Oct 10, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants