Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Preserve HLO shardings on calls and non-entry functions. #113

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions docs/sdy_export_passes.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,53 @@
<!-- Autogenerated by mlir-tblgen; don't manually edit -->
### `-sdy-insert-explicit-reshards`

_Inserts explicit reshards to make all operations have compatible shardings._

A compatible sharding essentially means that the operation can accept the
sharded operands and produce a sharded result without requiring any reshard
communications (note that the operation might still require communication
such as all-reduce or halo-swaps).

After propagation, some opeartions may still have incompatible shardings.

Please note, when an axis (or sub-axis) is used to shard non-corresponding
dimensions (e.g. non-contracting dimensions in matmul) across multiple
tensors, or when an axis shards a dimension in one tensor but not the
corresponding dimension in the other tensor, it is said that the operation
has a sharding conflict. Hence, after this pass, the opeartions become
conflict-free.

This pass injects reshard operations explicitly so that, for each operation,
corresponding dimensions become sharded in the same way across all operands
and results, and every axis (or sub-axis) can only be used to shard a single
dimension type.

A clarifying example:

Input:
```mlir
mesh = <"x"=4, "y"=2>
%lhs : tensor<8x32xf32> {sdy.sharding=<@mesh, \[{"y"},{"x"}\]>}
%rhs : tensor<32x16xf32> {sdy.sharding=<@mesh, \[{"y"}, {"x"}\]>}
stablehlo.dot %lhs, %rhs {sdy.sharding_per_value=<[<@mesh, \[{"x"}, {}\]>]>}
: (tensor<8x32xf32>, tensor<32x16xf32>) -> tensor<8x16xf32>
```

Output:
```mlir
sdy.mesh = <"x"=4, "y"=2>
%lhs : tensor<8x32xf32> {sdy.sharding=<@mesh, \[{"x"}, {"y"}\]>}
%rhs : tensor<32x16xf32> {sdy.sharding=<@mesh, \[{"y"}, {"x"}\]>}
%0 = sdy.reshard %rhs <@mesh, \[{"y"}, {}\]> : tensor<32x16xf32>
stablehlo.dot %lhs, %0 {sdy.sharding_per_value=<[<@mesh, \[{"x"}, {}\]>]>}
: (tensor<8x32xf32>, tensor<32x16xf32>) -> tensor<8x16xf32>
```

In the example above, there is a conflict since `lhs` and `rhs` tensors
are both sharded on axis "x" on their non-contracting dimensions. Here,
`rhs` tensor is resharded, before the dot operation, explicitly to be
sharded only on its first dimension and on axis "x". This way, the dot
opearation becomes compatible.
### `-sdy-sharding-constraint-to-reshard`

_Converts ShardingConstraintOp into ReshardOp._
Expand Down
10 changes: 10 additions & 0 deletions shardy/dialect/sdy/ir/data_flow_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ limitations under the License.
#include "mlir/IR/Value.h"
#include "mlir/IR/ValueRange.h"
#include "mlir/Support/LLVM.h"
#include "shardy/dialect/sdy/ir/constants.h"
#include "shardy/dialect/sdy/ir/dialect.h"
#include "shardy/dialect/sdy/ir/utils.h"
#include "stablehlo/dialect/StablehloOps.h"
Expand Down Expand Up @@ -97,6 +98,15 @@ void setBlockArgumentEdgeOwnerShardings(
shardings);
}

void setOpResultEdgeOwnerShardings(Operation* op,
ArrayRef<TensorShardingAttr> shardings) {
if (auto shardableDataFlowOp = dyn_cast<ShardableDataFlowOpInterface>(op)) {
return shardableDataFlowOp.setOpResultEdgeOwnerShardings(shardings);
}
op->setAttr(kShardingAttr,
TensorShardingPerValueAttr::get(op->getContext(), shardings));
}

DataFlowEdgeOp getDataFlowEdge(Value target) {
return DataFlowEdgeOp::getDataFlowEdgeUser(getDataFlowEdgeOwner(target));
}
Expand Down
7 changes: 6 additions & 1 deletion shardy/dialect/sdy/ir/data_flow_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,16 @@ SmallVector<Value> getDataFlowSources(DataFlowEdgeOp dataFlowEdge);
// Returns all non-edge-owner targets of the given `dataFlowEdge`.
SmallVector<Value> getNonEdgeOwnerTargets(DataFlowEdgeOp dataFlowEdge);

