Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve aggressive factor propagation strategy in Shardy. There are two main differences from BasicFactorPropagation. #28

Merged
merged 1 commit into from
Jul 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion shardy/dialect/sdy/transforms/propagation/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -250,9 +250,10 @@ cc_test(
srcs = ["aggressive_factor_propagation_test.cc"],
deps = [
":aggressive_factor_propagation",
":basic_factor_propagation",
":factor_propagation",
":sharding_projection",
":testing_utils",
":utils",
"//shardy/dialect/sdy/ir:dialect",
"@com_google_googletest//:gtest_main",
"@llvm-project//llvm:Support",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,146 +23,99 @@ limitations under the License.
#include "mlir/IR/Value.h"
#include "mlir/Support/LLVM.h"
#include "shardy/dialect/sdy/ir/dialect.h"
#include "shardy/dialect/sdy/transforms/propagation/basic_factor_propagation.h"
#include "shardy/dialect/sdy/transforms/propagation/factor_propagation.h"
#include "shardy/dialect/sdy/transforms/propagation/sharding_projection.h"

namespace mlir {
namespace sdy {

AxesPerFactor
AggressiveFactorPropagation::getCompatibleMajorShardingAxesForAllFactors(
const ShardingProjection& projection, PropagationDirection direction,
namespace {

bool updateTensorSharding(ShardingProjection& projection, int64_t tensorIndex,
int64_t factorIndex, ArrayRef<AxisRefAttr> newAxes) {
if (tensorIndex < projection.getNumOperands()) {
return projection.updateOperandSharding(tensorIndex, factorIndex, newAxes);
}
return projection.updateResultSharding(
tensorIndex - projection.getNumOperands(), factorIndex, newAxes);
}

} // namespace

UpdateTensorShardings AggressiveFactorPropagation::propagateFactorShardings(
ShardingProjection& projection, PropagationDirection direction,
ArrayRef<int64_t> factorSizes, MeshAttr mesh, Operation* op,
bool conservativePropagation) const {
UpdateTensorShardings result{
.updateOperands = BitVector(projection.getNumOperands()),
.updateResults = BitVector(projection.getNumResults())};
if (direction == PropagationDirection::NONE) {
return AxesPerFactor(factorSizes.size());
return result;
}

// Finds the compatible major axes ignoring conflicts.
AxesPerFactor result;
result.reserve(factorSizes.size());
// Find the compatible major axes ignoring conflicts.
SmallVector<SmallVector<AxisRefAttr>> axesPerFactor;
axesPerFactor.reserve(factorSizes.size());
bool allElementsAreEmpty = true;
for (int64_t i = 0; i < factorSizes.size(); ++i) {
result.push_back(getCompatibleMajorAxes(projection, i, direction, op));
SmallVector<AxisRefAttr>& axes = axesPerFactor.emplace_back(
getCompatibleMajorAxes(projection, i, direction, op));
if (!axes.empty()) {
allElementsAreEmpty = false;
}
}
if (allElementsAreEmpty) {
return result;
}

// Removes the conflicts within every single factor. This strategy and
// `BasicFactorPropagation` handles conflicts within a factor in the same way.
for (const TensorFactorShardings& tensorFactorShardings :
llvm::concat<const TensorFactorShardings>(projection.getOperands(),
projection.getResults())) {
for (const auto& [factorIndex, factorSharding] :
tensorFactorShardings.factorIndexToSharding) {
// The propagation on each tensor is independent. This strategy can propagate
// different shardings to different tensors along the same factor. Examples
// are provided in the docstring of this class.
for (const auto& [tensorIndex, tensorFactorShardings] :
llvm::enumerate(llvm::concat<const TensorFactorShardings>(
projection.getOperands(), projection.getResults()))) {
// Propagate the axes got in Step 1, and resolve conflicts within a factor.
FactorIndexToSharding newSharding =
tensorFactorShardings.factorIndexToSharding;
BitVector factorUpdated(factorSizes.size());
for (auto& [factorIndex, factorSharding] : newSharding) {
SmallVector<AxisRefAttr> newAxes = axesPerFactor[factorIndex];
truncateAxesByRemovingConflicts(
result[factorIndex],
newAxes,
[&, factorIndex = factorIndex, &factorSharding = factorSharding](
AxisRefAttr axisRef, int64_t shardedSize) {
return compatiblePrefixNoConflictsWithinFactor(
axisRef, tensorFactorShardings.replicatedAxes, factorSharding,
shardedSize, factorSizes[factorIndex]);
},
mesh, conservativePropagation);
if (shouldUpdate(factorSharding.axisRefs, newAxes)) {
factorSharding.axisRefs = newAxes;
factorUpdated.set(factorIndex);
}
}
}

// Removes the conflicts across factors, where this strategy and
// `BasicFactorPropagation` diverge.
//
// With `BasicFactorPropagation`, the compatible axes of a factor Fi cannot
// overlap with the existing sharding axes or the overflow axes related to all
// other factors. This criterion is considered for all tensors, no matter if
// Fi is mapped to the tensor or not. The table below shows the criterion:
//
// existing sharding axes & overflow axes new sharding axes
// factor in tensor remove overlap -
// factor not in tensor remove overlap -
//
// On the contrary, `AggressiveFactorPropagation` has the following criterion:
//
// existing sharding axes & overflow axes new sharding axes
// factor in tensor remove overlap remove overlap
// factor not in tensor - -
//
// There are two differences:
//
// 1. `BasicFactorPropagation` removes the overlap between the compatible axes
// of a factor Fi with the existing sharding axes and overflow axes in a
// tensor Tj even if Fi is not in Tj. `AggressiveFactorPropagation` does not
// remove this overlap if Fi is not in Tj. `BasicFactorPropagation` is too
// strict, since we cannot propagate sharding axes to Tj along Fi.
//
// `AggressiveFactorPropagation` cannot handle the following case if we only
// have difference #1. `-` means that the factor is not mapped to the tensor.
// After removing conflicts within factors, we will propagate "x" to T2 along
// F0 and F1 at the same time, which induces a conflict. To resolve this
// conflict, we have difference #2.
//
// F0 F1
// T0 "x" -
// T1 - "x"
// T2 ? ?
//
// 2. `AggressiveFactorPropagation` removes the overlap between compatible
// axes of a factor Fi with the potential new sharding axes of other factors
// in Tj if Fi is in Tj. Thus, it is safe to propagate the axes to Tj along Fi
// without conflicts with other factors. In the example, we will not propagate
// "x" along F0 or F1 since their potential new sharding axes overlap.
//
// The potential new sharding axes are saved in `resultSnapshot`. It is a hard
// copy since we need to handle the following case.
//
// F0 F1 F2
// T0 "x" - -
// T1 - "x" -
// T2 - - "x"
// T3 ? ? ?
//
// The `result` and `resultSnapshot` is [["x"], ["x"], ["x"]] before removing
// conflicts across factors. After removing conflicts between F0/F1 and other
// factors, `result` is [[], [], ["x"]]. When we remove conflicts between F2
// and other factors, if we use `result` as the potential new sharding axes,
// we will not remove "x" for F2 because it is no longer present in 'result'
// for F0 and F1. We have to use `resultSnapshot` to save the potential new
// sharding axes and remove "x" for F2.
const AxesPerFactor resultSnapshot = result;
for (const TensorFactorShardings& tensorFactorSharding :
llvm::concat<const TensorFactorShardings>(projection.getOperands(),
projection.getResults())) {
for (const auto& [factorIndex, factorSharding] :
tensorFactorSharding.factorIndexToSharding) {
// Resolve conflicts (overlapping sharding axes) between factors.
bool tensorUpdated = false;
for (const int64_t factorIndex : factorUpdated.set_bits()) {
SmallVector<AxisRefAttr> newAxes = newSharding[factorIndex].axisRefs;
truncateAxesByRemovingConflicts(
result[factorIndex],
newAxes,
[&, factorIndex = factorIndex](AxisRefAttr axisRef, int64_t) {
return compatiblePrefixNoConflictsAcrossFactors(
axisRef, tensorFactorSharding.factorIndexToSharding,
factorIndex, resultSnapshot);
axisRef, newSharding, factorIndex);
},
mesh, conservativePropagation);
tensorUpdated |=
updateTensorSharding(projection, tensorIndex, factorIndex, newAxes);
}
}

return result;
}

UpdateTensorShardings AggressiveFactorPropagation::propagateFactorShardings(
ShardingProjection& projection, PropagationDirection direction,
ArrayRef<int64_t> factorSizes, MeshAttr mesh, Operation* op,
bool conservativePropagation) const {
UpdateTensorShardings result{
.updateOperands = BitVector(projection.getNumOperands()),
.updateResults = BitVector(projection.getNumResults())};

// We get the compatible major sharding axes for all factors.
AxesPerFactor axesPerFactor = getCompatibleMajorShardingAxesForAllFactors(
projection, direction, factorSizes, mesh, op, conservativePropagation);

for (auto [factorIndex, axesToPropagate] : llvm::enumerate(axesPerFactor)) {
// Update all shardings along this factor if possible.
auto [updateOperandForFactor, updateResultForFactor] =
projection.updateSharding(factorIndex, axesToPropagate);

result.updateOperands |= updateOperandForFactor;
result.updateResults |= updateResultForFactor;
if (tensorIndex < projection.getNumOperands()) {
result.updateOperands[tensorIndex] = tensorUpdated;
} else {
result.updateResults[tensorIndex - projection.getNumOperands()] =
tensorUpdated;
}
}

return result;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,40 +22,53 @@ limitations under the License.
#include "mlir/Support/LLVM.h"
#include "shardy/dialect/sdy/ir/dialect.h"
#include "shardy/dialect/sdy/transforms/propagation/basic_factor_propagation.h"
#include "shardy/dialect/sdy/transforms/propagation/factor_propagation.h"
#include "shardy/dialect/sdy/transforms/propagation/sharding_projection.h"

namespace mlir {
namespace sdy {

// An aggressive strategy of propagating sharding axes along factors.
// An aggressive strategy of propagating sharding axes along factors. There are
// two main differences from `BasicFactorPropagation`.
//
// This strategy is the same as `BasicFactorPropagation` on the conflicts within
// a factor. They are different on the conflicts across factors.
// `BasicFactorPropagation` propagates the same sharding axes to all tensors
// along a factor. This strategy can propagate different sharding axes to
// different tensors along the same factor. For example, Tensors T0, T1, T2
// contain Factor F0. T0/F0 is already sharded along ["a", "b"], and "b" is
// already used by T2 ("b" can be explicitly replicated, or it is used to shard
// another factor). `BasicFactorPropagation` propagates ["a"] to both T1/F0 and
// T2/F0, while this strategy propagates ["a", "b"] to T1/F0 and ["a"] to T2/F0,
// respectively. If T2/F0 is closed, `BasicFactorPropagation` propagates
// nothing, while this strategy propagates nothing to T2/F0 and still propagates
// ["a", "b"] to T1/F0.
//
// `BasicFactorPropagation` considers the conflicts across factors with a strict
// criterion. The result cannot overlap with the sharded axes or overflow axes
// related to all other factors. This aggressive strategy ignores "fake
// conflicts", which are propagation choices that can co-exist. This aggressive
// strategy ensures that the resultant axes can be propagated to all tensors
// containing the factor. Several examples of fake conflicts:
// `BasicFactorPropagation` is conservative in terms of conflicts across
// factors. The overlapped axis between factors cannot be propagated. This
// strategy is more aggressive by allowing the overlapped axis being propagated
// along different factors if there is no overlapped axis in the result
// shardings.
//
// 1. An axis is in factors Fi and Fj. If it is infeasible to propagate that
// axis along factor Fi, we may propagate that axis along factor Fj if all the
// destination tensors have not used that axis.
// Let us take C = dot(A, B) as an example. F0 is the factor corresponding to a
// non-contracting dimension of A. F1 corresponds to a non-contracting dimension
// of B. F2 corresponds to a contracting dimension. "-" means that the tensor
// does not contain the factor.
//
// 2. Two factors Fi and Fj do not co-exist in any tensor, so they never
// interfere with each other. If Fi and Fj are sharded along the same axis, we
// can propagate that axis along both factors.
// F0 F1 F2
// A "a" -
// B -
// C "a" -
// Case 1. Fake conflict. `BasicFactorPropagation` propagates nothing, while
// this strategy propagates "a" to B/F1.
//
// Although fake conflicts can co-exist without inference, we may still need to
// all-gather some tensors.
// F0 F1 F2
// A "a" -
// B - "a"
// C -
// Case 2. Real conflict. Both `BasicFactorPropagation` and this strategy
// propagate nothing. We can propagate "a" to C/F0 or C/F1, which is illegal
// since "a" cannot be used twice in C.
class AggressiveFactorPropagation : public BasicFactorPropagation {
public:
AxesPerFactor getCompatibleMajorShardingAxesForAllFactors(
const ShardingProjection& projection, PropagationDirection direction,
ArrayRef<int64_t> factorSizes, MeshAttr mesh, Operation* op,
bool conservativePropagation) const override;

UpdateTensorShardings propagateFactorShardings(
ShardingProjection& projection, PropagationDirection direction,
ArrayRef<int64_t> factorSizes, MeshAttr mesh, Operation* op,
Expand Down
Loading