Skip to content

Commit

Permalink
[SDY] preprocess sharding groups which have shardings prior to propag…
Browse files Browse the repository at this point in the history
…ation to validate there are no inter-group conflicts.

PiperOrigin-RevId: 675771176
  • Loading branch information
Varcho authored and copybara-github committed Sep 19, 2024
1 parent c1a8669 commit 82bec19
Show file tree
Hide file tree
Showing 4 changed files with 210 additions and 4 deletions.
3 changes: 3 additions & 0 deletions shardy/dialect/sdy/transforms/import/import_pipeline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ void addImportPipeline(OpPassManager& pm, StringRef dumpDirectory) {
pm.addNestedPass<func::FuncOp>(createConstantSplitterPass());
pm.addNestedPass<func::FuncOp>(createAddDataFlowEdgesPass());
pm.addNestedPass<func::FuncOp>(createApplyShardingConstraintsPass());
// The sharding group import pass must run after applying sharding
// constraints. This ensures we can detect sharding conflicts between group
// members which have pre-propagation shardings due to sharding constraints.
pm.addPass(createShardingGroupImportPass());
pm.addPass(createImportMaximalShardingPass());

Expand Down
81 changes: 77 additions & 4 deletions shardy/dialect/sdy/transforms/import/sharding_group_import.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ limitations under the License.
#include "mlir/Pass/Pass.h" // IWYU pragma: keep
#include "mlir/Support/LLVM.h"
#include "shardy/dialect/sdy/ir/dialect.h"
#include "shardy/dialect/sdy/ir/utils.h"

namespace mlir {
namespace sdy {
Expand All @@ -41,8 +42,11 @@ using llvm::SmallVector;

using ValueToShardingGroup =
llvm::DenseMap<Value, llvm::SmallVector<ShardingGroupOp>>;
using GroupIdToValues = llvm::DenseMap<int64_t, SmallVector<Value>>;
using TensorShape = ArrayRef<int64_t>;

void unifyShardingGroups(ValueToShardingGroup& tensorToGroups) {
void unifyShardingGroups(ValueToShardingGroup& tensorToGroups,
GroupIdToValues& groupIdToReindexedTensors) {
if (tensorToGroups.empty()) {
return;
}
Expand Down Expand Up @@ -75,6 +79,7 @@ void unifyShardingGroups(ValueToShardingGroup& tensorToGroups) {
for (ShardingGroupOp op : groupsForTensor) {
op.setGroupId(reindexMap[shardingGroupEquivalences.getLeaderValue(
op.getGroupId())]);
groupIdToReindexedTensors[op.getGroupId()].push_back(op.getInput());
}
}
}
Expand All @@ -83,14 +88,15 @@ LogicalResult buildShardingGroupMappingAndValidateGroups(
ModuleOp module, ValueToShardingGroup& tensorToGroups) {
// Map to hold validation info for shard groups within manual computations.
DenseMap<int64_t, ManualComputationOp> groupToManualComp;
DenseMap<int64_t, TensorShape> groupToTensorShape;

// While walking the graph we simultaneously build up the tensorToGroups
// mapping (which will be used for unification) while also validating the
// structure of shard groups meets expectations
WalkResult result = module.walk([&](ShardingGroupOp op) {
tensorToGroups[op.getInput()].push_back(op);

// Validate sharding groups. All values in a group should have either:
// All values in a sharding group should have either:
// 1) No manual computation op parent
// 2) The same manual computation op parent.
// If a group has no manual computation op parent, 'groupToManualComp'
Expand All @@ -108,11 +114,68 @@ LogicalResult buildShardingGroupMappingAndValidateGroups(
return WalkResult::interrupt();
}

// All values in asharding group should have the same shape. It is possible
// to relax this constraint to just requiring ranks are the same (if we
// are not in conservative mode). However GSPMD required tensor shapes to be
// equivalent, so we will maintain this stricter requirement for parity.
TensorShape ts = getTensorShape(op.getInput());
auto [ts_it, ts_inserted] = groupToTensorShape.try_emplace(groupId, ts);
if (!ts_inserted && ts_it->getSecond() != ts) {
op.emitError(
"ShardingGroupOps values must have the same shape for groupId: ")
<< groupId;
return WalkResult::interrupt();
}

return WalkResult::advance();
});
return failure(result.wasInterrupted());
}

LogicalResult validateCompatibilityAndApplyInitialShardingConstraints(
ModuleOp module, GroupIdToValues& groupIdToValues) {
DenseMap<int64_t, TensorShardingAttr> groupIdToSharding;
// Sharding Constraints will only conflict with Sharding Groups if their value
// is a member of some sharding group. Because of this it is sufficient to
// only validate consistency of shardings of values in ShardingGroupOps.
WalkResult result = module.walk([&](ShardingGroupOp shardingGroupOp) {
TensorShardingAttr sharding = getSharding(shardingGroupOp.getInput());
// Conflicts only occur when there are two or more Values in a group
// which have a sharding and those shardings are different. If there is no
// sharding, then there will be no conflict.
if (!sharding) {
return WalkResult::advance();
}

int64_t groupId = shardingGroupOp.getGroupId();
auto [it, inserted] = groupIdToSharding.try_emplace(groupId, sharding);
if (!inserted && it->second != sharding) {
shardingGroupOp.emitError(
"Inconsistent shardings prior to propagation for ShardingGroupOps "
"with canonicalized groupId: ")
<< groupId;
return WalkResult::interrupt();
}
return WalkResult::advance();
});

if (result.wasInterrupted()) {
return failure();
}

// Apply initial shardings to all values in the group.
for (auto& [groupId, sharding] : groupIdToSharding) {
if (!sharding) {
continue;
}
for (Value value : groupIdToValues[groupId]) {
setSharding(value, sharding);
}
}

return success();
}

struct ShardingGroupImportPass
: public impl::ShardingGroupImportPassBase<ShardingGroupImportPass> {
using ShardingGroupImportPassBase::ShardingGroupImportPassBase;
Expand All @@ -121,12 +184,22 @@ struct ShardingGroupImportPass
// Extract the sharding group ids and tensor -> {group_id} mapping from the
// high level module and validate any sharding group constrainst are met.
ValueToShardingGroup tensorToGroups;
if (failed(buildShardingGroupMappingAndValidateGroups(getOperation(),
ModuleOp module = getOperation();
if (failed(buildShardingGroupMappingAndValidateGroups(module,
tensorToGroups))) {
signalPassFailure();
}

unifyShardingGroups(tensorToGroups);
GroupIdToValues groupIdToReindexedTensors;
unifyShardingGroups(tensorToGroups, groupIdToReindexedTensors);

// This pass assumes sharding constraints are already applied to values.
// Compatibility constraints are applied after group unification to detect
// conflicts within the unified groups.
if (failed(validateCompatibilityAndApplyInitialShardingConstraints(
module, groupIdToReindexedTensors))) {
signalPassFailure();
}
}
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,3 +168,67 @@ func.func @main(%arg0: tensor<8x8xf32>) -> tensor<8x8xf32> {
sdy.sharding_group %0 group_id = 7331 : tensor<8x8xf32>
func.return %0: tensor<8x8xf32>
}

// -----

sdy.mesh @mesh = <["a"=2, "b"=2]>

// Disallow creation of sharding groups which have values with different shapes.
func.func @main(%arg0: tensor<8x8xf32>) -> tensor<8x8xf32> {
%0 = stablehlo.add %arg0, %arg0 : tensor<8x8xf32>
%1 = stablehlo.constant dense<0.0> : tensor<8x8x1xf32>
sdy.sharding_group %arg0 group_id = 23 : tensor<8x8xf32>
sdy.sharding_group %0 group_id = 23 : tensor<8x8xf32>
// expected-error@below {{ShardingGroupOps values must have the same shape for groupId: 23}}
sdy.sharding_group %1 group_id = 23 : tensor<8x8x1xf32>
func.return %0: tensor<8x8xf32>
}

// -----

sdy.mesh @mesh = <["a"=2, "b"=2]>

// Throw error for sharding groups which have incompatible shardings inferred
// from initial constraints.
func.func @main(
%arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}, {}]>},
%arg1: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"b"}, {}]>}) {
// %0 = stablehlo.constant dense<0.0> : tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}, {}]>}
// %1 = stablehlo.constant dense<0.0> : tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"b"}, {}]>}
// Sharding Group and Sharding Constraint compatibility checks happend after
// unification + canonicalization of group ids, which is why the group id
// below (555) corresponds to group id: 0 in the check-error.
sdy.sharding_group %arg0 group_id = 555 : tensor<8x8xf32>
// expected-error@below {{Inconsistent shardings prior to propagation for ShardingGroupOps with canonicalized groupId: 0}}
sdy.sharding_group %arg1 group_id = 555 : tensor<8x8xf32>
func.return
}

// -----

sdy.mesh @mesh = <["a"=2, "b"=2]>

// Throw error for sharding groups which have incompatible shardings inferred
// from initial constraints.
func.func @main(
%arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}, {"b"}]>},
%arg1: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}, {}]>}) {

%0 = stablehlo.constant dense<0.0> : tensor<8x8xf32>
%1 = stablehlo.constant dense<0.0> : tensor<8x8xf32>

sdy.sharding_group %arg0 group_id = 10 : tensor<8x8xf32>
sdy.sharding_group %0 group_id = 10 : tensor<8x8xf32>
sdy.sharding_group %0 group_id = 20 : tensor<8x8xf32>
sdy.sharding_group %1 group_id = 20 : tensor<8x8xf32>

// The shard group below will cause the above sharding groups to be merged
// by transitivity this implies that all of {%arg0, %arg1, 0, 1} should have
// the same sharding. Note that %0 and %1 are compatible by them selves but
// %arg0 and %arg1 are not due to their initial shardings.
sdy.sharding_group %1 group_id = 30 : tensor<8x8xf32>
// expected-error@below {{Inconsistent shardings prior to propagation for ShardingGroupOps with canonicalized groupId: 0}}
sdy.sharding_group %arg1 group_id = 30 : tensor<8x8xf32>
func.return
}

Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,69 @@ func.func @sharding_groups_reindex_ordering_matches_min_element_ordering(%arg0:
sdy.sharding_group %arg2 group_id = 123456 : tensor<4xf32>
func.return
}