// Sets the block argument edge owner shardings if the `op` is a
// Sets the block argument edge owner `shardings` if the `op` is a
// `ShardableDataFlowOpInterface`.
void setBlockArgumentEdgeOwnerShardings(Operation* op,
ArrayRef<TensorShardingAttr> shardings);

// Sets the op result edge owner `shardings` if the `op` is a
// `ShardableDataFlowOpInterface`.
void setOpResultEdgeOwnerShardings(Operation* op,
ArrayRef<TensorShardingAttr> shardings);

} // namespace sdy
} // namespace mlir

Expand Down
151 changes: 149 additions & 2 deletions shardy/dialect/sdy/ir/dialect.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,12 @@ limitations under the License.
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/Region.h"
#include "mlir/IR/TypeRange.h"
#include "mlir/IR/Types.h"
#include "mlir/IR/Value.h"
#include "mlir/IR/ValueRange.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
Expand All @@ -59,8 +61,8 @@ struct ShardyDialectInlinerInterface : public DialectInlinerInterface {
return true;
}

// ManualComputationOp is an op with a region, and it should be allowed to be
// inlined into another op.
// `ManualComputationOp` and `NamedComputationOp` are ops with a region, and
// it should be allowed to be inlined into another op.
bool isLegalToInline(Region*, Region*, bool, IRMapping&) const final {
return true;
}
Expand Down Expand Up @@ -90,6 +92,31 @@ void SdyDialect::initialize() {
>();
}

namespace details {

ArrayRef<TensorShardingAttr> getOpResultEdgeOwnerShardingsImpl(Operation* op) {
if (auto shardingPerResult =
op->getAttrOfType<TensorShardingPerValueAttr>(kShardingAttr)) {
return shardingPerResult.getShardings();
}
return {};
}

void setOpResultEdgeOwnerShardingImpl(Operation* op, unsigned index,
TensorShardingAttr sharding) {
op->setAttr(kShardingAttr,
getOrCreateShardingPerResult(op, sharding.getMeshName())
.replaceValueSharding(index, sharding));
}

void setOpResultEdgeOwnerShardingsImpl(Operation* op,
ArrayRef<TensorShardingAttr> shardings) {
op->setAttr(kShardingAttr,
TensorShardingPerValueAttr::get(op->getContext(), shardings));
}

} // namespace details

//===----------------------------------------------------------------------===//
// MeshAttr
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -665,6 +692,126 @@ DataFlowEdgeOp DataFlowEdgeOp::getDataFlowEdgeUser(Value root) {
root && root.hasOneUse() ? *root.user_begin() : nullptr);
}

//===----------------------------------------------------------------------===//
// NamedComputationOp
//===----------------------------------------------------------------------===//

// NOTE: we are assuming that if there are no shardings, all result shardings
// will be on the same mesh. Needs to change when supporting multiple meshes.
void NamedComputationOp::setOpResultEdgeOwnerSharding(
unsigned resultIndex, TensorShardingAttr sharding) {
if (std::optional<TensorShardingPerValueAttr> outShardings =
getOutShardings()) {
setOutShardingsAttr(
outShardings->replaceValueSharding(resultIndex, sharding));
} else {
setOutShardingsAttr(
TensorShardingPerValueAttr::getFullyOpen(getContext(), getResultTypes(),
sharding.getMeshName())
.replaceValueSharding(resultIndex, sharding));
}
}

void NamedComputationOp::setOpResultEdgeOwnerShardings(
ArrayRef<TensorShardingAttr> shardings) {
setOutShardingsAttr(TensorShardingPerValueAttr::get(getContext(), shardings));
}

ArrayRef<TensorShardingAttr>
NamedComputationOp::getBlockArgumentEdgeOwnerShardings() {
if (std::optional<TensorShardingPerValueAttr> inShardings =
getInShardings()) {
return inShardings->getShardings();
}
return {};
}

ArrayRef<TensorShardingAttr>
NamedComputationOp::getOpResultEdgeOwnerShardings() {
if (std::optional<TensorShardingPerValueAttr> outShardings =
getOutShardings()) {
return outShardings->getShardings();
}
return {};
}

// NOTE: we are assuming that if there are no shardings, all argument shardings
// will be on the same mesh. Needs to change when supporting multiple meshes.
void NamedComputationOp::setBlockArgumentEdgeOwnerSharding(
unsigned index, TensorShardingAttr sharding) {
if (std::optional<TensorShardingPerValueAttr> inShardings =
getInShardings()) {
setInShardingsAttr(inShardings->replaceValueSharding(index, sharding));
} else {
setInShardingsAttr(
TensorShardingPerValueAttr::getFullyOpen(
getContext(), getOperandTypes(), sharding.getMeshName())
.replaceValueSharding(index, sharding));
}
}

