Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#sdy import CallOps with backend_configs to NamedComputationOps. #116

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions docs/sdy_dialect.md
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,63 @@ Interfaces: `Symbol`
</table>


### `sdy.named_computation` (sdy::NamedComputationOp)

_Named computation operation_


Syntax:

```
operation ::= `sdy.named_computation` `<`$name`>` `` `(` $operands `)`
custom<SingleBlockRegionNoBlockId>($body)
attr-dict
`:` functional-type($operands, results)
```

Groups a computation, i.e. a block of operations, and gives it a name.
Propagation will flow in/out of the region as if everything was inlined.

This can be used to handle propagating through call instructions to other
functions. Any users of Shardy should write an import/export pass that
converts their call ops to `sdy.named_computation` ops, duplicating/copying
the body of the called function into the body of the `named_computation`.

The type of each block arguments and returned values in the region must be
the same as the type of the operands and results type of the op.

Example:

```mlir
%1 = sdy.named_computation<"foo">(%0) (%arg1: tensor<16x32xf32>) {
sdy.return %arg1 : tensor<16x32xf32>
} : (tensor<16x32xf32>) -> tensor<16x32xf32>
```

Traits: `IsolatedFromAbove`, `RecursiveMemoryEffects`, `RecursivelySpeculatableImplTrait`, `SingleBlockImplicitTerminator<ReturnOp>`, `SingleBlock`

Interfaces: `ConditionallySpeculatable`

#### Attributes:

<table>
<tr><th>Attribute</th><th>MLIR Type</th><th>Description</th></tr>
<tr><td><code>name</code></td><td>::mlir::StringAttr</td><td>string attribute</td></tr>
</table>

#### Operands:

| Operand | Description |
| :-----: | ----------- |
| `operands` | variadic of any type

#### Results:

| Result | Description |
| :----: | ----------- |
&laquo;unnamed&raquo; | variadic of any type


### `sdy.propagation_barrier` (sdy::PropagationBarrierOp)

_Propagation barrier operation_
Expand Down
49 changes: 49 additions & 0 deletions docs/sdy_export_passes.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,53 @@
<!-- Autogenerated by mlir-tblgen; don't manually edit -->
### `-sdy-insert-explicit-reshards`

_Inserts explicit reshards to make all operations have compatible shardings._

A compatible sharding essentially means that the operation can accept the
sharded operands and produce a sharded result without requiring any reshard
communications (note that the operation might still require communication
such as all-reduce or halo-swaps).

After propagation, some opeartions may still have incompatible shardings.

Please note, when an axis (or sub-axis) is used to shard non-corresponding
dimensions (e.g. non-contracting dimensions in matmul) across multiple
tensors, or when an axis shards a dimension in one tensor but not the
corresponding dimension in the other tensor, it is said that the operation
has a sharding conflict. Hence, after this pass, the opeartions become
conflict-free.

This pass injects reshard operations explicitly so that, for each operation,
corresponding dimensions become sharded in the same way across all operands
and results, and every axis (or sub-axis) can only be used to shard a single
dimension type.

A clarifying example:

Input:
```mlir
mesh = <"x"=4, "y"=2>
%lhs : tensor<8x32xf32> {sdy.sharding=<@mesh, \[{"y"},{"x"}\]>}
%rhs : tensor<32x16xf32> {sdy.sharding=<@mesh, \[{"y"}, {"x"}\]>}
stablehlo.dot %lhs, %rhs {sdy.sharding_per_value=<[<@mesh, \[{"x"}, {}\]>]>}
: (tensor<8x32xf32>, tensor<32x16xf32>) -> tensor<8x16xf32>
```

Output:
```mlir
sdy.mesh = <"x"=4, "y"=2>
%lhs : tensor<8x32xf32> {sdy.sharding=<@mesh, \[{"x"}, {"y"}\]>}
%rhs : tensor<32x16xf32> {sdy.sharding=<@mesh, \[{"y"}, {"x"}\]>}
%0 = sdy.reshard %rhs <@mesh, \[{"y"}, {}\]> : tensor<32x16xf32>
stablehlo.dot %lhs, %0 {sdy.sharding_per_value=<[<@mesh, \[{"x"}, {}\]>]>}
: (tensor<8x32xf32>, tensor<32x16xf32>) -> tensor<8x16xf32>
```

In the example above, there is a conflict since `lhs` and `rhs` tensors
are both sharded on axis "x" on their non-contracting dimensions. Here,
`rhs` tensor is resharded, before the dot operation, explicitly to be
sharded only on its first dimension and on axis "x". This way, the dot
opearation becomes compatible.
### `-sdy-sharding-constraint-to-reshard`