// -----

sdy.mesh @mesh = <["a"=2, "b"=2]>

// CHECK-LABEL: set_existing_shardings_for_sharding_group_members
func.func @set_existing_shardings_for_sharding_group_members(
%arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}, {"b"}]>},
%arg1: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}, {"b"}]>}) {
// CHECK: %cst = stablehlo.constant {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"a"}, {"b"}]>]>} dense<0.000000e+00> : tensor<8x8xf32>
%0 = stablehlo.constant dense<0.0> : tensor<8x8xf32>

sdy.sharding_group %arg0 group_id = 43210 : tensor<8x8xf32>
sdy.sharding_group %arg1 group_id = 43210 : tensor<8x8xf32>
sdy.sharding_group %0 group_id = 43210 : tensor<8x8xf32>
func.return
}

// -----

sdy.mesh @mesh = <["a"=2, "b"=2]>

// CHECK-LABEL: transitively_update_shardings_for_sharding_group_members
func.func @transitively_update_shardings_for_sharding_group_members(
%arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}, {}]>},
%arg1: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}, {}]>}) {
// CHECK: %cst = stablehlo.constant {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"a"}, {}]>]>} dense<0.000000e+00> : tensor<8x8xf32>
// CHECK: %cst_0 = stablehlo.constant {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"a"}, {}]>]>} dense<0.000000e+00> : tensor<8x8xf32>
%0 = stablehlo.constant dense<0.0> : tensor<8x8xf32>
%1 = stablehlo.constant dense<0.0> : tensor<8x8xf32>

sdy.sharding_group %arg0 group_id = 10 : tensor<8x8xf32>
sdy.sharding_group %0 group_id = 10 : tensor<8x8xf32>
sdy.sharding_group %0 group_id = 20 : tensor<8x8xf32>
sdy.sharding_group %1 group_id = 20 : tensor<8x8xf32>
sdy.sharding_group %1 group_id = 30 : tensor<8x8xf32>
sdy.sharding_group %arg1 group_id = 30 : tensor<8x8xf32>
func.return
}

