Skip to content

Commit

Permalink
Dgrad ReduceScatter overlap fix (#1088)
Browse files Browse the repository at this point in the history
* DGRAD-RS overlap bug fix

This PR fixes a bug in enabling DGRAD-RS overlap by adding the
layer to the correct method list. Previously, the RS-DGRAD overlap
layer was incorrectly added to pipeline method list even if
ring_exchange method is specified in config.

Signed-off-by: Vasudevan Rengasamy <[email protected]>

* Bug fix for ring_exchange ReduceScatter

ring_exchange RS uses main_stream for last GEMM chunk. But the
send/recv streams wait for stream_compute during last chunk.

Signed-off-by: Vasudevan Rengasamy <[email protected]>

---------

Signed-off-by: Vasudevan Rengasamy <[email protected]>
Co-authored-by: Tim Moon <[email protected]>
Co-authored-by: Kirthi Shankar Sivamani <[email protected]>
  • Loading branch information
3 people authored Aug 13, 2024
1 parent b484038 commit ec49a52
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 12 deletions.
19 changes: 8 additions & 11 deletions transformer_engine/pytorch/csrc/comm_gemm_overlap.h
Original file line number Diff line number Diff line change
Expand Up @@ -1205,11 +1205,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
torch::Tensor workspace_chunk =
torch::from_blob(workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk,
{workspace_size_chunk}, workspace.options());
if (i == _tp_size - 1) {
at::cuda::setCurrentCUDAStream(stream_main);
} else {
at::cuda::setCurrentCUDAStream(_stream_compute[i % _stream_compute.size()]);
}
at::cuda::setCurrentCUDAStream(_stream_compute[i % _stream_compute.size()]);
te_gemm(A, A_scale_inverse, A_type, transa, input_b_chunk, B_scale_inverse, B_type, transb,
_ubufs[i], D_scale, D_type, D_amax, bias, bias_type, pre_gelu_out, grad,
workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator, _math_sms);
Expand All @@ -1230,6 +1226,13 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
recv_rank, (cudaStream_t)_stream_recv);
}
}
at::cuda::setCurrentCUDAStream(stream_main);
for (size_t i = 0; i < _stream_compute.size(); i++) {
NVTE_CHECK_CUDA(cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[i]));
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_compute, 0));
}
NVTE_CHECK_CUDA(cudaEventRecord(_stop_send, (cudaStream_t)_stream_send));
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_send, 0));
NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t)_stream_recv));
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_recv, 0));

Expand All @@ -1248,12 +1251,6 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
reduce_buf_ptr, {_tp_size, _ubufs[0].size(0), _ubufs[0].size(1)}, _ubuf.options());
torch::sum_out(rs_output, reduce_buf, 0);
}
for (size_t i = 0; i < _stream_compute.size(); i++) {
NVTE_CHECK_CUDA(cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[i]));
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_compute, 0));
}
NVTE_CHECK_CUDA(cudaEventRecord(_stop_send, (cudaStream_t)_stream_send));
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_send, 0));
_ub_comm->sms = ori_sms;
}

Expand Down
4 changes: 3 additions & 1 deletion transformer_engine/pytorch/module/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,9 @@ def add_ub(
layers_reduce_scatter_overlap.remove(wgrad_name)
layers_all_gather_overlap.remove(name)
layers_reduce_scatter_overlap.append(name)
methods["pipeline"].append(name)
methods["bulk"].remove(name)
new_method = ub_cfgs[name]["method"]
methods[new_method].append(name)

for name in methods["ring_exchange"] + methods["pipeline"] + methods["bulk"]:
ub_cfg = get_default_config(name)
Expand Down

0 comments on commit ec49a52

Please sign in to comment.