From 9825af404a0415bcef1ecaa85ae5ae836f86cc5f Mon Sep 17 00:00:00 2001 From: Bill Varcho Date: Thu, 12 Sep 2024 01:14:34 -0700 Subject: [PATCH] [SDY] Add validation to assert that sharding group values cannot cross in between (and out of) manual computation ops. PiperOrigin-RevId: 673727822 --- shardy/dialect/sdy/transforms/import/BUILD | 2 +- .../sdy/transforms/import/import_pipeline.cc | 2 +- .../dialect/sdy/transforms/import/passes.td | 29 +-- .../import/sharding_group_import.cc | 136 ++++++++++++++ .../import/sharding_group_unification.cc | 96 ---------- ...cation.mlir => sharding_group_import.mlir} | 2 +- ...ding_group_manual_computation_barrier.mlir | 170 ++++++++++++++++++ 7 files changed, 328 insertions(+), 109 deletions(-) create mode 100644 shardy/dialect/sdy/transforms/import/sharding_group_import.cc delete mode 100644 shardy/dialect/sdy/transforms/import/sharding_group_unification.cc rename shardy/dialect/sdy/transforms/import/test/{sharding_group_unification.mlir => sharding_group_import.mlir} (97%) create mode 100644 shardy/dialect/sdy/transforms/import/test/sharding_group_manual_computation_barrier.mlir diff --git a/shardy/dialect/sdy/transforms/import/BUILD b/shardy/dialect/sdy/transforms/import/BUILD index fc7ad44..b11540a 100644 --- a/shardy/dialect/sdy/transforms/import/BUILD +++ b/shardy/dialect/sdy/transforms/import/BUILD @@ -40,7 +40,7 @@ cc_library( "constant_splitter.cc", "import_maximal_sharding.cc", "import_pipeline.cc", - "sharding_group_unification.cc", + "sharding_group_import.cc", ], hdrs = [ "passes.h", diff --git a/shardy/dialect/sdy/transforms/import/import_pipeline.cc b/shardy/dialect/sdy/transforms/import/import_pipeline.cc index fc589ba..5ff4027 100644 --- a/shardy/dialect/sdy/transforms/import/import_pipeline.cc +++ b/shardy/dialect/sdy/transforms/import/import_pipeline.cc @@ -36,7 +36,7 @@ void addImportPipeline(OpPassManager& pm, StringRef dumpDirectory) { pm.addNestedPass(createConstantSplitterPass()); pm.addNestedPass(createAddDataFlowEdgesPass()); pm.addNestedPass(createApplyShardingConstraintsPass()); - pm.addPass(createShardingGroupUnificationPass()); + pm.addPass(createShardingGroupImportPass()); pm.addPass(createImportMaximalShardingPass()); GreedyRewriteConfig config; diff --git a/shardy/dialect/sdy/transforms/import/passes.td b/shardy/dialect/sdy/transforms/import/passes.td index 1156b5a..93cc197 100644 --- a/shardy/dialect/sdy/transforms/import/passes.td +++ b/shardy/dialect/sdy/transforms/import/passes.td @@ -85,17 +85,26 @@ def ConstantSplitterPass : Pass<"sdy-constant-splitter", "func::FuncOp"> { let dependentDialects = ["mlir::sdy::SdyDialect"]; } -def ShardingGroupUnificationPass : Pass<"sdy-sharding-group-unification", "ModuleOp"> { - let summary = "Combines sharding groups to reduce them to a minimum set of canonical group ids."; +def ShardingGroupImportPass : Pass<"sdy-sharding-group-import", "ModuleOp"> { + let summary = "Canonicalization and validation pass for sharding groups."; let description = [{ - Combines sharding groups using the transitive property of group membership. - - Any time that a tensor T is in a sharding group G1 *and* sharding group G2, - then we can infer that all members in G1 and G2 should be sharded in the - same way. Thus we can combine G1 and G2 into a single group. - - The set of canonical group ids after merging will be 0,1,...N-1 for the - minimum set of groups. + Applies canonicalization and validation to sharding groups upon import. + Namely these are: + + 1) Sharding Group Unification - + Combines sharding groups using the transitive property of group + membership. Any time that a tensor T is in a sharding group G1 *and* + sharding group G2, then we can infer that all members in G1 and G2 should + be sharded in the same way. Thus we can combine G1 and G2 into a single + group. The set of canonical group ids after merging will be 0,1,...N-1 + for the minimum set of groups. + + 2) Sharding Group Validation + Validates that sharding groups are well formed and conform to assumptions + within the implementation. This currently asserts that if a sharding + group contains a `Value` defined inside the block of a + `ManualComputationOp`, then all other values in that group must reside in + the same block. }]; let dependentDialects = ["mlir::sdy::SdyDialect"]; } diff --git a/shardy/dialect/sdy/transforms/import/sharding_group_import.cc b/shardy/dialect/sdy/transforms/import/sharding_group_import.cc new file mode 100644 index 0000000..8b4f2d1 --- /dev/null +++ b/shardy/dialect/sdy/transforms/import/sharding_group_import.cc @@ -0,0 +1,136 @@ +/* Copyright 2024 The Shardy Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include // IWYU pragma: keep + +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/EquivalenceClasses.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Value.h" +#include "mlir/IR/Visitors.h" +#include "mlir/Pass/Pass.h" // IWYU pragma: keep +#include "mlir/Support/LLVM.h" +#include "shardy/dialect/sdy/ir/dialect.h" + +namespace mlir { +namespace sdy { + +#define GEN_PASS_DEF_SHARDINGGROUPIMPORTPASS +#include "shardy/dialect/sdy/transforms/import/passes.h.inc" + +namespace { + +using llvm::DenseMap; +using llvm::EquivalenceClasses; +using llvm::SmallDenseMap; +using llvm::SmallVector; + +using ValueToShardingGroup = + llvm::DenseMap>; + +void unifyShardingGroups(ValueToShardingGroup& tensorToGroups) { + if (tensorToGroups.empty()) { + return; + } + // Merge the equivalence classes of group ids which had the same tensors + // within them. (unionSets uses the default comparator and will consider the + // minimum group_id as the representative element of the equivalence class). + EquivalenceClasses shardingGroupEquivalences; + for (auto& [_, groupsForTensor] : tensorToGroups) { + const int64_t canonicalId = groupsForTensor.front().getGroupId(); + for (ShardingGroupOp group : groupsForTensor) { + shardingGroupEquivalences.unionSets(canonicalId, group.getGroupId()); + } + } + + // After merging groups we reindex the group IDs so that they take values + // from the set {0,1,...,N-1} (N is the number of equivalence classes). + // The leader element of each equivalent class corresponds to the minimum + // group_id, so by looping over the group leaders in order their reindexed + // ids can be set to maintain the same relative ordering. + int64_t reindexId = 0; + SmallDenseMap reindexMap; + for (const auto& group : shardingGroupEquivalences) { + if (group.isLeader()) { + reindexMap[group.getData()] = reindexId++; + } + } + + // Update the graph to replace group_ids with their canonical id. + for (auto& [_, groupsForTensor] : tensorToGroups) { + for (ShardingGroupOp op : groupsForTensor) { + op.setGroupId(reindexMap[shardingGroupEquivalences.getLeaderValue( + op.getGroupId())]); + } + } +} + +LogicalResult buildShardingGroupMappingAndValidateGroups( + ModuleOp module, ValueToShardingGroup& tensorToGroups) { + // Map to hold validation info for shard groups within manual computations. + DenseMap groupToManualComp; + + // While walking the graph we simultaneously build up the tensorToGroups + // mapping (which will be used for unification) while also validating the + // structure of shard groups meets expectations + WalkResult result = module.walk([&](ShardingGroupOp op) { + tensorToGroups[op.getInput()].push_back(op); + + // Validate sharding groups. All values in a group should have either: + // 1) No manual computation op parent + // 2) The same manual computation op parent. + // If a group has no manual computation op parent, 'groupToManualComp' + // will map it to nullptr and ensure all other values in that group are + // also mapped to nullptr. + auto parent = op->getParentOfType(); + int64_t groupId = op.getGroupId(); + + auto [it, inserted] = groupToManualComp.try_emplace(groupId, parent); + if (!inserted && it->second != parent) { + op.emitError( + "ShardingGroupOps values cannot cross ManualComputationOp " + "boundaries for groupId: ") + << groupId; + return WalkResult::interrupt(); + } + + return WalkResult::advance(); + }); + return failure(result.wasInterrupted()); +} + +struct ShardingGroupImportPass + : public impl::ShardingGroupImportPassBase { + using ShardingGroupImportPassBase::ShardingGroupImportPassBase; + + void runOnOperation() final { + // Extract the sharding group ids and tensor -> {group_id} mapping from the + // high level module and validate any sharding group constrainst are met. + ValueToShardingGroup tensorToGroups; + if (failed(buildShardingGroupMappingAndValidateGroups(getOperation(), + tensorToGroups))) { + signalPassFailure(); + } + + unifyShardingGroups(tensorToGroups); + } +}; + +} // namespace + +} // namespace sdy +} // namespace mlir diff --git a/shardy/dialect/sdy/transforms/import/sharding_group_unification.cc b/shardy/dialect/sdy/transforms/import/sharding_group_unification.cc deleted file mode 100644 index 1cdb695..0000000 --- a/shardy/dialect/sdy/transforms/import/sharding_group_unification.cc +++ /dev/null @@ -1,96 +0,0 @@ -/* Copyright 2024 The Shardy Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include // IWYU pragma: keep - -#include "llvm/ADT/DenseMap.h" -#include "llvm/ADT/EquivalenceClasses.h" -#include "llvm/ADT/SmallVector.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/Value.h" -#include "mlir/Pass/Pass.h" // IWYU pragma: keep -#include "mlir/Support/LLVM.h" -#include "shardy/dialect/sdy/ir/dialect.h" - -namespace mlir { -namespace sdy { - -#define GEN_PASS_DEF_SHARDINGGROUPUNIFICATIONPASS -#include "shardy/dialect/sdy/transforms/import/passes.h.inc" - -namespace { - -using llvm::DenseMap; -using llvm::EquivalenceClasses; -using llvm::SmallDenseMap; -using llvm::SmallVector; - -struct ShardingGroupUnificationPass - : public impl::ShardingGroupUnificationPassBase< - ShardingGroupUnificationPass> { - using ShardingGroupUnificationPassBase::ShardingGroupUnificationPassBase; - - void runOnOperation() final { - // Extract the sharding group ids and tensor -> {group_id} mapping from the - // high level module, and initialize the equivalence classes for the group - // ids present. - DenseMap> tensorToGroups; - ModuleOp module = getOperation(); - module.walk([&](ShardingGroupOp op) { - tensorToGroups[op.getInput()].push_back(op); - }); - if (tensorToGroups.empty()) { - return; - } - - // Merge the equivalence classes of group ids which had the same tensors - // within them. (unionSets uses the default comparator and will consider the - // minimum group_id as the representative element of the equivalence class). - EquivalenceClasses shardingGroupEquivalences; - for (auto& [_, groupsForTensor] : tensorToGroups) { - const int64_t canonicalId = groupsForTensor.front().getGroupId(); - for (ShardingGroupOp group : groupsForTensor) { - shardingGroupEquivalences.unionSets(canonicalId, group.getGroupId()); - } - } - - // After merging groups we reindex the group IDs so that they take values - // from the set {0,1,...,N-1} (N is the number of equivalence classes). - // The leader element of each equivalent class corresponds to the minimum - // group_id, so by looping over the group leaders in order their reindexed - // ids can be set to maintain the same relative ordering. - int64_t reindexId = 0; - SmallDenseMap reindexMap; - for (const auto& group : shardingGroupEquivalences) { - if (group.isLeader()) { - reindexMap[group.getData()] = reindexId++; - } - } - - // Update the graph to replace group_ids with their canonical id. - for (auto& [_, groupsForTensor] : tensorToGroups) { - for (ShardingGroupOp op : groupsForTensor) { - op.setGroupId(reindexMap[shardingGroupEquivalences.getLeaderValue( - op.getGroupId())]); - } - } - } -}; - -} // namespace - -} // namespace sdy -} // namespace mlir diff --git a/shardy/dialect/sdy/transforms/import/test/sharding_group_unification.mlir b/shardy/dialect/sdy/transforms/import/test/sharding_group_import.mlir similarity index 97% rename from shardy/dialect/sdy/transforms/import/test/sharding_group_unification.mlir rename to shardy/dialect/sdy/transforms/import/test/sharding_group_import.mlir index a858fe5..7cd8589 100644 --- a/shardy/dialect/sdy/transforms/import/test/sharding_group_unification.mlir +++ b/shardy/dialect/sdy/transforms/import/test/sharding_group_import.mlir @@ -1,4 +1,4 @@ -// RUN: sdy_opt -split-input-file %s -sdy-sharding-group-unification | FileCheck %s +// RUN: sdy_opt -split-input-file %s -sdy-sharding-group-import | FileCheck %s // CHECK-LABEL: sharding_groups_no_overlap func.func @sharding_groups_no_overlap(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) { diff --git a/shardy/dialect/sdy/transforms/import/test/sharding_group_manual_computation_barrier.mlir b/shardy/dialect/sdy/transforms/import/test/sharding_group_manual_computation_barrier.mlir new file mode 100644 index 0000000..4c386fc --- /dev/null +++ b/shardy/dialect/sdy/transforms/import/test/sharding_group_manual_computation_barrier.mlir @@ -0,0 +1,170 @@ +// RUN: sdy_opt %s -split-input-file -sdy-sharding-group-import -verify-diagnostics + +sdy.mesh @mesh = <["a"=2, "b"=2]> + +// Allow sharding groups where group values don't cross ManualComputationOps +// barrier. +func.func @main(%arg0: tensor<8x8xf32>) -> tensor<8x8xf32> { + %0 = sdy.manual_computation(%arg0) in_shardings=[<@mesh, [{?}, {?}]>] out_shardings=[<@mesh, [{?}, {?}]>] manual_axes={} (%arg1: tensor<8x8xf32>) { + %1 = stablehlo.add %arg1, %arg1 : tensor<8x8xf32> + sdy.sharding_group %1 group_id = 8675 : tensor<8x8xf32> + sdy.return %1 : tensor<8x8xf32> + } : (tensor<8x8xf32>) -> tensor<8x8xf32> + + sdy.sharding_group %0 group_id = 309 : tensor<8x8xf32> + func.return %0: tensor<8x8xf32> +} + +// ----- + +sdy.mesh @mesh = <["a"=2, "b"=2]> + +// Don't permit the creation of sharding groups which mix values who have parent +// ManualComputationOps with those that don't. +func.func @main(%arg0: tensor<8x8xf32>) -> tensor<8x8xf32> { + %0 = sdy.manual_computation(%arg0) in_shardings=[<@mesh, [{?}, {?}]>] out_shardings=[<@mesh, [{?}, {?}]>] manual_axes={} (%arg1: tensor<8x8xf32>) { + %1 = stablehlo.add %arg1, %arg1 : tensor<8x8xf32> + sdy.sharding_group %1 group_id = 90210 : tensor<8x8xf32> + sdy.return %1 : tensor<8x8xf32> + } : (tensor<8x8xf32>) -> tensor<8x8xf32> + + // expected-error@below {{ShardingGroupOps values cannot cross ManualComputationOp boundaries for groupId: 90210}} + sdy.sharding_group %0 group_id = 90210 : tensor<8x8xf32> + func.return %0: tensor<8x8xf32> +} + +// ----- + +sdy.mesh @mesh = <["a"=2, "b"=2]> + +// Don't permit the creation of sharding groups which have different manual +// computation op parents +func.func @main(%arg0: tensor<8x8xf32>) -> tensor<8x8xf32> { + %0 = sdy.manual_computation(%arg0) in_shardings=[<@mesh, [{?}, {?}]>] out_shardings=[<@mesh, [{?}, {?}]>] manual_axes={} (%arg1: tensor<8x8xf32>) { + %1 = stablehlo.add %arg1, %arg1 : tensor<8x8xf32> + sdy.sharding_group %1 group_id = 44094 : tensor<8x8xf32> + sdy.return %1 : tensor<8x8xf32> + } : (tensor<8x8xf32>) -> tensor<8x8xf32> + + %4 = sdy.manual_computation(%arg0) in_shardings=[<@mesh, [{?}, {?}]>] out_shardings=[<@mesh, [{?}, {?}]>] manual_axes={} (%2: tensor<8x8xf32>) { + %3 = stablehlo.add %2, %2 : tensor<8x8xf32> + // expected-error@below {{ShardingGroupOps values cannot cross ManualComputationOp boundaries for groupId: 44094}} + sdy.sharding_group %3 group_id = 44094 : tensor<8x8xf32> + sdy.return %3 : tensor<8x8xf32> + } : (tensor<8x8xf32>) -> tensor<8x8xf32> + + func.return %4: tensor<8x8xf32> +} + +// ----- + +sdy.mesh @mesh = <["a"=2, "b"=2]> + +// Allow sharding groups with values that remain in the same level of a nested +// manual computation +func.func @main(%arg0: tensor<8x8xf32>) -> tensor<8x8xf32> { + %0 = sdy.manual_computation(%arg0) in_shardings=[<@mesh, [{?}, {?}]>] out_shardings=[<@mesh, [{?}, {?}]>] manual_axes={} (%arg1: tensor<8x8xf32>) { + %1 = stablehlo.add %arg1, %arg1 : tensor<8x8xf32> + %2 = stablehlo.add %1, %arg1 : tensor<8x8xf32> + + %3 = sdy.manual_computation(%2) in_shardings=[<@mesh, [{?}, {?}]>] out_shardings=[<@mesh, [{?}, {?}]>] manual_axes={} (%arg2: tensor<8x8xf32>) { + %4 = stablehlo.add %arg2, %arg2 : tensor<8x8xf32> + %5 = stablehlo.add %4, %arg2 : tensor<8x8xf32> + sdy.sharding_group %4 group_id = 1881 : tensor<8x8xf32> + sdy.sharding_group %5 group_id = 1881 : tensor<8x8xf32> + sdy.return %5 : tensor<8x8xf32> + } : (tensor<8x8xf32>) -> tensor<8x8xf32> + + sdy.sharding_group %1 group_id = 8008 : tensor<8x8xf32> + sdy.sharding_group %2 group_id = 8008 : tensor<8x8xf32> + sdy.return %1 : tensor<8x8xf32> + } : (tensor<8x8xf32>) -> tensor<8x8xf32> + func.return %0: tensor<8x8xf32> +} + +// ----- + +sdy.mesh @mesh = <["a"=2, "b"=2]> + +// Don't allow sharding groups with values at different levels of blocks within +// a nested manual computation. +func.func @main(%arg0: tensor<8x8xf32>) -> tensor<8x8xf32> { + %0 = sdy.manual_computation(%arg0) in_shardings=[<@mesh, [{?}, {?}]>] out_shardings=[<@mesh, [{?}, {?}]>] manual_axes={} (%arg1: tensor<8x8xf32>) { + %1 = stablehlo.add %arg1, %arg1 : tensor<8x8xf32> + %2 = sdy.manual_computation(%1) in_shardings=[<@mesh, [{?}, {?}]>] out_shardings=[<@mesh, [{?}, {?}]>] manual_axes={} (%arg2: tensor<8x8xf32>) { + %3 = stablehlo.add %arg2, %arg2 : tensor<8x8xf32> + sdy.sharding_group %3 group_id = 4311 : tensor<8x8xf32> + sdy.return %3 : tensor<8x8xf32> + } : (tensor<8x8xf32>) -> tensor<8x8xf32> + + // expected-error@below {{ShardingGroupOps values cannot cross ManualComputationOp boundaries for groupId: 4311}} + sdy.sharding_group %1 group_id = 4311 : tensor<8x8xf32> + sdy.return %1 : tensor<8x8xf32> + } : (tensor<8x8xf32>) -> tensor<8x8xf32> + func.return %0: tensor<8x8xf32> +} + +// ----- + +sdy.mesh @mesh = <["a"=2, "b"=2]> + +// Allow sharding groups within ManualComputationOp which has a WhileOp and +// 1) Some group members are outside of the WhileOp and some are inside +// 2) All ops have the same parent ManualComputationOp +func.func @main(%arg0: tensor<8x8xf32>) -> tensor<8x8xf32> { + %0 = sdy.manual_computation(%arg0) in_shardings=[<@mesh, [{?}, {?}]>] out_shardings=[<@mesh, [{?}, {?}]>] manual_axes={} (%arg1: tensor<8x8xf32>) { + %1 = stablehlo.add %arg1, %arg1 : tensor<8x8xf32> + %2 = stablehlo.add %1, %arg1 : tensor<8x8xf32> + sdy.sharding_group %1 group_id = 1337 : tensor<8x8xf32> + sdy.sharding_group %2 group_id = 1337 : tensor<8x8xf32> + + %3 = stablehlo.constant dense<0> : tensor + %4 = stablehlo.constant dense<1> : tensor + %5 = stablehlo.constant dense<32> : tensor + %6:2 = stablehlo.while(%iterArg = %2, %iterArg_2 = %3) : tensor<8x8xf32>, tensor + cond { + %7 = stablehlo.compare LT, %iterArg_2, %5 : (tensor, tensor) -> tensor + stablehlo.return %7 : tensor + } do { + %7 = sdy.data_flow_edge %iterArg sharding=<@mesh, [{"a"}, {}]> : tensor<8x8xf32> + %8 = stablehlo.add %iterArg_2, %4 : tensor + %9 = stablehlo.add %7, %7 : tensor<8x8xf32> + sdy.sharding_group %9 group_id = 1337 : tensor<8x8xf32> + stablehlo.return %9, %8 : tensor<8x8xf32>, tensor + } + sdy.return %2 : tensor<8x8xf32> + } : (tensor<8x8xf32>) -> tensor<8x8xf32> + func.return %0: tensor<8x8xf32> +} + +// ----- + +sdy.mesh @mesh = <["a"=2, "b"=2]> + +// Disallow sharding groups which cross the barrier of a ManualComputationOp +// (and also a while op). +func.func @main(%arg0: tensor<8x8xf32>) -> tensor<8x8xf32> { + %0 = sdy.manual_computation(%arg0) in_shardings=[<@mesh, [{?}, {?}]>] out_shardings=[<@mesh, [{?}, {?}]>] manual_axes={} (%arg1: tensor<8x8xf32>) { + %1 = stablehlo.add %arg1, %arg1 : tensor<8x8xf32> + %2 = stablehlo.add %1, %arg1 : tensor<8x8xf32> + %3 = stablehlo.constant dense<0> : tensor + %4 = stablehlo.constant dense<1> : tensor + %5 = stablehlo.constant dense<32> : tensor + %6:2 = stablehlo.while(%iterArg = %2, %iterArg_2 = %3) : tensor<8x8xf32>, tensor + cond { + %7 = stablehlo.compare LT, %iterArg_2, %5 : (tensor, tensor) -> tensor + stablehlo.return %7 : tensor + } do { + %7 = sdy.data_flow_edge %iterArg sharding=<@mesh, [{"a"}, {}]> : tensor<8x8xf32> + %8 = stablehlo.add %iterArg_2, %4 : tensor + %9 = stablehlo.add %7, %7 : tensor<8x8xf32> + sdy.sharding_group %9 group_id = 7331 : tensor<8x8xf32> + stablehlo.return %9, %8 : tensor<8x8xf32>, tensor + } + sdy.return %2 : tensor<8x8xf32> + } : (tensor<8x8xf32>) -> tensor<8x8xf32> + + // expected-error@below {{ShardingGroupOps values cannot cross ManualComputationOp boundaries for groupId: 7331}} + sdy.sharding_group %0 group_id = 7331 : tensor<8x8xf32> + func.return %0: tensor<8x8xf32> +}