// -----

sdy.mesh @mesh = <["a"=2, "b"=2]>

// CHECK-LABEL: set_existing_shards_for_disjoint_groups
func.func @set_existing_shards_for_disjoint_groups(
%arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}, {}]>},
%arg1: tensor<8x8xf32>,
%arg2: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {"b"}]>},
%arg3: tensor<8x8xf32>) {
// CHECK: %cst = stablehlo.constant {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"a"}, {}]>]>} dense<0.000000e+00> : tensor<8x8xf32>
%0 = stablehlo.constant dense<0.0> : tensor<8x8xf32>
// CHECK: %cst_0 = stablehlo.constant {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{}, {"b"}]>]>} dense<0.000000e+00> : tensor<8x8xf32>
%1 = stablehlo.constant dense<0.0> : tensor<8x8xf32>
// CHECK: %cst_1 = stablehlo.constant dense<0.000000e+00> : tensor<8x8xf32>
%2 = stablehlo.constant dense<0.0> : tensor<8x8xf32>

sdy.sharding_group %arg0 group_id = 11111 : tensor<8x8xf32>
sdy.sharding_group %arg1 group_id = 11111 : tensor<8x8xf32>
sdy.sharding_group %0 group_id = 11111 : tensor<8x8xf32>

sdy.sharding_group %arg2 group_id = 22222 : tensor<8x8xf32>
sdy.sharding_group %arg3 group_id = 22222 : tensor<8x8xf32>
sdy.sharding_group %1 group_id = 22222 : tensor<8x8xf32>
func.return
}

0 comments on commit 82bec19

Please sign in to comment.