Skip to content

Commit

Permalink
Merge branch 'main' into add-docs
Browse files Browse the repository at this point in the history
  • Loading branch information
melissawm authored Nov 7, 2024
2 parents b3dc357 + 3890152 commit 477d319
Show file tree
Hide file tree
Showing 17 changed files with 3,261 additions and 97 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ agnostic and provide extensive user control and debuggability features. It
includes an axis-based sharding representation, a set of compiler APIs,
functionality for sharding propagation, and plans for an SPMD partitioner.

For more information see the docs directory.

## Status

Shardy is a work in progress. Currently the core dialect and c bindings are
Expand Down
3 changes: 1 addition & 2 deletions rfcs/2024-03-14-shardy-partitioner-rfc.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,7 @@ Sharding representation can either be:

Design a new **axis-based** sharding representation that is general enough to handle all existing use cases of both GSPMD and PartIR.

See for more information on the requirements.

See for more information on the requirements.

### Overview

Expand Down
1 change: 1 addition & 0 deletions shardy/dialect/sdy/transforms/export/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ cc_library(
"//shardy/dialect/sdy/transforms/common:op_properties",
"//shardy/dialect/sdy/transforms/propagation:op_sharding_rule_registry",
"//shardy/dialect/sdy/transforms/propagation:sharding_projection",
"//shardy/dialect/sdy/transforms/propagation:utils",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
Expand Down
2 changes: 1 addition & 1 deletion shardy/dialect/sdy/transforms/export/export_pipeline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ namespace sdy {

void addExportPipeline(OpPassManager& pm, StringRef dumpDirectory) {
pm.addPass(createRemoveShardingGroupsPass());
pm.addNestedPass<func::FuncOp>(createSinkDataFlowEdgesPass());
pm.addNestedPass<func::FuncOp>(createShardingConstraintToReshardPass());
pm.addNestedPass<func::FuncOp>(createSinkDataFlowEdgesPass());
pm.addNestedPass<func::FuncOp>(
createUpdateNonDivisibleInputOutputShardingsPass());
pm.addPass(mlir::sdy::createSaveModuleOpPass(dumpDirectory,
Expand Down
62 changes: 35 additions & 27 deletions shardy/dialect/sdy/transforms/export/insert_explicit_reshards.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/

#include <cassert>
#include <cstdint>
#include <optional>

#include "llvm/ADT/STLExtras.h"
Expand All @@ -29,6 +30,7 @@ limitations under the License.
#include "shardy/dialect/sdy/ir/utils.h" // IWYU pragma: keep
#include "shardy/dialect/sdy/transforms/propagation/op_sharding_rule_registry.h"
#include "shardy/dialect/sdy/transforms/propagation/sharding_projection.h"
#include "shardy/dialect/sdy/transforms/propagation/utils.h"
#include "stablehlo/dialect/StablehloOps.h" // IWYU pragma: keep

namespace mlir {
Expand Down Expand Up @@ -118,42 +120,43 @@ bool hasCompatibleFactorShardings(const ShardingProjection& projection) {
// Assumes factor shardings do not have overflow axes.
// TODO(enver): Handle the case when some factor shardings have overflow axes.
void insertExplicitReshards(Operation* op, const ShardingProjection& projection,
UpdateTensorShardings updateTensorShardings,
IRRewriter& rewriter,
OpShardingRuleAttr shardingRule, StringRef meshName,
MeshAttr mesh) {
rewriter.setInsertionPoint(op);
for (const auto& [index, operand] : llvm::enumerate(op->getOperands())) {
for (int operandIndex : updateTensorShardings.updateOperands.set_bits()) {
auto operand = op->getOperand(operandIndex);
auto newTensorSharding =
projection.getOperand(index).createTensorShardingAttr(
mesh.getContext(), shardingRule.getOperandMapping(index),
shardingRule.getFactorSizes(), meshName, mesh);
if (newTensorSharding == getSharding(operand)) {
continue;
}
projection.getOperand(operandIndex)
.createTensorShardingAttr(
mesh.getContext(), shardingRule.getOperandMapping(operandIndex),
shardingRule.getFactorSizes(), meshName, mesh);
auto reshardOp = rewriter.create<ReshardOp>(operand.getLoc(), operand,
newTensorSharding);
op->setOperand(index, reshardOp);
op->setOperand(operandIndex, reshardOp);
}

rewriter.setInsertionPointAfter(op);
for (const auto& [result, tensorFactorShardings, tensorMapping] :
llvm::zip_equal(op->getResults(), projection.getResults(),
shardingRule.getResultMappings())) {
// TODO(enver): The following logic is mostly shared between operands and
// results. Use a helper function, instead.
auto newTensorSharding = tensorFactorShardings.createTensorShardingAttr(
mesh.getContext(), tensorMapping, shardingRule.getFactorSizes(),
meshName, mesh);
if (newTensorSharding == getSharding(result)) {
continue;
}
for (int resultIndex : toSetBitsVector(updateTensorShardings.updateResults)) {
auto result = op->getResult(resultIndex);
auto newTensorSharding =
projection.getResult(resultIndex)
.createTensorShardingAttr(
mesh.getContext(), shardingRule.getResultMapping(resultIndex),
shardingRule.getFactorSizes(), meshName, mesh);
auto reshardOp = rewriter.create<ReshardOp>(result.getLoc(), result,
getSharding(result));
rewriter.replaceAllUsesExcept(result, reshardOp, reshardOp);
setSharding(result, newTensorSharding);
}
}

AxesPerFactor findCommonAxes(const ShardingProjection& projection,
int64_t numFactors) {
return projection.getGreatestCommonPrefixAxes(numFactors);
}

struct InsertExplicitReshardsPass
: public impl::InsertExplicitReshardsPassBase<InsertExplicitReshardsPass> {
using InsertExplicitReshardsPassBase::InsertExplicitReshardsPassBase;
Expand All @@ -163,6 +166,8 @@ struct InsertExplicitReshardsPass
IRRewriter rewriter(funcOp);
SymbolTable symbolTable(funcOp->getParentOfType<ModuleOp>());
// TODO(enver): Handle data flow ops.
// TODO(enver): Handle cases func op result sharding does not match the
// sharding of returned value.
funcOp.walk([&](Operation* op) {
// TODO(enver): Check if data flow ops, data flow edge op, manual
// computation op require extra check before creating sharding rule.
Expand Down Expand Up @@ -205,15 +210,18 @@ struct InsertExplicitReshardsPass
return;
}

// TODO(enver): Instead of building a new projection, update and use the
// existing one.
ShardingProjection projection = ShardingProjection::build(
shardingProjection.getGreatestCommonPrefixAxes(
shardingRule.getNumFactors()),
shardingRule);
UpdateTensorShardings updateTensorShardings(shardingRule.getNumOperands(),
shardingRule.getNumResults());
for (const auto& [index, factorAxes] : llvm::enumerate(findCommonAxes(
shardingProjection, shardingRule.getNumFactors()))) {
// TODO(enver): Add unit tests to test overflow axes are cleared after
// handling the case that some factors have overflow axes.
updateTensorShardings |= shardingProjection.updateSharding(
index, factorAxes, /*overflowAxes=*/{});
}

insertExplicitReshards(op, projection, rewriter, shardingRule, *meshName,
mesh);
insertExplicitReshards(op, shardingProjection, updateTensorShardings,
rewriter, shardingRule, *meshName, mesh);

// TODO(enver): Remove sharding rules from ops.
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ limitations under the License.
#include <tuple>

#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/SmallVector.h"
#include "mlir/IR/Value.h"
#include "mlir/Support/LLVM.h"
Expand Down Expand Up @@ -108,59 +109,68 @@ UpdateTensorShardings AggressiveFactorPropagation::propagateFactorShardings(
return result;
}

// We sort the factors based on:
// 1. larger source tensor size first
// 2. smaller source tensor index first
// 3. smaller factor index first
// Unstable sort is fine because there is no equality in the candidates.
// TODO(b/376233527): reevaluate this conflict resolution heuristic.
SmallVector<int64_t> sortedFactorIndices =
llvm::to_vector(llvm::seq<int64_t>(0, factorSizes.size()));
SmallVector<TensorIndexSize> factorToSourceTensor =
getFactorToSourceTensor(projection, factorSizes, axesPerFactor);
llvm::sort(sortedFactorIndices, [&](int64_t i, int64_t j) {
return std::forward_as_tuple(-factorToSourceTensor[i].size,
factorToSourceTensor[i].index, i) <
std::forward_as_tuple(-factorToSourceTensor[j].size,
factorToSourceTensor[j].index, j);
});

// The propagation on each tensor is independent. This strategy can propagate
// different shardings to different tensors along the same factor. Examples
// are provided in the docstring of this class.
for (const auto& [tensorIndex, tensorFactorShardings] :
llvm::enumerate(llvm::concat<const TensorFactorShardings>(
projection.getOperands(), projection.getResults()))) {
// Propagate the axes got in Step 1, and resolve conflicts within a factor.
FactorIndexToSharding newSharding =
const FactorIndexToSharding& factorIndexToSharding =
tensorFactorShardings.factorIndexToSharding;
BitVector factorUpdated(factorSizes.size());
for (auto& [factorIndex, factorSharding] : newSharding) {

// Propagate the axes got in Step 1, resolving conflicts between factors by
// following the order of preference in `sortedFactorIndices`.
bool tensorUpdated = false;
for (int64_t factorIndex : sortedFactorIndices) {
auto factorShardingIt = factorIndexToSharding.find(factorIndex);
if (factorShardingIt == factorIndexToSharding.end()) {
continue;
}
const FactorSharding& factorSharding = factorShardingIt->second;
SmallVector<AxisRefAttr> newAxes = axesPerFactor[factorIndex];

// Resolve conflicts within a factor.
truncateAxesByRemovingConflicts(
newAxes,
[&, factorIndex = factorIndex, &factorSharding = factorSharding,
[&, factorIndex = factorIndex,
&tensorFactorShardings = tensorFactorShardings](
AxisRefAttr axisRef, int64_t prevShardedSize) {
return compatiblePrefixNoConflictsWithinFactor(
axisRef, tensorFactorShardings.replicatedAxes, factorSharding,
prevShardedSize, factorSizes[factorIndex], mesh);
},
mesh, conservativePropagation);
if (isStrictPrefix(factorSharding.axisRefs, newAxes)) {
factorSharding.axisRefs = newAxes;
factorUpdated.set(factorIndex);
if (!isStrictPrefix(factorSharding.axisRefs, newAxes)) {
continue;
}
}

SmallVector<int> sortedFactorIndices = toSetBitsVector(factorUpdated);
// We sort the factors based on:
// 1. larger source tensor size first
// 2. smaller source tensor index first
// 3. smaller factor index first
// Unstable sort is fine because there is no equality in the candidates.
llvm::sort(sortedFactorIndices, [&](int64_t i, int64_t j) {
return std::forward_as_tuple(-factorToSourceTensor[i].size,
factorToSourceTensor[i].index, i) <
std::forward_as_tuple(-factorToSourceTensor[j].size,
factorToSourceTensor[j].index, j);
});

// Resolve conflicts (overlapping sharding axes) between factors.
bool tensorUpdated = false;
for (const int64_t factorIndex : sortedFactorIndices) {
SmallVector<AxisRefAttr> newAxes = newSharding[factorIndex].axisRefs;
// Resolve conflicts (overlapping sharding axes) between factors.
//
// Note that we pass `factorIndexToSharding`, which might have been
// updated for a previous factor (previous iteration), thus we are
// checking for conflicts w.r.t. the updated state of this tensor.
truncateAxesByRemovingConflicts(
newAxes,
[&, factorIndex = factorIndex](AxisRefAttr axisRef, int64_t) {
return compatiblePrefixNoConflictsAcrossFactors(
axisRef, newSharding, factorIndex);
axisRef, factorIndexToSharding, factorIndex);
},
mesh, conservativePropagation);
tensorUpdated |=
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,28 +44,35 @@ namespace sdy {
// `BasicFactorPropagation` is conservative in terms of conflicts across
// factors. The overlapped axis between factors cannot be propagated. This
// strategy is more aggressive by allowing the overlapped axis being propagated
// along different factors if there is no overlapped axis in the result
// along different factors if there is no overlapped axis in the current
// shardings.
//
// To resolve conflicts across factors, when there are multiple choices (that
// cannot co-exist), we prefer the factor with the larger source tensor (the
// tensor from which the factor sharding is propagated), as it's normally
// beneficial to reshard the smaller tensor. If two factors have the same source
// tensor size, we sort based on the source tensor index and finally factor
// index.
//
// Let us take C = dot(A, B) as an example. F0 is the factor corresponding to a
// non-contracting dimension of A. F1 corresponds to a non-contracting dimension
// of B. F2 corresponds to a contracting dimension. "-" means that the tensor
// of B. F2 corresponds to a contracting dimension. '-' means that the tensor
// does not contain the factor.
//
// F0 F1 F2
// A "a" -
// B -
// C "a" -
// Case 1. Fake conflict. `BasicFactorPropagation` propagates nothing, while
// this strategy propagates "a" to B/F1.
// Case 1. Conflict with a single choice. `BasicFactorPropagation` propagates
// nothing, while this strategy propagates "a" to B/F1.
//
// F0 F1 F2
// A "a" -
// B - "a"
// C -
// Case 2. Real conflict. Both `BasicFactorPropagation` and this strategy
// propagate nothing. We can propagate "a" to C/F0 or C/F1, which is illegal
// since "a" cannot be used twice in C.
// Case 2. Conflict with multiple choices. `BasicFactorPropagation` propagates
// nothing, while this strategy propagates "a" to C/F1, since F1 is preferred
// over F0 (tensor B is larger than A).
class AggressiveFactorPropagation : public BasicFactorPropagation {
public:
UpdateTensorShardings propagateFactorShardings(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ TEST_F(AggressiveFactorPropagationTest, RealAndFakeConflicts) {
/*results=*/{
{.factorIndexToSharding =
{
{0, {.axisRefs = {}}},
{0, {.axisRefs = {createAxis("a")}}},
{1, {.axisRefs = {}}},
{2, {.axisRefs = {}, .overflowAxes = {createAxis("d")}}},
{3, {.axisRefs = {createAxis("c")}}},
Expand All @@ -108,7 +108,8 @@ TEST_F(AggressiveFactorPropagationTest, RealAndFakeConflicts) {
});

// Axis "a" may be propagated to the result along factors 0 or 1, which forms
// a real conflict. Thus, we do not apply either of propagation choices.
// a real conflict. We prefer factor 0 because its source is the first operand
// (all tensors have the same size).
//
// Other conflicts are fake. We can propagate other axes as much as possible.
// Axes "c", "b", "e", "f", "g" can be propagated to the result along factors
Expand Down Expand Up @@ -320,7 +321,7 @@ TEST_F(AggressiveFactorPropagationTest, NewAxesConflict) {
/*results=*/{
{.factorIndexToSharding =
{
{0, {.axisRefs = {}}},
{0, {.axisRefs = {createAxis("a"), createAxis("b")}}},
{1, {.axisRefs = {}, .isClosed = true}},
{2, {.axisRefs = {createAxis("c")}}},
{3, {.axisRefs = {createAxis("d")}}},
Expand All @@ -337,8 +338,8 @@ TEST_F(AggressiveFactorPropagationTest, NewAxesConflict) {
});

// “a” can be propagated to the Result 0 along either Factor 0 or Factor 2.
// This strategy truncate “a” for both F0 and F2 in Result 0. Namely, this
// strategy does not resolve real conflicts across factors.
// This strategy prefers factor 0 because its source is the first operand
// (all tensors have the same size).
auto [updateOperands, updateResults] =
propagateFactorShardings(projection, 4);
EXPECT_THAT(toSetBitsVector(updateOperands), ElementsAre(1, 2));
Expand Down
Loading

0 comments on commit 477d319

Please sign in to comment.