_Converts ShardingConstraintOp into ReshardOp._
Expand Down
10 changes: 10 additions & 0 deletions shardy/dialect/sdy/ir/data_flow_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ limitations under the License.
#include "mlir/IR/Value.h"
#include "mlir/IR/ValueRange.h"
#include "mlir/Support/LLVM.h"
#include "shardy/dialect/sdy/ir/constants.h"
#include "shardy/dialect/sdy/ir/dialect.h"
#include "shardy/dialect/sdy/ir/utils.h"
#include "stablehlo/dialect/StablehloOps.h"
Expand Down Expand Up @@ -97,6 +98,15 @@ void setBlockArgumentEdgeOwnerShardings(
shardings);
}

void setOpResultEdgeOwnerShardings(Operation* op,
ArrayRef<TensorShardingAttr> shardings) {
if (auto shardableDataFlowOp = dyn_cast<ShardableDataFlowOpInterface>(op)) {
return shardableDataFlowOp.setOpResultEdgeOwnerShardings(shardings);
}
op->setAttr(kShardingAttr,
TensorShardingPerValueAttr::get(op->getContext(), shardings));
}

DataFlowEdgeOp getDataFlowEdge(Value target) {
return DataFlowEdgeOp::getDataFlowEdgeUser(getDataFlowEdgeOwner(target));
}
Expand Down
7 changes: 6 additions & 1 deletion shardy/dialect/sdy/ir/data_flow_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,16 @@ SmallVector<Value> getDataFlowSources(DataFlowEdgeOp dataFlowEdge);
// Returns all non-edge-owner targets of the given `dataFlowEdge`.
SmallVector<Value> getNonEdgeOwnerTargets(DataFlowEdgeOp dataFlowEdge);

// Sets the block argument edge owner shardings if the `op` is a
// Sets the block argument edge owner `shardings` if the `op` is a
// `ShardableDataFlowOpInterface`.
void setBlockArgumentEdgeOwnerShardings(Operation* op,
ArrayRef<TensorShardingAttr> shardings);

// Sets the op result edge owner `shardings` if the `op` is a
// `ShardableDataFlowOpInterface`.
void setOpResultEdgeOwnerShardings(Operation* op,
ArrayRef<TensorShardingAttr> shardings);

} // namespace sdy
} // namespace mlir

Expand Down
177 changes: 175 additions & 2 deletions shardy/dialect/sdy/ir/dialect.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,12 @@ limitations under the License.
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/Region.h"
#include "mlir/IR/TypeRange.h"
#include "mlir/IR/Types.h"
#include "mlir/IR/Value.h"
#include "mlir/IR/ValueRange.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
Expand All @@ -59,8 +61,8 @@ struct ShardyDialectInlinerInterface : public DialectInlinerInterface {
return true;
}

// ManualComputationOp is an op with a region, and it should be allowed to be
// inlined into another op.
// `ManualComputationOp` and `NamedComputationOp` are ops with a region, and
// it should be allowed to be inlined into another op.
bool isLegalToInline(Region*, Region*, bool, IRMapping&) const final {
return true;
}
Expand Down Expand Up @@ -90,6 +92,72 @@ void SdyDialect::initialize() {
>();
}

namespace details {

ArrayRef<TensorShardingAttr> getOpResultEdgeOwnerShardingsImpl(Operation* op) {
if (auto shardingPerResult =
op->getAttrOfType<TensorShardingPerValueAttr>(kShardingAttr)) {
return shardingPerResult.getShardings();
}
return {};
}

void setOpResultEdgeOwnerShardingImpl(Operation* op, unsigned index,
TensorShardingAttr sharding) {
op->setAttr(kShardingAttr, replaceShardingPerValue(op, index, sharding));
}

void setOpResultEdgeOwnerShardingsImpl(Operation* op,
ArrayRef<TensorShardingAttr> shardings) {
op->setAttr(kShardingAttr,
TensorShardingPerValueAttr::get(op->getContext(), shardings));
}

} // namespace details

//===----------------------------------------------------------------------===//
// ShardableDataFlowOpInterface
//===----------------------------------------------------------------------===//

mlir::sdy::TensorShardingAttr
ShardableDataFlowOpInterface::getBlockArgumentEdgeOwnerSharding(
unsigned index) {
if (mlir::ArrayRef<mlir::sdy::TensorShardingAttr> argSharding =
getBlockArgumentEdgeOwnerShardings();
!argSharding.empty()) {
return argSharding[index];
}
return nullptr;
}

mlir::sdy::TensorShardingAttr
ShardableDataFlowOpInterface::getOpResultEdgeOwnerSharding(unsigned index) {
if (mlir::ArrayRef<mlir::sdy::TensorShardingAttr> resultSharding =
getOpResultEdgeOwnerShardings();
!resultSharding.empty()) {
return resultSharding[index];
}
return nullptr;
}

mlir::sdy::TensorShardingAttr
ShardableDataFlowOpInterface::getEdgeOwnerSharding(Value value) {
if (auto blockArg = dyn_cast<BlockArgument>(value)) {
return getBlockArgumentEdgeOwnerSharding(blockArg.getArgNumber());
}
return getOpResultEdgeOwnerSharding(cast<OpResult>(value).getResultNumber());
}

