From 9bceebfbe2c1178f95764a56102c4d6bf3541d7d Mon Sep 17 00:00:00 2001 From: Alex Vasile <48962821+Alex-Vasile@users.noreply.github.com> Date: Tue, 11 Feb 2025 07:47:05 -0800 Subject: [PATCH] Reduce buffer copying by using one device to reduce and distribute --- sharktank/sharktank/ops/sharded_impls.py | 28 +++++++++++++----------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/sharktank/sharktank/ops/sharded_impls.py b/sharktank/sharktank/ops/sharded_impls.py index dc66b39cd..e3f3b0d64 100644 --- a/sharktank/sharktank/ops/sharded_impls.py +++ b/sharktank/sharktank/ops/sharded_impls.py @@ -60,20 +60,22 @@ def all_gather_split( def all_reduce_split_or_unreduced( input: Union[SplitPrimitiveTensor, UnreducedTensor], ) -> ReplicatedTensor: - # For each device move the shards to it and do a reduction. - # If we don't move first, common sub-expression elimination is free to collapse all - # reductions into one and then copy to all devices, which is not what we want. + reduced = functools.reduce( + lambda x, y: elementwise(torch.add, x, y), + [ + ( + transfer_to_logical_device(shard, 0) + if i != 0 + else barrier_on_logical_device(shard, 0) + ) + for i, shard in enumerate(input.shards) + ], + ) shards = [ - functools.reduce( - lambda x, y: elementwise(torch.add, x, y), - [ - ( - barrier_on_logical_device(shard, i) - if i == j - else transfer_to_logical_device(shard, i) - ) - for j, shard in enumerate(input.shards) - ], + ( + transfer_to_logical_device(reduced, i) + if i != 0 + else barrier_on_logical_device(reduced, 0) ) for i in range(input.shard_count) ]