Skip to content

Commit

Permalink
[SDY] Add validation to assert that sharding group values cannot cros…
Browse files Browse the repository at this point in the history
…s in between (and out of) manual computation ops.

PiperOrigin-RevId: 673727822
  • Loading branch information
Varcho authored and copybara-github committed Sep 17, 2024
1 parent 23fca62 commit 9825af4
Show file tree
Hide file tree
Showing 7 changed files with 328 additions and 109 deletions.
2 changes: 1 addition & 1 deletion shardy/dialect/sdy/transforms/import/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ cc_library(
"constant_splitter.cc",
"import_maximal_sharding.cc",
"import_pipeline.cc",
"sharding_group_unification.cc",
"sharding_group_import.cc",
],
hdrs = [
"passes.h",
Expand Down
2 changes: 1 addition & 1 deletion shardy/dialect/sdy/transforms/import/import_pipeline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ void addImportPipeline(OpPassManager& pm, StringRef dumpDirectory) {
pm.addNestedPass<func::FuncOp>(createConstantSplitterPass());
pm.addNestedPass<func::FuncOp>(createAddDataFlowEdgesPass());
pm.addNestedPass<func::FuncOp>(createApplyShardingConstraintsPass());
pm.addPass(createShardingGroupUnificationPass());
pm.addPass(createShardingGroupImportPass());
pm.addPass(createImportMaximalShardingPass());

GreedyRewriteConfig config;
Expand Down
29 changes: 19 additions & 10 deletions shardy/dialect/sdy/transforms/import/passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -85,17 +85,26 @@ def ConstantSplitterPass : Pass<"sdy-constant-splitter", "func::FuncOp"> {
let dependentDialects = ["mlir::sdy::SdyDialect"];
}

def ShardingGroupUnificationPass : Pass<"sdy-sharding-group-unification", "ModuleOp"> {
let summary = "Combines sharding groups to reduce them to a minimum set of canonical group ids.";
def ShardingGroupImportPass : Pass<"sdy-sharding-group-import", "ModuleOp"> {
let summary = "Canonicalization and validation pass for sharding groups.";
let description = [{
Combines sharding groups using the transitive property of group membership.

Any time that a tensor T is in a sharding group G1 *and* sharding group G2,
then we can infer that all members in G1 and G2 should be sharded in the
same way. Thus we can combine G1 and G2 into a single group.

The set of canonical group ids after merging will be 0,1,...N-1 for the
minimum set of groups.
Applies canonicalization and validation to sharding groups upon import.
Namely these are:

1) Sharding Group Unification -
Combines sharding groups using the transitive property of group
membership. Any time that a tensor T is in a sharding group G1 *and*
sharding group G2, then we can infer that all members in G1 and G2 should
be sharded in the same way. Thus we can combine G1 and G2 into a single
group. The set of canonical group ids after merging will be 0,1,...N-1
for the minimum set of groups.

2) Sharding Group Validation
Validates that sharding groups are well formed and conform to assumptions
within the implementation. This currently asserts that if a sharding
group contains a `Value` defined inside the block of a
`ManualComputationOp`, then all other values in that group must reside in
the same block.
}];
let dependentDialects = ["mlir::sdy::SdyDialect"];
}
Expand Down
136 changes: 136 additions & 0 deletions shardy/dialect/sdy/transforms/import/sharding_group_import.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
/* Copyright 2024 The Shardy Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include <cstdint>
#include <memory> // IWYU pragma: keep

#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/EquivalenceClasses.h"
#include "llvm/ADT/SmallVector.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Value.h"
#include "mlir/IR/Visitors.h"
#include "mlir/Pass/Pass.h" // IWYU pragma: keep
#include "mlir/Support/LLVM.h"
#include "shardy/dialect/sdy/ir/dialect.h"

namespace mlir {
namespace sdy {

#define GEN_PASS_DEF_SHARDINGGROUPIMPORTPASS
#include "shardy/dialect/sdy/transforms/import/passes.h.inc"

namespace {

using llvm::DenseMap;
using llvm::EquivalenceClasses;
using llvm::SmallDenseMap;
using llvm::SmallVector;

using ValueToShardingGroup =
llvm::DenseMap<Value, llvm::SmallVector<ShardingGroupOp>>;

void unifyShardingGroups(ValueToShardingGroup& tensorToGroups) {
if (tensorToGroups.empty()) {
return;
}
// Merge the equivalence classes of group ids which had the same tensors
// within them. (unionSets uses the default comparator and will consider the
// minimum group_id as the representative element of the equivalence class).
EquivalenceClasses<int64_t> shardingGroupEquivalences;
for (auto& [_, groupsForTensor] : tensorToGroups) {
const int64_t canonicalId = groupsForTensor.front().getGroupId();
for (ShardingGroupOp group : groupsForTensor) {
shardingGroupEquivalences.unionSets(canonicalId, group.getGroupId());
}
}

// After merging groups we reindex the group IDs so that they take values
// from the set {0,1,...,N-1} (N is the number of equivalence classes).
// The leader element of each equivalent class corresponds to the minimum
// group_id, so by looping over the group leaders in order their reindexed
// ids can be set to maintain the same relative ordering.
int64_t reindexId = 0;
SmallDenseMap<int64_t, int64_t> reindexMap;
for (const auto& group : shardingGroupEquivalences) {
if (group.isLeader()) {
reindexMap[group.getData()] = reindexId++;
}
}

// Update the graph to replace group_ids with their canonical id.
for (auto& [_, groupsForTensor] : tensorToGroups) {
for (ShardingGroupOp op : groupsForTensor) {
op.setGroupId(reindexMap[shardingGroupEquivalences.getLeaderValue(
op.getGroupId())]);
}
}
}

LogicalResult buildShardingGroupMappingAndValidateGroups(
ModuleOp module, ValueToShardingGroup& tensorToGroups) {
// Map to hold validation info for shard groups within manual computations.
DenseMap<int64_t, ManualComputationOp> groupToManualComp;

// While walking the graph we simultaneously build up the tensorToGroups
// mapping (which will be used for unification) while also validating the
// structure of shard groups meets expectations
WalkResult result = module.walk([&](ShardingGroupOp op) {
tensorToGroups[op.getInput()].push_back(op);

// Validate sharding groups. All values in a group should have either:
// 1) No manual computation op parent
// 2) The same manual computation op parent.
// If a group has no manual computation op parent, 'groupToManualComp'
// will map it to nullptr and ensure all other values in that group are
// also mapped to nullptr.
auto parent = op->getParentOfType<ManualComputationOp>();
int64_t groupId = op.getGroupId();

auto [it, inserted] = groupToManualComp.try_emplace(groupId, parent);
if (!inserted && it->second != parent) {
op.emitError(
"ShardingGroupOps values cannot cross ManualComputationOp "
"boundaries for groupId: ")
<< groupId;
return WalkResult::interrupt();
}

return WalkResult::advance();
});
return failure(result.wasInterrupted());
}

struct ShardingGroupImportPass
: public impl::ShardingGroupImportPassBase<ShardingGroupImportPass> {
using ShardingGroupImportPassBase::ShardingGroupImportPassBase;

void runOnOperation() final {
// Extract the sharding group ids and tensor -> {group_id} mapping from the
// high level module and validate any sharding group constrainst are met.
ValueToShardingGroup tensorToGroups;
if (failed(buildShardingGroupMappingAndValidateGroups(getOperation(),
tensorToGroups))) {
signalPassFailure();
}

unifyShardingGroups(tensorToGroups);
}
};

} // namespace

} // namespace sdy
} // namespace mlir
96 changes: 0 additions & 96 deletions shardy/dialect/sdy/transforms/import/sharding_group_unification.cc

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: sdy_opt -split-input-file %s -sdy-sharding-group-unification | FileCheck %s
// RUN: sdy_opt -split-input-file %s -sdy-sharding-group-import | FileCheck %s

// CHECK-LABEL: sharding_groups_no_overlap
func.func @sharding_groups_no_overlap(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) {
Expand Down
Loading

0 comments on commit 9825af4

Please sign in to comment.