Skip to content

Commit

Permalink
[CPU] Allow deepspeed.comm.inference_all_reduce in torch.compile graph (
Browse files Browse the repository at this point in the history
#5604)

This PR allows `deepspeed.comm.inference_all_reduce()` enters
torch.compile graph even it is implemented as C++ kernel in DeepSpeed.

Previous implementation register `inference_all_reduce()` C++ kernel as
pybind function so it can be called inside PyThon code. However pybind
function cannot be recognized by PyTorch so graph breaks when
`inference_all_reduce` is called.

We address issue by register `inference_all_reduce` as a PyTorch custom
op `torch.ops.deepspeed.inference_all_reduce`, so it can be built into
PyTorch graph

The output trace code from torchinductor
```
class GraphModule(torch.nn.Module):
    def forward(self, primals_1: "f32[5, 4]", primals_2: "f32[5]", primals_3: "f32[4, 4]"):
        # File: /home/gma/DeepSpeed/deepspeed/comm/torch.py:161 in inference_all_reduce, code: return torch.ops.deepspeed.inference_all_reduce_(tensor)
        inference_all_reduce: "f32[4, 4]" = torch.ops.deepspeed.inference_all_reduce.default(primals_3)

        # File: /home/gma/allreduce_graph/test_allreduce.py:33 in forward, code: return self.linear(input)
        permute: "f32[4, 5]" = torch.ops.aten.permute.default(primals_1, [1, 0]);  primals_1 = None
        addmm: "f32[4, 5]" = torch.ops.aten.addmm.default(primals_2, inference_all_reduce, permute);  primals_2 = permute = None

        # No stacktrace found for following nodes
        copy_: "f32[4, 4]" = torch.ops.aten.copy_.default(primals_3, inference_all_reduce);  primals_3 = None
        return [addmm, inference_all_reduce]
```

Note in this PR the inference_all_reduce op for CPU does not handle
multinode and FP16 data type. For FP16 data type support, we will align
with PyTorch CPU FP16 plan. For multinode, we are still looking at the
possibility to upstream oneCCL integration into PyTorch, so we are able
to get use of oneCCL for multinode tensor parallel inference with
PyTorch.

This PR is independent to
#5571. They can work
seperately or together without issue.

---------

Co-authored-by: Olatunji Ruwase <[email protected]>
Co-authored-by: Masahiro Tanaka <[email protected]>
  • Loading branch information
3 people authored Jul 15, 2024
1 parent a07a3c5 commit ec6cbb3
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 19 deletions.
99 changes: 82 additions & 17 deletions csrc/cpu/comm/shm_interface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,13 @@ void initialize(int size, int rank)
if (all_ranks_local_p) { shm_initialize(size, rank, addr_string, port_string); }
}

int get_rank(int group = 0) { return world_rank; }

int get_world_size(int group = 0) { return world_size; }
void inference_all_reduce_(torch::Tensor& data, int op);

// Success - return 0
// Fail (cannot hornor the request and need to fall back) - return -1
int inference_all_reduce(torch::Tensor& data, py::object op)
void inference_all_reduce_(torch::Tensor& data, int op)
{
if (!all_ranks_local_p) return -1;
assert(op == 0);
#ifdef DO_PROFILE
static double total_time = 0.0;
static double total_time_sq = 0.0;
Expand All @@ -67,11 +65,6 @@ int inference_all_reduce(torch::Tensor& data, py::object op)
auto start = std::chrono::system_clock::now();
#endif

static py::object ReduceOp = py::module_::import("deepspeed.comm").attr("ReduceOp");
static auto ReduceOpSum = (int)py::int_(ReduceOp.attr("SUM").attr("value"));

assert(py::int_(op.attr("value")) == ReduceOpSum);

auto numel = data.numel();

int data_size = 0;
Expand All @@ -84,7 +77,7 @@ int inference_all_reduce(torch::Tensor& data, py::object op)
default: data_type_fallback = true;
}

if (data_type_fallback) return -1;
if (data_type_fallback) return;

all_reduce_outer_loop(data, numel, data_size);

Expand All @@ -109,13 +102,85 @@ int inference_all_reduce(torch::Tensor& data, py::object op)
}
}
#endif
return 0;
return;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("initialize", &initialize, "shm initialize"); }

