Skip to content

Commit

Permalink
#sdy Add support for specified an inlined MeshAttr in a `TensorShar…
Browse files Browse the repository at this point in the history
…dingAttr` instead of referencing a symbol `MeshOp` by name.

PiperOrigin-RevId: 681801498
  • Loading branch information
tomnatan30 authored and copybara-github committed Oct 3, 2024
1 parent c464210 commit d6f0f01
Show file tree
Hide file tree
Showing 57 changed files with 3,045 additions and 4,293 deletions.
57 changes: 57 additions & 0 deletions docs/sdy_dialect.md
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,63 @@ Interfaces: `Symbol`
</table>


### `sdy.named_computation` (sdy::NamedComputationOp)

_Named computation operation_


Syntax:

```
operation ::= `sdy.named_computation` `<`$name`>` `` `(` $operands `)`
custom<SingleBlockRegionNoBlockId>($body)
attr-dict
`:` functional-type($operands, results)
```

Groups a computation, i.e. a block of operations, and gives it a name.
Propagation will flow in/out of the region as if everything was inlined.

This can be used to handle propagating through call instructions to other
functions. Any users of Shardy should write an import/export pass that
converts their call ops to `sdy.named_computation` ops, duplicating/copying
the body of the called function into the body of the `named_computation`.

The type of each block arguments and returned values in the region must be
the same as the type of the operands and results type of the op.

Example:

```mlir
%1 = sdy.named_computation<"foo">(%0) (%arg1: tensor<16x32xf32>) {
sdy.return %arg1 : tensor<16x32xf32>
} : (tensor<16x32xf32>) -> tensor<16x32xf32>
```

Traits: `IsolatedFromAbove`, `RecursiveMemoryEffects`, `RecursivelySpeculatableImplTrait`, `SingleBlockImplicitTerminator<ReturnOp>`, `SingleBlock`

Interfaces: `ConditionallySpeculatable`

#### Attributes:

<table>
<tr><th>Attribute</th><th>MLIR Type</th><th>Description</th></tr>
<tr><td><code>name</code></td><td>::mlir::StringAttr</td><td>string attribute</td></tr>
</table>

#### Operands:

| Operand | Description |
| :-----: | ----------- |
| `operands` | variadic of any type

#### Results:

| Result | Description |
| :----: | ----------- |
&laquo;unnamed&raquo; | variadic of any type


### `sdy.propagation_barrier` (sdy::PropagationBarrierOp)

_Propagation barrier operation_
Expand Down
49 changes: 49 additions & 0 deletions docs/sdy_export_passes.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,53 @@
<!-- Autogenerated by mlir-tblgen; don't manually edit -->
### `-sdy-insert-explicit-reshards`

_Inserts explicit reshards to make all operations have compatible shardings._

A compatible sharding essentially means that the operation can accept the
sharded operands and produce a sharded result without requiring any reshard
communications (note that the operation might still require communication
such as all-reduce or halo-swaps).

After propagation, some opeartions may still have incompatible shardings.

Please note, when an axis (or sub-axis) is used to shard non-corresponding
dimensions (e.g. non-contracting dimensions in matmul) across multiple
tensors, or when an axis shards a dimension in one tensor but not the
corresponding dimension in the other tensor, it is said that the operation
has a sharding conflict. Hence, after this pass, the opeartions become
conflict-free.

This pass injects reshard operations explicitly so that, for each operation,
corresponding dimensions become sharded in the same way across all operands
and results, and every axis (or sub-axis) can only be used to shard a single
dimension type.

A clarifying example:

Input:
```mlir
mesh = <"x"=4, "y"=2>
%lhs : tensor<8x32xf32> {sdy.sharding=<@mesh, \[{"y"},{"x"}\]>}
%rhs : tensor<32x16xf32> {sdy.sharding=<@mesh, \[{"y"}, {"x"}\]>}
stablehlo.dot %lhs, %rhs {sdy.sharding_per_value=<[<@mesh, \[{"x"}, {}\]>]>}
: (tensor<8x32xf32>, tensor<32x16xf32>) -> tensor<8x16xf32>
```

Output:
```mlir
sdy.mesh = <"x"=4, "y"=2>
%lhs : tensor<8x32xf32> {sdy.sharding=<@mesh, \[{"x"}, {"y"}\]>}
%rhs : tensor<32x16xf32> {sdy.sharding=<@mesh, \[{"y"}, {"x"}\]>}
%0 = sdy.reshard %rhs <@mesh, \[{"y"}, {}\]> : tensor<32x16xf32>
stablehlo.dot %lhs, %0 {sdy.sharding_per_value=<[<@mesh, \[{"x"}, {}\]>]>}
: (tensor<8x32xf32>, tensor<32x16xf32>) -> tensor<8x16xf32>
```

In the example above, there is a conflict since `lhs` and `rhs` tensors
are both sharded on axis "x" on their non-contracting dimensions. Here,
`rhs` tensor is resharded, before the dot operation, explicitly to be
sharded only on its first dimension and on axis "x". This way, the dot
opearation becomes compatible.
### `-sdy-sharding-constraint-to-reshard`

_Converts ShardingConstraintOp into ReshardOp._
Expand Down
60 changes: 50 additions & 10 deletions shardy/dialect/sdy/ir/attrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -361,28 +361,43 @@ def Sdy_DimensionSharding : AttrDef<Sdy_Dialect, "DimensionSharding"> {
}];
}

// Either a `MeshAttr` or a symbol name, referencing a corresponding `MeshOp`
// symbol.
def Sdy_MeshOrRef : AnyAttrOf<[Sdy_Mesh, FlatSymbolRefAttr]> {
string cppType = "::mlir::Attribute";
}

def Sdy_TensorSharding : AttrDef<Sdy_Dialect, "TensorSharding"> {
let mnemonic = "sharding";
let summary = "Tensor sharding";
let description = [{
A tensor sharding is bound to a specific mesh by its name, and can only
reference axis names from that mesh. The dimension shardings tell us for
each dimension of the tensor, along which axes (or sub-axes) it is sharded
from major to minor. All other axes that don’t shard a dimension are either
implicitly or explicitly (if they appear in the list of replicated axes)
replicated.
A tensor sharding is bound to a specific mesh, and can only reference axis
names from that mesh. The dimension shardings tell us for each dimension of
the tensor, along which axes (or sub-axes) it is sharded from major to
minor. All other axes that don’t shard a dimension are either implicitly or
explicitly (if they appear in the list of replicated axes) replicated.

The mesh this sharding is bound to can either be specified by a symbol
name, referencing a corresponding `MeshOp` symbol, or by an inlined
`MeshAttr`.
}];
let parameters = (ins
"FlatSymbolRefAttr":$mesh_sym_name,
Sdy_MeshOrRef:$mesh_or_ref,
OptionalArrayRefParameter<"DimensionShardingAttr">:$dim_shardings,
Sdy_AxisRefs:$replicated_axes
);
let assemblyFormat = [{
`<` $mesh_sym_name `,` `[` (`]`):($dim_shardings^ `]`)? ``
`<` custom<MeshOrRef>($mesh_or_ref) `,` `[` (`]`):($dim_shardings^ `]`)? ``
(`,` `replicated` `` `=` `` `{` $replicated_axes^ `}`)? `>`
}];

let builders = [
AttrBuilder<(ins "StringAttr":$mesh_name,
"ArrayRef<DimensionShardingAttr>":$dim_shardings,
"ArrayRef<AxisRefAttr>":$replicated_axes), [{
return $_get($_ctxt, FlatSymbolRefAttr::get(mesh_name),
dim_shardings, replicated_axes);
}]>,
AttrBuilder<(ins "StringRef":$mesh_name,
"ArrayRef<DimensionShardingAttr>":$dim_shardings,
"ArrayRef<AxisRefAttr>":$replicated_axes), [{
Expand Down Expand Up @@ -418,10 +433,29 @@ def Sdy_TensorSharding : AttrDef<Sdy_Dialect, "TensorSharding"> {
});
}

// Returns the mesh `FlatSymbolRefAttr` this sharding references, assuming
// is doesn't have an inlined `MeshAttr`.
FlatSymbolRefAttr getMeshSymName() const {
return mlir::cast<FlatSymbolRefAttr>(getMeshOrRef());
}

// Returns the mesh name this sharding references, assuming is doesn't have
// an inlined `MeshAttr`.
StringRef getMeshName() const {
return getMeshSymName().getValue();
}

// If this sharding has an inlined `MeshAttr`, returns it, otherwise looks
// up the mesh symbol with the referenced name in `symbolTable`, and returns
// its `MeshAttr` if it exists in the table, or nullptr otherwise.
MeshAttr getMesh(const SymbolTable& symbolTable) const;

// If this sharding has an inlined `MeshAttr`, returns it, otherwise looks
// up the mesh symbol with the referenced name in the symbol table of the
// enclosing module of `op`, and returns its `MeshAttr` if it exists in the
// table, or nullptr otherwise.
MeshAttr getMesh(Operation* op) const;

// Returns true if all dimension shardings are empty and there are no
// replicated axes.
bool emptyAxes() const;
Expand Down Expand Up @@ -518,11 +552,17 @@ def Sdy_TensorShardingPerValue : AttrDef<Sdy_Dialect, "TensorShardingPerValue">
let assemblyFormat = "`<` `[` (`]`):($shardings^ `]`)? `>`";

let extraClassDeclaration = [{
// Builds a `TensorShardingPerValue` for each type in `types`, with all
// dimension shardings marked open (can be further replicated/sharded).
// Builds a `TensorSharding` for each type in `types`, with all dimension
// shardings marked open (can be further replicated/sharded).
static TensorShardingPerValueAttr getFullyOpen(
MLIRContext* context, TypeRange types, StringRef meshName);

// Builds an open `TensorSharding` for each type in `types`, but
// with the sharding at `index` replaced with `sharding`.
static TensorShardingPerValueAttr getOpenWithShardingAtIndex(
MLIRContext* context, TypeRange types, int64_t index,
TensorShardingAttr sharding);

// Returns whether there are no values.
bool empty() const { return getShardings().empty(); }

Expand Down
20 changes: 12 additions & 8 deletions shardy/dialect/sdy/ir/canonicalization.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@ limitations under the License.

#include <cassert>
#include <cstdint>
#include <optional>

#include "llvm/ADT/STLExtras.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/PatternMatch.h"
Expand Down Expand Up @@ -80,21 +80,25 @@ class RedundantManualComputationPattern
private:
LogicalResult matchAndRewrite(ManualComputationOp manualComputationOp,
PatternRewriter& rewriter) const override {
std::optional<StringRef> meshName =
getCommonMeshName(manualComputationOp.getInShardings().getShardings(),
manualComputationOp.getOutShardings().getShardings());
ArrayRef<TensorShardingAttr> inShardings =
manualComputationOp.getInShardings().getShardings();
ArrayRef<TensorShardingAttr> outShardings =
manualComputationOp.getOutShardings().getShardings();

int64_t manualAxesProduct = 1;
if (meshName.has_value()) {
MeshAttr mesh = getMeshAttr(manualComputationOp, *meshName);
assert(mesh && "unknown mesh");
if (!inShardings.empty() && !outShardings.empty()) {
MeshAttr mesh =
getCommonMesh(inShardings, outShardings, manualComputationOp);
for (StringAttr manualAxis : manualComputationOp.getManualAxes()) {
manualAxesProduct *= mesh.getAxisSize(manualAxis);
}
}

if (manualAxesProduct != 1) {
return failure();
return rewriter.notifyMatchFailure(
manualComputationOp, [](Diagnostic& diag) {
diag << "product of manual axis sizes is not 1";
});
}

mlir::InlinerInterface inliner(manualComputationOp.getContext());
Expand Down
8 changes: 8 additions & 0 deletions shardy/dialect/sdy/ir/data_flow_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,14 @@ void setBlockArgumentEdgeOwnerShardings(
shardings);
}

void setOpResultEdgeOwnerShardings(Operation* op,
ArrayRef<TensorShardingAttr> shardings) {
if (auto shardableDataFlowOp = dyn_cast<ShardableDataFlowOpInterface>(op)) {
return shardableDataFlowOp.setOpResultEdgeOwnerShardings(shardings);
}
setShardings(op, shardings);
}

DataFlowEdgeOp getDataFlowEdge(Value target) {
return DataFlowEdgeOp::getDataFlowEdgeUser(getDataFlowEdgeOwner(target));
}
Expand Down
7 changes: 6 additions & 1 deletion shardy/dialect/sdy/ir/data_flow_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,16 @@ SmallVector<Value> getDataFlowSources(DataFlowEdgeOp dataFlowEdge);
// Returns all non-edge-owner targets of the given `dataFlowEdge`.
SmallVector<Value> getNonEdgeOwnerTargets(DataFlowEdgeOp dataFlowEdge);

// Sets the block argument edge owner shardings if the `op` is a
// Sets the block argument edge owner `shardings` if the `op` is a
// `ShardableDataFlowOpInterface`.
void setBlockArgumentEdgeOwnerShardings(Operation* op,
ArrayRef<TensorShardingAttr> shardings);

// Sets the op result edge owner `shardings` if the `op` is a
// `ShardableDataFlowOpInterface`.
void setOpResultEdgeOwnerShardings(Operation* op,
ArrayRef<TensorShardingAttr> shardings);

} // namespace sdy
} // namespace mlir

Expand Down
Loading

0 comments on commit d6f0f01

Please sign in to comment.