void ShardableDataFlowOpInterface::setEdgeOwnerSharding(
Value value, mlir::sdy::TensorShardingAttr sharding) {
if (auto blockArg = dyn_cast<BlockArgument>(value)) {
setBlockArgumentEdgeOwnerSharding(blockArg.getArgNumber(), sharding);
} else {
setOpResultEdgeOwnerSharding(cast<OpResult>(value).getResultNumber(),
sharding);
}
}

//===----------------------------------------------------------------------===//
// MeshAttr
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -665,6 +733,111 @@ DataFlowEdgeOp DataFlowEdgeOp::getDataFlowEdgeUser(Value root) {
root && root.hasOneUse() ? *root.user_begin() : nullptr);
}

//===----------------------------------------------------------------------===//
// NamedComputationOp
//===----------------------------------------------------------------------===//

void NamedComputationOp::setOpResultEdgeOwnerSharding(
unsigned resultIndex, TensorShardingAttr sharding) {
TensorShardingPerValueAttr outShardings =
getOutShardings().value_or(TensorShardingPerValueAttr::getFullyOpen(
getContext(), getResultTypes(), sharding.getMeshName()));
setOutShardingsAttr(outShardings.replaceValueSharding(resultIndex, sharding));
}

void NamedComputationOp::setOpResultEdgeOwnerShardings(
ArrayRef<TensorShardingAttr> shardings) {
setOutShardingsAttr(TensorShardingPerValueAttr::get(getContext(), shardings));
}

ArrayRef<TensorShardingAttr>
NamedComputationOp::getBlockArgumentEdgeOwnerShardings() {
if (std::optional<TensorShardingPerValueAttr> inShardings =
getInShardings()) {
return inShardings->getShardings();
}
return {};
}

ArrayRef<TensorShardingAttr>
NamedComputationOp::getOpResultEdgeOwnerShardings() {
if (std::optional<TensorShardingPerValueAttr> outShardings =
getOutShardings()) {
return outShardings->getShardings();
}
return {};
}

void NamedComputationOp::setBlockArgumentEdgeOwnerSharding(
unsigned index, TensorShardingAttr sharding) {
TensorShardingPerValueAttr inShardings =
getInShardings().value_or(TensorShardingPerValueAttr::getFullyOpen(
getContext(), getOperandTypes(), sharding.getMeshName()));
setInShardingsAttr(inShardings.replaceValueSharding(index, sharding));
}

void NamedComputationOp::setBlockArgumentEdgeOwnerShardings(
ArrayRef<TensorShardingAttr> shardings) {
setInShardingsAttr(TensorShardingPerValueAttr::get(getContext(), shardings));
}

ArrayRef<BlockArgument> NamedComputationOp::getBlockArgumentEdgeOwners() {
return getBody().getArguments();
}

ResultRange NamedComputationOp::getOpResultEdgeOwners() { return getResults(); }

// Gets the sources given a target value.
//
// Note that the return value is a vector, for `NamedComputationOp`s there can
// only be one value but sdy's interface expects a vector.
//
// For example, given the following:
// ```
// %r = sdy.named_computation<"my_tan">(%operand0) (%arg0)
// %a = tanh(%arg0)
// sdy.return %a
// }
// ```
// If the target is a block argument (e.g., `%operand0`), return `%arg0`.
// If the target is a result (e.g., `%r`), return `%a`.
SmallVector<Value> NamedComputationOp::getEdgeSources(Value target) {
assert(getOwningOp(target) == getOperation());
return mlir::TypeSwitch<Value, SmallVector<Value>>(target)
.Case<BlockArgument>(
[this](BlockArgument blockArg) -> SmallVector<Value> {
return {getOperand(blockArg.getArgNumber())};
})
.Case<mlir::OpResult>([this](
mlir::OpResult opResult) -> SmallVector<Value> {
return {getBodyTerminatorOperand(*this, opResult.getResultNumber())};
})
.Default([](Value _) -> SmallVector<Value> { return {}; });
}

// Returns the edge owner value given a `target`.
//
// For `NamedComputationOp`s, there is only one target per data flow edge which
// is also the edge owner.
Value NamedComputationOp::getEdgeOwnerFromTarget(Value target) {
assert(getOwningOp(target) == getOperation());
return target;
}

// Returns the edge owner given a `source`.
//
// If the `source` is an operand of a terminator, return the corresponding
// result. Otherwise it should be an operand of the `NamedComputationOp`, return
// the `BlockArgument` with the same index.
Value NamedComputationOp::getEdgeOwnerFromSource(OpOperand& source) {
Operation* sourceOwner = source.getOwner();
if (sourceOwner->hasTrait<mlir::OpTrait::IsTerminator>()) {
return getResult(source.getOperandNumber());
}
assert(sourceOwner == getOperation());
return getOperand(source.getOperandNumber());
}

} // namespace sdy
} // namespace mlir

Expand Down
Loading
Loading