diff --git a/third_party/llvm/generated.patch b/third_party/llvm/generated.patch index ed6047cd2b8f2..509398da979e8 100644 --- a/third_party/llvm/generated.patch +++ b/third_party/llvm/generated.patch @@ -1,28 +1 @@ Auto generated patch. Do not edit or delete it, even if empty. -diff -ruN --strip-trailing-cr a/clang/test/CodeGen/X86/avx10_2_512satcvtds-builtins-errors.c b/clang/test/CodeGen/X86/avx10_2_512satcvtds-builtins-errors.c ---- a/clang/test/CodeGen/X86/avx10_2_512satcvtds-builtins-errors.c -+++ b/clang/test/CodeGen/X86/avx10_2_512satcvtds-builtins-errors.c -@@ -1,4 +1,4 @@ --// RUN: %clang_cc1 -flax-vector-conversions=none -ffreestanding %s -triple=i386-unknown-unknown -target-feature +avx10.2-512 -emit-llvm -Wall -Werror -verify -+// RUN: %clang_cc1 -flax-vector-conversions=none -ffreestanding %s -triple=i386-unknown-unknown -target-feature +avx10.2-512 -Wall -Werror -verify - - #include - #include -diff -ruN --strip-trailing-cr a/clang/test/CodeGen/X86/avx10_2_512satcvtds-builtins-x64-error.c b/clang/test/CodeGen/X86/avx10_2_512satcvtds-builtins-x64-error.c ---- a/clang/test/CodeGen/X86/avx10_2_512satcvtds-builtins-x64-error.c -+++ b/clang/test/CodeGen/X86/avx10_2_512satcvtds-builtins-x64-error.c -@@ -1,4 +1,4 @@ --// RUN: %clang_cc1 -flax-vector-conversions=none -ffreestanding %s -triple=x86_64-unknown-unknown -target-feature +avx10.2-512 -emit-llvm -Wall -Werror -verify -+// RUN: %clang_cc1 -flax-vector-conversions=none -ffreestanding %s -triple=x86_64-unknown-unknown -target-feature +avx10.2-512 -Wall -Werror -verify - - #include - #include -diff -ruN --strip-trailing-cr a/clang/test/CodeGen/X86/avx10_2satcvtds-builtins-errors.c b/clang/test/CodeGen/X86/avx10_2satcvtds-builtins-errors.c ---- a/clang/test/CodeGen/X86/avx10_2satcvtds-builtins-errors.c -+++ b/clang/test/CodeGen/X86/avx10_2satcvtds-builtins-errors.c -@@ -1,4 +1,4 @@ --// RUN: %clang_cc1 -flax-vector-conversions=none -ffreestanding %s -triple=i386-unknown-unknown -target-feature +avx10.2-256 -emit-llvm -Wall -Werror -verify -+// RUN: %clang_cc1 -flax-vector-conversions=none -ffreestanding %s -triple=i386-unknown-unknown -target-feature +avx10.2-256 -Wall -Werror -verify - - unsigned long long test_mm_cvttssd(unsigned long long __A) { - return _mm_cvttssd(__A); // expected-error {{call to undeclared function '_mm_cvttssd'}} diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl index aaf1f1a2db970..2f731c7a2e4da 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" - LLVM_COMMIT = "f0b3287297aeeddcf030e3c1b08d05a69ad465aa" - LLVM_SHA256 = "3bc65e7a760a389f5ace1146cb2ffde724a272e97e71c8b8509149e827df6c83" + LLVM_COMMIT = "36adf8ecedb64047021265a1e1730773d3b3a9e8" + LLVM_SHA256 = "7baedfc21f67f64f054482cbe77cb3049cd4428187cd45799e10ff8eb03dc9f6" tf_http_archive( name = name, diff --git a/third_party/shardy/temporary.patch b/third_party/shardy/temporary.patch index 5ddce7071d49f..101b6b56e4ca4 100644 --- a/third_party/shardy/temporary.patch +++ b/third_party/shardy/temporary.patch @@ -1,48 +1,496 @@ -diff --git a/third_party/llvm/generated.patch b/third_party/llvm/generated.patch -index 509398d..ed6047c 100644 ---- a/third_party/llvm/generated.patch -+++ b/third_party/llvm/generated.patch -@@ -1 +1,28 @@ - Auto generated patch. Do not edit or delete it, even if empty. -+diff -ruN --strip-trailing-cr a/clang/test/CodeGen/X86/avx10_2_512satcvtds-builtins-errors.c b/clang/test/CodeGen/X86/avx10_2_512satcvtds-builtins-errors.c -+--- a/clang/test/CodeGen/X86/avx10_2_512satcvtds-builtins-errors.c -++++ b/clang/test/CodeGen/X86/avx10_2_512satcvtds-builtins-errors.c -+@@ -1,4 +1,4 @@ -+-// RUN: %clang_cc1 -flax-vector-conversions=none -ffreestanding %s -triple=i386-unknown-unknown -target-feature +avx10.2-512 -emit-llvm -Wall -Werror -verify -++// RUN: %clang_cc1 -flax-vector-conversions=none -ffreestanding %s -triple=i386-unknown-unknown -target-feature +avx10.2-512 -Wall -Werror -verify -+ -+ #include -+ #include -+diff -ruN --strip-trailing-cr a/clang/test/CodeGen/X86/avx10_2_512satcvtds-builtins-x64-error.c b/clang/test/CodeGen/X86/avx10_2_512satcvtds-builtins-x64-error.c -+--- a/clang/test/CodeGen/X86/avx10_2_512satcvtds-builtins-x64-error.c -++++ b/clang/test/CodeGen/X86/avx10_2_512satcvtds-builtins-x64-error.c -+@@ -1,4 +1,4 @@ -+-// RUN: %clang_cc1 -flax-vector-conversions=none -ffreestanding %s -triple=x86_64-unknown-unknown -target-feature +avx10.2-512 -emit-llvm -Wall -Werror -verify -++// RUN: %clang_cc1 -flax-vector-conversions=none -ffreestanding %s -triple=x86_64-unknown-unknown -target-feature +avx10.2-512 -Wall -Werror -verify -+ -+ #include -+ #include -+diff -ruN --strip-trailing-cr a/clang/test/CodeGen/X86/avx10_2satcvtds-builtins-errors.c b/clang/test/CodeGen/X86/avx10_2satcvtds-builtins-errors.c -+--- a/clang/test/CodeGen/X86/avx10_2satcvtds-builtins-errors.c -++++ b/clang/test/CodeGen/X86/avx10_2satcvtds-builtins-errors.c -+@@ -1,4 +1,4 @@ -+-// RUN: %clang_cc1 -flax-vector-conversions=none -ffreestanding %s -triple=i386-unknown-unknown -target-feature +avx10.2-256 -emit-llvm -Wall -Werror -verify -++// RUN: %clang_cc1 -flax-vector-conversions=none -ffreestanding %s -triple=i386-unknown-unknown -target-feature +avx10.2-256 -Wall -Werror -verify -+ -+ unsigned long long test_mm_cvttssd(unsigned long long __A) { -+ return _mm_cvttssd(__A); // expected-error {{call to undeclared function '_mm_cvttssd'}} +diff --git a/shardy/dialect/sdy/ir/data_flow_utils.cc b/shardy/dialect/sdy/ir/data_flow_utils.cc +index e2cea1f..e53adf5 100644 +--- a/shardy/dialect/sdy/ir/data_flow_utils.cc ++++ b/shardy/dialect/sdy/ir/data_flow_utils.cc +@@ -31,6 +31,11 @@ namespace sdy { + + namespace { + ++bool isDataFlowOp(Operation* op) { ++ return isa(op); ++} ++ + // Gets the owning op if it is a shardable data flow op interface op. + ShardableDataFlowOpInterface getOwningShardableDataFlowOp(Value value) { + return dyn_cast(getOwningOp(value)); +@@ -71,12 +76,6 @@ Value getDataFlowEdgeOwner(OpOperand& source) { + } + + } // namespace +- +-bool isDataFlowOp(Operation* op) { +- return isa(op); +-} +- + ResultRange getDataFlowEdgeResultOwners(Operation* op) { + if (auto shardableDataFlowOp = dyn_cast(op)) { + return shardableDataFlowOp.getOpResultEdgeOwners(); +@@ -88,19 +87,6 @@ ResultRange getDataFlowEdgeResultOwners(Operation* op) { + return ResultRange(nullptr, 0); + } + +-ArrayRef getDataFlowEdgeBlockArgumentOwners(Operation* op) { +- if (auto shardableDataFlowOp = dyn_cast(op)) { +- return shardableDataFlowOp.getBlockArgumentEdgeOwners(); +- } +- return {}; +-} +- +-void setBlockArgumentEdgeOwnerShardings( +- Operation* op, ArrayRef shardings) { +- cast(op).setBlockArgumentEdgeOwnerShardings( +- shardings); +-} +- + DataFlowEdgeOp getDataFlowEdge(Value target) { + return DataFlowEdgeOp::getDataFlowEdgeUser(getDataFlowEdgeOwner(target)); + } +diff --git a/shardy/dialect/sdy/ir/data_flow_utils.h b/shardy/dialect/sdy/ir/data_flow_utils.h +index 99ffde8..992a0b8 100644 +--- a/shardy/dialect/sdy/ir/data_flow_utils.h ++++ b/shardy/dialect/sdy/ir/data_flow_utils.h +@@ -35,18 +35,10 @@ namespace sdy { + + // See `DataFlowEdgeOp` documentation for more information on data-flow edges. + +-// Returns true if the `op` defines data flow edges, e.g. , it's a +-// `ShardableDataFlowOpInterface`. +-bool isDataFlowOp(Operation* op); +- + // If `op` has data-flow edges, returns their op result edge owners (e.g., all + // results of a while/case op), otherwise returns an empty range. + ResultRange getDataFlowEdgeResultOwners(Operation* op); + +-// If `op` is a `ShardableDataFlowOpInterface` which can have block argument +-// edge owners, returns the owners, otherwise returns an empty range. +-ArrayRef getDataFlowEdgeBlockArgumentOwners(Operation* op); +- + // If `target` is a target of a data-flow edge, returns the corresponding + // `DataFlowEdgeOp`, otherwise returns `nullptr`. + DataFlowEdgeOp getDataFlowEdge(Value target); +@@ -62,11 +54,6 @@ SmallVector getDataFlowSources(DataFlowEdgeOp dataFlowEdge); + void forEachNonEdgeOwnerDataFlowTarget(DataFlowEdgeOp dataFlowEdge, + std::function fn); + +-// Sets the block argument edge owner shardings if the `op` is a +-// `ShardableDataFlowOpInterface`. +-void setBlockArgumentEdgeOwnerShardings(Operation* op, +- ArrayRef shardings); +- + } // namespace sdy + } // namespace mlir + +diff --git a/shardy/dialect/sdy/ir/parsers.cc b/shardy/dialect/sdy/ir/parsers.cc +index c48bb2a..7748bc2 100644 +--- a/shardy/dialect/sdy/ir/parsers.cc ++++ b/shardy/dialect/sdy/ir/parsers.cc +@@ -199,7 +199,7 @@ FailureOr parseFactorSymbolIndex(AsmParser& parser, + ParseResult parseSymbolIndices(AsmParser& parser, StringRef factorsStr, + SmallVector& indices) { + while (!factorsStr.empty()) { +- // TODO(bartchr): Add SDY_ASSIGN_OR_RETURN_FAILURE macro for re-returning ++ // TODO(bartchr): Add ASSIGN_OR_RETURN_FAILURE macro for re-returning + // failures. Or check if there already is one in MLIR. + FailureOr index = parseFactorSymbolIndex(parser, factorsStr); + if (failed(index)) { +diff --git a/shardy/dialect/sdy/transforms/common/macros.h b/shardy/dialect/sdy/transforms/common/macros.h +index c69c1e2..c16502a 100644 +--- a/shardy/dialect/sdy/transforms/common/macros.h ++++ b/shardy/dialect/sdy/transforms/common/macros.h +@@ -17,9 +17,8 @@ limitations under the License. + #define SHARDY_DIALECT_SDY_TRANSFORMS_COMMON_MACROS_H_ + + // Macro to assign value from std::optional or return std::nullopt. +-#define SDY_ASSIGN_OR_RETURN_IF_NULLOPT(lhs, expr) \ +- SDY_ASSIGN_OR_RETURN_IF_NULLOPT_IMPL(CONCAT_(_expr_result, __LINE__), lhs, \ +- expr) ++#define ASSIGN_OR_RETURN_IF_NULLOPT(lhs, expr) \ ++ ASSIGN_OR_RETURN_IF_NULLOPT_IMPL(CONCAT_(_expr_result, __LINE__), lhs, expr) + + // ================================================================= + // == Implementation details, do not rely on anything below here. == +@@ -28,11 +27,11 @@ limitations under the License. + #define CONCAT_INNER_(x, y) x##y + #define CONCAT_(x, y) CONCAT_INNER_(x, y) + +-#define SDY_ASSIGN_OR_RETURN_IF_NULLOPT_IMPL(result, lhs, expr) \ +- auto result = expr; \ +- if (!result.has_value()) { \ +- return std::nullopt; \ +- } \ ++#define ASSIGN_OR_RETURN_IF_NULLOPT_IMPL(result, lhs, expr) \ ++ auto result = expr; \ ++ if (!result.has_value()) { \ ++ return std::nullopt; \ ++ } \ + lhs = std::move(result).value(); + + #endif // SHARDY_DIALECT_SDY_TRANSFORMS_COMMON_MACROS_H_ +diff --git a/shardy/dialect/sdy/transforms/export/sink_data_flow_edges.cc b/shardy/dialect/sdy/transforms/export/sink_data_flow_edges.cc +index 9b2f62a..ab4d281 100644 +--- a/shardy/dialect/sdy/transforms/export/sink_data_flow_edges.cc ++++ b/shardy/dialect/sdy/transforms/export/sink_data_flow_edges.cc +@@ -15,19 +15,20 @@ limitations under the License. + + #include + #include // IWYU pragma: keep ++#include + + #include "llvm/ADT/STLExtras.h" +-#include "llvm/ADT/SmallVector.h" + #include "mlir/Dialect/Func/IR/FuncOps.h" ++#include "mlir/IR/MLIRContext.h" + #include "mlir/IR/Operation.h" + #include "mlir/IR/PatternMatch.h" + #include "mlir/IR/Value.h" +-#include "mlir/IR/ValueRange.h" +-#include "mlir/IR/Visitors.h" + #include "mlir/Pass/Pass.h" // IWYU pragma: keep ++#include "mlir/Rewrite/FrozenRewritePatternSet.h" + #include "mlir/Support/LLVM.h" ++#include "mlir/Support/LogicalResult.h" ++#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + #include "shardy/dialect/sdy/ir/constants.h" +-#include "shardy/dialect/sdy/ir/data_flow_utils.h" + #include "shardy/dialect/sdy/ir/dialect.h" + #include "shardy/dialect/sdy/ir/utils.h" + +@@ -39,85 +40,96 @@ namespace sdy { + + namespace { + +-// Gets a vector of `TensorShardingAttr` for the given edge owner. +-// +-// Each value in `edgeOwners` is the owner of a data flow edge. If the data flow +-// edge already has a sharding, we will copy the sharding. Otherwise, if one +-// of the owners in `edgeOwners` has a sharding, we create a fully open sharding +-// with the mesh name of the first such sharding for all the other values that +-// don't have a sharding. +-SmallVector getShardingsFromDataFlowEdges( +- ValueRange edgeOwners) { +- SmallVector shardings; +- shardings.reserve(edgeOwners.size()); +- +- StringRef meshName; +- for (Value edgeOwner : edgeOwners) { +- TensorShardingAttr sharding; +- if (DataFlowEdgeOp dataFlowEdgeOp = +- DataFlowEdgeOp::getDataFlowEdgeUser(edgeOwner)) { +- sharding = dataFlowEdgeOp.getShardingAttr(); +- if (sharding && meshName.empty()) { +- meshName = sharding.getMeshName(); ++// This pattern matches on a specific `DataFlowEdgeOp`, but will also sink any ++// other `DataFlowEdgeOp` whose input is defined by the same op. This way we can ++// build the `TensorShardingPerValueAttr` for the defining op once. ++class SinkDataFlowEdgesPattern : public OpRewritePattern { ++ public: ++ using OpRewritePattern::OpRewritePattern; ++ ++ private: ++ LogicalResult matchAndRewrite(DataFlowEdgeOp dataFlowEdgeOp, ++ PatternRewriter& rewriter) const override { ++ Operation* defOp = dataFlowEdgeOp.getInput().getDefiningOp(); ++ if (!defOp) { ++ // `dataFlowEdgeOp` takes a block argument, we ignore the sharding of ++ // `dataFlowEdgeOp` since a block argument can't have a sharding attached. ++ // TODO(tomnatan): we might need to revisit this for future use cases. ++ rewriter.replaceOp(dataFlowEdgeOp, dataFlowEdgeOp.getInput()); ++ return success(); ++ } ++ ++ SmallVector shardings(defOp->getNumResults()); ++ ++ // For each result of `defOp` that is used by a `DataFlowEdgeOp`: ++ // - If the `DataFlowEdgeOp` has a sharding, add it to `shardings`. ++ // - Replace the `DataFlowEdgeOp` with its input. ++ // ++ // In addition, stores the mesh name of first encountered sharding, as we ++ // need a mesh name to replace missing shardings with fully replicated ++ // shardings. Note that it's ok to pick an arbitrary mesh if there are ++ // multiple, as we are creating fully replicated shardings. ++ StringRef meshName; ++ for (auto [index, result] : llvm::enumerate(defOp->getResults())) { ++ // We can assume a `DataFlowEdgeOp` will be the only user of its input. ++ DataFlowEdgeOp dataFlowEdgeOp = ++ DataFlowEdgeOp::getDataFlowEdgeUser(result); ++ if (!dataFlowEdgeOp) { ++ continue; ++ } ++ if (TensorShardingAttr sharding = dataFlowEdgeOp.getShardingAttr()) { ++ shardings[index] = sharding; ++ if (meshName.empty()) { ++ meshName = sharding.getMeshName(); ++ } + } ++ rewriter.replaceOp(dataFlowEdgeOp, dataFlowEdgeOp.getInput()); + } +- shardings.push_back(sharding); +- } +- if (meshName.empty()) { +- return {}; +- } +- // There is at least one `DataFlowEdgeOp` with a sharding. +- // Replace all empty shardings with fully open shardings. +- // NOTE: this will replace the existing edgeOwner's sharding, if any, though +- // this shouldn't happen as as `sdy-add-data-flow-edges` would have copied it. +- for (auto [sharding, edgeOwner] : llvm::zip_equal(shardings, edgeOwners)) { +- if (!sharding) { +- sharding = TensorShardingAttr::getFullyOpen( +- edgeOwner.getContext(), getTensorRank(edgeOwner), meshName); ++ ++ if (!meshName.empty()) { ++ // There is at least one `DataFlowEdgeOp` with a sharding. ++ // Replace all empty shardings with fully open shardings. ++ for (auto [sharding, result] : ++ llvm::zip(shardings, defOp->getResults())) { ++ if (!sharding) { ++ sharding = getOrCreateSharding(result, meshName); ++ } ++ } ++ defOp->setAttr(kShardingAttr, TensorShardingPerValueAttr::get( ++ defOp->getContext(), shardings)); + } ++ ++ return success(); + } +- return shardings; +-} ++}; + + struct SinkDataFlowEdgesPass + : public impl::SinkDataFlowEdgesPassBase { + using SinkDataFlowEdgesPassBase::SinkDataFlowEdgesPassBase; + ++ LogicalResult initialize(MLIRContext* context) final { ++ config.useTopDownTraversal = true; ++ config.enableRegionSimplification = ++ mlir::GreedySimplifyRegionLevel::Disabled; ++ config.maxIterations = 2; ++ ++ RewritePatternSet patternsInternal(context); ++ patternsInternal.add(context); ++ patterns = std::move(patternsInternal); ++ ++ return success(); ++ } ++ + void runOnOperation() final { +- func::FuncOp funcOp = getOperation(); +- IRRewriter rewriter(funcOp); +- // Copy the sharding from data flow edges to the data flow ops. +- funcOp.walk([&](Operation* op) { +- // Since we are doing the walk in preorder with a forward iterator, ops +- // are walked before their users and regions. Since `DataFlowEdgeOp` can +- // only appear inside the data flow op's region or as its user, we always +- // encounter the data flow op before their data flow edges. This means it +- // is safe to erase the `DataFlowEdgeOp` at this point. We need the skip +- // at the end because it's a condition to erase the op. See the +- // documentation for `Operation::walk` for more details. +- if (isa(op)) { +- DataFlowEdgeOp dataFlowEdgeOp = cast(op); +- rewriter.replaceOp(dataFlowEdgeOp, dataFlowEdgeOp.getInput()); +- return WalkResult::skip(); +- } +- if (!isDataFlowOp(op)) { +- return WalkResult::advance(); +- } +- if (SmallVector blockArgShardings = +- getShardingsFromDataFlowEdges( +- getDataFlowEdgeBlockArgumentOwners(op)); +- !blockArgShardings.empty()) { +- setBlockArgumentEdgeOwnerShardings(op, blockArgShardings); +- } +- if (SmallVector resultShardings = +- getShardingsFromDataFlowEdges(getDataFlowEdgeResultOwners(op)); +- !resultShardings.empty()) { +- op->setAttr(kShardingAttr, TensorShardingPerValueAttr::get( +- op->getContext(), resultShardings)); +- } +- return WalkResult::advance(); +- }); ++ if (failed( ++ applyPatternsAndFoldGreedily(getOperation(), patterns, config))) { ++ signalPassFailure(); ++ } + } ++ ++ private: ++ FrozenRewritePatternSet patterns; ++ GreedyRewriteConfig config; + }; + + } // namespace +diff --git a/shardy/dialect/sdy/transforms/export/test/sink_data_flow_edges.mlir b/shardy/dialect/sdy/transforms/export/test/sink_data_flow_edges.mlir +index eee582c..194b28b 100644 +--- a/shardy/dialect/sdy/transforms/export/test/sink_data_flow_edges.mlir ++++ b/shardy/dialect/sdy/transforms/export/test/sink_data_flow_edges.mlir +@@ -103,7 +103,7 @@ func.func @all_edges_have_sharding(%arg0: tensor<32x96xf32>, %arg1: tensor<32x96 + func.func @missing_edge(%arg0: tensor<32x96xf32>, %arg1: tensor<32x96xf32>) + -> (tensor<32x96xf32>, tensor<32x96xf32>) { + // CHECK-NEXT: %[[OPT_BARRIER:.*]]:2 = stablehlo.optimization_barrier +- // CHECK-SAME: {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{?}, {?}]>, <@mesh, [{"a", ?}, {}]>]>} ++ // CHECK-SAME: {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"b"}, {?}]>, <@mesh, [{"a", ?}, {}]>]>} + // CHECK-NEXT: return %[[OPT_BARRIER]]#0, %[[OPT_BARRIER]]#1 + %0:2 = stablehlo.optimization_barrier + {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"b"}, {?}]>, <@mesh, [{?}, {}]>]>} +@@ -134,7 +134,7 @@ func.func @sharding_overrided(%arg0: tensor<32x96xf32>, %arg1: tensor<32x96xf32> + func.func @edge_missing_sharding(%arg0: tensor<32x96xf32>, %arg1: tensor<32x96xf32>) + -> (tensor<32x96xf32>, tensor<32x96xf32>) { + // CHECK-NEXT: %[[OPT_BARRIER:.*]]:2 = stablehlo.optimization_barrier +- // CHECK-SAME: {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{?}, {?}]>, <@mesh, [{"a", ?}, {}]>]>} ++ // CHECK-SAME: {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"b"}, {?}]>, <@mesh, [{"a", ?}, {}]>]>} + // CHECK-NEXT: return %[[OPT_BARRIER]]#0, %[[OPT_BARRIER]]#1 + %0:2 = stablehlo.optimization_barrier + {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"b"}, {?}]>, <@mesh, [{?}, {}]>]>} +diff --git a/shardy/dialect/sdy/transforms/import/add_data_flow_edges.cc b/shardy/dialect/sdy/transforms/import/add_data_flow_edges.cc +index ee9ee1d..4f5ac97 100644 +--- a/shardy/dialect/sdy/transforms/import/add_data_flow_edges.cc ++++ b/shardy/dialect/sdy/transforms/import/add_data_flow_edges.cc +@@ -63,7 +63,12 @@ struct AddDataFlowEdgesPass + funcOp.walk([&](Operation* op) { + // Add the data flow edges for result owners and block argument owners. + addDataFlowEdges(getDataFlowEdgeResultOwners(op), rewriter); +- addDataFlowEdges(getDataFlowEdgeBlockArgumentOwners(op), rewriter); ++ if (auto shardableDataFlowOpInterface = ++ dyn_cast(op)) { ++ addDataFlowEdges( ++ shardableDataFlowOpInterface.getBlockArgumentEdgeOwners(), ++ rewriter); ++ } + }); + } + }; +diff --git a/shardy/dialect/sdy/transforms/propagation/basic_factor_propagation.cc b/shardy/dialect/sdy/transforms/propagation/basic_factor_propagation.cc +index bb05f11..812dc73 100644 +--- a/shardy/dialect/sdy/transforms/propagation/basic_factor_propagation.cc ++++ b/shardy/dialect/sdy/transforms/propagation/basic_factor_propagation.cc +@@ -47,8 +47,8 @@ std::optional getPrefixWithoutOverlap( + AxisRefAttr axisRef, ArrayRef otherAxisRefs) { + AxisRefAttr result = axisRef; + for (AxisRefAttr otherAxisRef : otherAxisRefs) { +- SDY_ASSIGN_OR_RETURN_IF_NULLOPT( +- result, result.getPrefixWithoutOverlap(otherAxisRef)); ++ ASSIGN_OR_RETURN_IF_NULLOPT(result, ++ result.getPrefixWithoutOverlap(otherAxisRef)); + } + return result; + } +@@ -62,9 +62,9 @@ BasicFactorPropagation::compatiblePrefixNoConflictsAcrossFactors( + AxisRefAttr result = axisRef; + for (const auto& [otherFactorIndex, shardings] : factorIndexToSharding) { + if (otherFactorIndex != factorIndex) { +- SDY_ASSIGN_OR_RETURN_IF_NULLOPT( ++ ASSIGN_OR_RETURN_IF_NULLOPT( + result, getPrefixWithoutOverlap(result, shardings.overflowAxes)); +- SDY_ASSIGN_OR_RETURN_IF_NULLOPT( ++ ASSIGN_OR_RETURN_IF_NULLOPT( + result, getPrefixWithoutOverlap(result, shardings.axisRefs)); + } + } +@@ -78,8 +78,8 @@ BasicFactorPropagation::compatiblePrefixNoConflictsWithinFactor( + int64_t factorSize) const { + AxisRefAttr result = axisRef; + +- SDY_ASSIGN_OR_RETURN_IF_NULLOPT( +- result, getPrefixWithoutOverlap(result, replicatedAxes)); ++ ASSIGN_OR_RETURN_IF_NULLOPT(result, ++ getPrefixWithoutOverlap(result, replicatedAxes)); + + ArrayRef factorAxes = factorSharding.axisRefs; + if (llvm::any_of(factorAxes, [&](AxisRefAttr shardingAxis) { +@@ -323,9 +323,9 @@ std::optional BasicFactorPropagation::compatiblePrefix( + const FactorIndexToSharding& factorIndexToSharding = + tensorFactorSharding.factorIndexToSharding; + +- SDY_ASSIGN_OR_RETURN_IF_NULLOPT( +- AxisRefAttr result, compatiblePrefixNoConflictsAcrossFactors( +- axisRef, factorIndexToSharding, factorIndex)); ++ ASSIGN_OR_RETURN_IF_NULLOPT(AxisRefAttr result, ++ compatiblePrefixNoConflictsAcrossFactors( ++ axisRef, factorIndexToSharding, factorIndex)); + + auto factorShardingIt = factorIndexToSharding.find(factorIndex); + if (factorShardingIt == factorIndexToSharding.end()) { +@@ -351,7 +351,7 @@ std::optional BasicFactorPropagation::compatiblePrefix( + for (const TensorFactorShardings& tensorFactorSharding : + llvm::concat(projection.getOperands(), + projection.getResults())) { +- SDY_ASSIGN_OR_RETURN_IF_NULLOPT( ++ ASSIGN_OR_RETURN_IF_NULLOPT( + result, compatiblePrefix(result, tensorFactorSharding, factorIndex, + shardedSize, factorSize)); + } diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl -index 2f731c7..aaf1f1a 100644 +index aab5a85..2f731c7 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" -- LLVM_COMMIT = "36adf8ecedb64047021265a1e1730773d3b3a9e8" -- LLVM_SHA256 = "7baedfc21f67f64f054482cbe77cb3049cd4428187cd45799e10ff8eb03dc9f6" -+ LLVM_COMMIT = "f0b3287297aeeddcf030e3c1b08d05a69ad465aa" -+ LLVM_SHA256 = "3bc65e7a760a389f5ace1146cb2ffde724a272e97e71c8b8509149e827df6c83" +- LLVM_COMMIT = "3cd01371e007b2a8fe32e5d8ce1154057e5e1c2e" +- LLVM_SHA256 = "3d1ee3e896689b5ff2e8cc547e554c59bc70d1101ede9f25be9ca53d9dc409b9" ++ LLVM_COMMIT = "36adf8ecedb64047021265a1e1730773d3b3a9e8" ++ LLVM_SHA256 = "7baedfc21f67f64f054482cbe77cb3049cd4428187cd45799e10ff8eb03dc9f6" tf_http_archive( name = name, +diff --git a/third_party/stablehlo/temporary.patch b/third_party/stablehlo/temporary.patch +index a506670..0f83544 100755 +--- a/third_party/stablehlo/temporary.patch ++++ b/third_party/stablehlo/temporary.patch +@@ -1,29 +1,33 @@ + diff --ruN a/stablehlo/stablehlo/conversions/tosa/tests/unary.mlir b/stablehlo/stablehlo/conversions/tosa/tests/unary.mlir + --- stablehlo/stablehlo/conversions/tosa/tests/unary.mlir + +++ stablehlo/stablehlo/conversions/tosa/tests/unary.mlir +-@@ -119,13 +119,15 @@ +- return %0 : tensor<10xf32> +- } ++@@ -121,8 +121,8 @@ + +--// CHECK-LABEL: @transpose +--func.func @transpose(%arg0: tensor<1x2x3xf32>) -> tensor<3x2x1xf32> { ++ // CHECK-LABEL: @transpose ++ func.func @transpose(%arg0: tensor<1x2x3xf32>) -> tensor<3x2x1xf32> { + - // CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() <{value = dense<[2, 1, 0]> : tensor<3xi64>}> : () -> tensor<3xi64> + - // CHECK-DAG: %[[VAR1:.*]] = tosa.transpose %arg0, %[[VAR0]] +-- %0 = "stablehlo.transpose"(%arg0) {permutation = array} : (tensor<1x2x3xf32>) -> tensor<3x2x1xf32> +-- return %0 : tensor<3x2x1xf32> +--} +-+// TODO: https://github.com/llvm/llvm-project/pull/108133 breaks the test, +-+// need to investigate this. +-+// disableCHECK-LABEL: @transpose +-+// func.func @transpose(%arg0: tensor<1x2x3xf32>) -> tensor<3x2x1xf32> { +-+ // disableCHECK-DAG: %[[VAR0:.*]] = "tosa.const"() <{value = dense<[2, 1, 0]> : tensor<3xi64>}> : () -> tensor<3xi64> +-+ // disableCHECK-DAG: %[[VAR1:.*]] = tosa.transpose %arg0, %[[VAR0]] +-+ // %0 = "stablehlo.transpose"(%arg0) {permutation = array} : (tensor<1x2x3xf32>) -> tensor<3x2x1xf32> +-+ // return %0 : tensor<3x2x1xf32> +-+// } +++ // CHECK: %[[VAR0:.*]] = "tosa.const"() <{value = dense<[2, 1, 0]> : tensor<3xi32>}> : () -> tensor<3xi32> +++ // CHECK: %[[VAR1:.*]] = tosa.transpose %arg0, %[[VAR0]] ++ %0 = "stablehlo.transpose"(%arg0) {permutation = array} : (tensor<1x2x3xf32>) -> tensor<3x2x1xf32> ++ return %0 : tensor<3x2x1xf32> ++ } ++diff --ruN a/stablehlo/stablehlo/conversions/tosa/transforms/StablehloLegalizeToTosa.cpp b/stablehlo/stablehlo/conversions/tosa/transforms/StablehloLegalizeToTosa.cpp ++--- stablehlo/stablehlo/conversions/tosa/transforms/StablehloLegalizeToTosa.cpp +++++ stablehlo/stablehlo/conversions/tosa/transforms/StablehloLegalizeToTosa.cpp ++@@ -451,9 +451,10 @@ + +- // CHECK-LABEL: @while +- func.func @while(%arg0: tensor) -> tensor { ++ auto perms = op.getPermutation(); ++ auto type = RankedTensorType::get({static_cast(perms.size())}, ++- rewriter.getI64Type()); +++ rewriter.getI32Type()); +++ std::vector perms_int32(perms.begin(), perms.end()); ++ auto constOp = rewriter.create( ++- op->getLoc(), type, DenseIntElementsAttr::get(type, perms)); +++ op->getLoc(), type, DenseIntElementsAttr::get(type, perms_int32)); ++ rewriter.replaceOpWithNewOp(op, op.getType(), ++ op.getOperand(), constOp); ++ return success(); + diff --ruN a/stablehlo/stablehlo/dialect/Version.cpp b/stablehlo/stablehlo/dialect/Version.cpp + --- stablehlo/stablehlo/dialect/Version.cpp + +++ stablehlo/stablehlo/dialect/Version.cpp diff --git a/third_party/shardy/workspace.bzl b/third_party/shardy/workspace.bzl index e80cf84d1a1e8..8b5f84dffb563 100644 --- a/third_party/shardy/workspace.bzl +++ b/third_party/shardy/workspace.bzl @@ -3,8 +3,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): - SHARDY_COMMIT = "1142aa484b8e11440f609859a6a55edfb733f869" - SHARDY_SHA256 = "ecb2037f2e098488d2061f988b6088d8834460e61fe73ffc32ea5e9c6692b96c" + SHARDY_COMMIT = "17543e47ff8adb535322df2a9ad9e3171c9741c9" + SHARDY_SHA256 = "13ee573236d481bc7996a01c44ece6e70ad17a3b643e925155c830f4fbc09094" tf_http_archive( name = "shardy", diff --git a/third_party/tsl/third_party/llvm/generated.patch b/third_party/tsl/third_party/llvm/generated.patch index ed6047cd2b8f2..509398da979e8 100644 --- a/third_party/tsl/third_party/llvm/generated.patch +++ b/third_party/tsl/third_party/llvm/generated.patch @@ -1,28 +1 @@ Auto generated patch. Do not edit or delete it, even if empty. -diff -ruN --strip-trailing-cr a/clang/test/CodeGen/X86/avx10_2_512satcvtds-builtins-errors.c b/clang/test/CodeGen/X86/avx10_2_512satcvtds-builtins-errors.c ---- a/clang/test/CodeGen/X86/avx10_2_512satcvtds-builtins-errors.c -+++ b/clang/test/CodeGen/X86/avx10_2_512satcvtds-builtins-errors.c -@@ -1,4 +1,4 @@ --// RUN: %clang_cc1 -flax-vector-conversions=none -ffreestanding %s -triple=i386-unknown-unknown -target-feature +avx10.2-512 -emit-llvm -Wall -Werror -verify -+// RUN: %clang_cc1 -flax-vector-conversions=none -ffreestanding %s -triple=i386-unknown-unknown -target-feature +avx10.2-512 -Wall -Werror -verify - - #include - #include -diff -ruN --strip-trailing-cr a/clang/test/CodeGen/X86/avx10_2_512satcvtds-builtins-x64-error.c b/clang/test/CodeGen/X86/avx10_2_512satcvtds-builtins-x64-error.c ---- a/clang/test/CodeGen/X86/avx10_2_512satcvtds-builtins-x64-error.c -+++ b/clang/test/CodeGen/X86/avx10_2_512satcvtds-builtins-x64-error.c -@@ -1,4 +1,4 @@ --// RUN: %clang_cc1 -flax-vector-conversions=none -ffreestanding %s -triple=x86_64-unknown-unknown -target-feature +avx10.2-512 -emit-llvm -Wall -Werror -verify -+// RUN: %clang_cc1 -flax-vector-conversions=none -ffreestanding %s -triple=x86_64-unknown-unknown -target-feature +avx10.2-512 -Wall -Werror -verify - - #include - #include -diff -ruN --strip-trailing-cr a/clang/test/CodeGen/X86/avx10_2satcvtds-builtins-errors.c b/clang/test/CodeGen/X86/avx10_2satcvtds-builtins-errors.c ---- a/clang/test/CodeGen/X86/avx10_2satcvtds-builtins-errors.c -+++ b/clang/test/CodeGen/X86/avx10_2satcvtds-builtins-errors.c -@@ -1,4 +1,4 @@ --// RUN: %clang_cc1 -flax-vector-conversions=none -ffreestanding %s -triple=i386-unknown-unknown -target-feature +avx10.2-256 -emit-llvm -Wall -Werror -verify -+// RUN: %clang_cc1 -flax-vector-conversions=none -ffreestanding %s -triple=i386-unknown-unknown -target-feature +avx10.2-256 -Wall -Werror -verify - - unsigned long long test_mm_cvttssd(unsigned long long __A) { - return _mm_cvttssd(__A); // expected-error {{call to undeclared function '_mm_cvttssd'}} diff --git a/third_party/tsl/third_party/llvm/workspace.bzl b/third_party/tsl/third_party/llvm/workspace.bzl index aaf1f1a2db970..2f731c7a2e4da 100644 --- a/third_party/tsl/third_party/llvm/workspace.bzl +++ b/third_party/tsl/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" - LLVM_COMMIT = "f0b3287297aeeddcf030e3c1b08d05a69ad465aa" - LLVM_SHA256 = "3bc65e7a760a389f5ace1146cb2ffde724a272e97e71c8b8509149e827df6c83" + LLVM_COMMIT = "36adf8ecedb64047021265a1e1730773d3b3a9e8" + LLVM_SHA256 = "7baedfc21f67f64f054482cbe77cb3049cd4428187cd45799e10ff8eb03dc9f6" tf_http_archive( name = name,