Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[CPU] Allow deepspeed.comm.inference_all_reduce in torch.compile graph (
#5604) This PR allows `deepspeed.comm.inference_all_reduce()` enters torch.compile graph even it is implemented as C++ kernel in DeepSpeed. Previous implementation register `inference_all_reduce()` C++ kernel as pybind function so it can be called inside PyThon code. However pybind function cannot be recognized by PyTorch so graph breaks when `inference_all_reduce` is called. We address issue by register `inference_all_reduce` as a PyTorch custom op `torch.ops.deepspeed.inference_all_reduce`, so it can be built into PyTorch graph The output trace code from torchinductor ``` class GraphModule(torch.nn.Module): def forward(self, primals_1: "f32[5, 4]", primals_2: "f32[5]", primals_3: "f32[4, 4]"): # File: /home/gma/DeepSpeed/deepspeed/comm/torch.py:161 in inference_all_reduce, code: return torch.ops.deepspeed.inference_all_reduce_(tensor) inference_all_reduce: "f32[4, 4]" = torch.ops.deepspeed.inference_all_reduce.default(primals_3) # File: /home/gma/allreduce_graph/test_allreduce.py:33 in forward, code: return self.linear(input) permute: "f32[4, 5]" = torch.ops.aten.permute.default(primals_1, [1, 0]); primals_1 = None addmm: "f32[4, 5]" = torch.ops.aten.addmm.default(primals_2, inference_all_reduce, permute); primals_2 = permute = None # No stacktrace found for following nodes copy_: "f32[4, 4]" = torch.ops.aten.copy_.default(primals_3, inference_all_reduce); primals_3 = None return [addmm, inference_all_reduce] ``` Note in this PR the inference_all_reduce op for CPU does not handle multinode and FP16 data type. For FP16 data type support, we will align with PyTorch CPU FP16 plan. For multinode, we are still looking at the possibility to upstream oneCCL integration into PyTorch, so we are able to get use of oneCCL for multinode tensor parallel inference with PyTorch. This PR is independent to #5571. They can work seperately or together without issue. --------- Co-authored-by: Olatunji Ruwase <[email protected]> Co-authored-by: Masahiro Tanaka <[email protected]>
- Loading branch information