TORCH_LIBRARY(deepspeed, m)
{
m.def("inference_all_reduce(Tensor self) -> Tensor");
m.def("inference_all_reduce_(Tensor(a!) self) -> Tensor(a!)");
}

torch::Tensor inference_all_reduce_meta(const torch::Tensor& self_)
{
torch::Tensor result_ = torch::empty_like(self_);
return result_;
}

torch::Tensor& inference_all_reduce__meta(torch::Tensor& self_) { return self_; }

torch::Tensor& inference_all_reduce__cpu(torch::Tensor& self_)
{
TORCH_INTERNAL_ASSERT(self_.device().type() == torch::DeviceType::CPU);
torch::Tensor self_tensor = self_.contiguous();
inference_all_reduce_(self_tensor, 0);
return self_;
}

torch::Tensor inference_all_reduce_cpu(const torch::Tensor& self_)
{
torch::Tensor result = self_.clone();
inference_all_reduce__cpu(result);
return result;
}

#include <ATen/FunctionalTensorWrapper.h>
// The boilerplate functionalization logic, that teaches functionalization
// how to map x_() calls into x() calls.
// Long term, we'd like to not require users to write this logic.
// HOWEVER, if you have a custom op that is mutable,
// You will still need to write an out-of-place version of that op!
at::Tensor& inference_all_reduce__functionalization_glue(at::Tensor& x)
{
// We expect all tensor inputs to our op to be "functional tensors"
TORCH_INTERNAL_ASSERT(at::functionalization::impl::isFunctionalTensor(x));
// First, sync and unwrap and functional tensors
at::functionalization::impl::sync(x);
auto x_ = at::functionalization::impl::from_functional_tensor(x);
// Grab the dispatcher entry corresponding to the out-of-place op, "x"
static auto op_handle = c10::Dispatcher::singleton()
// specify namespace::op_name, op_overload_name
.findSchemaOrThrow("deepspeed::inference_all_reduce", "")
// Specify the C++ schema of the out-of-place op.
.typed<at::Tensor(const at::Tensor&)>();
// Next, redispatch to the out-of-place op, x() (user called x_, we call x)
at::Tensor tmp_output;
{
at::AutoDispatchSkipFunctionalize guard;
tmp_output = op_handle.call(x_);
}
// Finally, tell functionalization about this mutation.
at::functionalization::impl::replace_(x, tmp_output);
at::functionalization::impl::commit_update(x);
at::functionalization::impl::sync(x);
return x;
}

TORCH_LIBRARY_IMPL(deepspeed, CPU, m)
{
m.impl("inference_all_reduce", inference_all_reduce_cpu);
m.impl("inference_all_reduce_", inference_all_reduce__cpu);
}

TORCH_LIBRARY_IMPL(deepspeed, Meta, m)
{
m.impl("inference_all_reduce", inference_all_reduce_meta);
m.impl("inference_all_reduce_", inference_all_reduce__meta);
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
TORCH_LIBRARY_IMPL(deepspeed, Functionalize, m)
{
m.def("initialize", &initialize, "shm initialize");
m.def("get_rank", &get_rank, "get rank");
m.def("get_world_size", &get_world_size, "get world size");
m.def("inference_all_reduce", &inference_all_reduce, "low latency all_reduce implementation");
m.impl("inference_all_reduce_", inference_all_reduce__functionalization_glue);
}
5 changes: 3 additions & 2 deletions deepspeed/comm/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,11 +151,12 @@ def all_reduce(self, tensor, op=torch.distributed.ReduceOp.SUM, group=None, asyn
op = self._reduce_op(op)
return torch.distributed.all_reduce(tensor=tensor, op=op, group=group, async_op=async_op)

@compiler.disable
def inference_all_reduce(self, tensor, op, group=None):
if self.shm_comm_op == None or self.shm_comm_op.inference_all_reduce(tensor, op) == -1:
if not hasattr(torch.ops, 'deepspeed') or not hasattr(torch.ops.deepspeed, 'inference_all_reduce_'):
op = self._reduce_op(op)
return torch.distributed.all_reduce(tensor=tensor, op=op, group=group, async_op=False)
else:
return torch.ops.deepspeed.inference_all_reduce_(tensor)

@compiler.disable
def all_reduce_coalesced(self, tensors, op=torch.distributed.ReduceOp.SUM, group=None, async_op=False):
Expand Down

0 comments on commit ec6cbb3

Please sign in to comment.