void NamedComputationOp::setBlockArgumentEdgeOwnerShardings(
ArrayRef<TensorShardingAttr> shardings) {
setInShardingsAttr(TensorShardingPerValueAttr::get(getContext(), shardings));
}

ArrayRef<BlockArgument> NamedComputationOp::getBlockArgumentEdgeOwners() {
return getBody().getArguments();
}

ResultRange NamedComputationOp::getOpResultEdgeOwners() { return getResults(); }

// Gets the sources given a target value.
//
// Note that the return values is a vector, for `NamedComputationOp`s there can
// only be one value but sdy's interface expects a vector.
//
// For example, given the following:
// ```
// %r = sdy.named_computation<"my_tan">(%operand0) (%arg0)
// %a = tanh(%arg0)
// sdy.return %a
// }
// ```
// If the target is a block argument (e.g., `%operand0`), return `%arg0`.
// If the target is a result (e.g., `%r`), return `%a`.
SmallVector<Value> NamedComputationOp::getEdgeSources(Value target) {
assert(getOwningOp(target) == getOperation());
return mlir::TypeSwitch<Value, SmallVector<Value>>(target)
.Case<BlockArgument>(
[this](BlockArgument blockArg) -> SmallVector<Value> {
return {getOperand(blockArg.getArgNumber())};
})
.Case<mlir::OpResult>([this](
mlir::OpResult opResult) -> SmallVector<Value> {
return {getBodyTerminatorOperand(*this, opResult.getResultNumber())};
})
.Default([](Value _) -> SmallVector<Value> { return {}; });
}

// Returns the edge owner value given a `target`.
//
// For `NamedComputationOp`s, there is only one target per data flow edge which
// is also the edge owner.
Value NamedComputationOp::getEdgeOwnerFromTarget(Value target) {
assert(getOwningOp(target) == getOperation());
return target;
}

// Returns the edge owner given a `source`.
//
// If the `source` is an operand of a terminator, return the corresponding
// result. Otherwise it should be an operand of the `NamedComputationOp`, return
// the `BlockArgument` with the same index.
Value NamedComputationOp::getEdgeOwnerFromSource(OpOperand& source) {
Operation* sourceOwner = source.getOwner();
if (sourceOwner->hasTrait<mlir::OpTrait::IsTerminator>()) {
return getResult(source.getOperandNumber());
}
assert(sourceOwner == getOperation());
return getOperand(source.getOperandNumber());
}

} // namespace sdy
} // namespace mlir

Expand Down
28 changes: 28 additions & 0 deletions shardy/dialect/sdy/ir/dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,12 @@ limitations under the License.
#include "mlir/IR/Operation.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/IR/Types.h"
#include "mlir/IR/ValueRange.h"
#include "mlir/Interfaces/CallInterfaces.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Support/LLVM.h"
#include "shardy/dialect/sdy/ir/constants.h"

// IWYU pragma: end_keep

Expand All @@ -49,6 +52,31 @@ limitations under the License.
#include "shardy/dialect/sdy/ir/attrs.h.inc"
// ODS-generated enum classes.
#include "shardy/dialect/sdy/ir/enums.h.inc"

// Below are methods that are the bodies of ODS-generated op-interface classes
// which cannot be inlined due to cyclic dependencies on helper functions.
namespace mlir {
namespace sdy {
namespace details {

// Implementation of the `getOpResultEdgeOwnerShardings` method of `
// ShardableDataFlowOpInterface`.
ArrayRef<TensorShardingAttr> getOpResultEdgeOwnerShardingsImpl(Operation* op);

// Implementation of the `setOpResultEdgeOwnerSharding` method of `
// ShardableDataFlowOpInterface`.
void setOpResultEdgeOwnerShardingImpl(Operation* op, unsigned index,
TensorShardingAttr sharding);

// Implementation of the `setOpResultEdgeOwnerShardings` method of `
// ShardableDataFlowOpInterface`.
void setOpResultEdgeOwnerShardingsImpl(Operation* op,
ArrayRef<TensorShardingAttr> shardings);

} // namespace details
} // namespace sdy
} // namespace mlir

// ODS-generated op-interface classes.
#include "shardy/dialect/sdy/ir/op_interface.h.inc"
// ODS-generated op classes.
Expand Down
Loading
Loading