diff --git a/third_party/stablehlo/temporary.patch b/third_party/stablehlo/temporary.patch index f803c97b5f40bf..8b137891791fe9 100755 --- a/third_party/stablehlo/temporary.patch +++ b/third_party/stablehlo/temporary.patch @@ -1,385 +1 @@ -diff --ruN a/stablehlo/BUILD.bazel b/stablehlo/BUILD.bazel ---- stablehlo/BUILD.bazel -+++ stablehlo/BUILD.bazel -@@ -1112,11 +1112,13 @@ - "stablehlo/transforms/StablehloRefineShapes.cpp", - "stablehlo/transforms/VhloLegalizeToStablehlo.cpp", - "stablehlo/transforms/VhloToVersion.cpp", -+ "stablehlo/transforms/passes_utils.cpp", - ], - hdrs = [ - "stablehlo/transforms/MapStablehloToVhlo.h", - "stablehlo/transforms/Passes.h", - "stablehlo/transforms/StablehloRefineShapes.h", -+ "stablehlo/transforms/passes_utils.h", - ], - strip_include_prefix = ".", - deps = [ -diff --ruN a/stablehlo/stablehlo/conversions/tosa/tests/unary.mlir b/stablehlo/stablehlo/conversions/tosa/tests/unary.mlir ---- stablehlo/stablehlo/conversions/tosa/tests/unary.mlir -+++ stablehlo/stablehlo/conversions/tosa/tests/unary.mlir -@@ -121,8 +121,8 @@ - - // CHECK-LABEL: @transpose - func.func @transpose(%arg0: tensor<1x2x3xf32>) -> tensor<3x2x1xf32> { -- // CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() <{value = dense<[2, 1, 0]> : tensor<3xi64>}> : () -> tensor<3xi64> -- // CHECK-DAG: %[[VAR1:.*]] = tosa.transpose %arg0, %[[VAR0]] -+ // CHECK: %[[VAR0:.*]] = "tosa.const"() <{value = dense<[2, 1, 0]> : tensor<3xi32>}> : () -> tensor<3xi32> -+ // CHECK: %[[VAR1:.*]] = tosa.transpose %arg0, %[[VAR0]] - %0 = "stablehlo.transpose"(%arg0) {permutation = array} : (tensor<1x2x3xf32>) -> tensor<3x2x1xf32> - return %0 : tensor<3x2x1xf32> - } -diff --ruN a/stablehlo/stablehlo/conversions/tosa/transforms/StablehloLegalizeToTosa.cpp b/stablehlo/stablehlo/conversions/tosa/transforms/StablehloLegalizeToTosa.cpp ---- stablehlo/stablehlo/conversions/tosa/transforms/StablehloLegalizeToTosa.cpp -+++ stablehlo/stablehlo/conversions/tosa/transforms/StablehloLegalizeToTosa.cpp -@@ -451,9 +451,10 @@ - - auto perms = op.getPermutation(); - auto type = RankedTensorType::get({static_cast(perms.size())}, -- rewriter.getI64Type()); -+ rewriter.getI32Type()); -+ std::vector perms_int32(perms.begin(), perms.end()); - auto constOp = rewriter.create( -- op->getLoc(), type, DenseIntElementsAttr::get(type, perms)); -+ op->getLoc(), type, DenseIntElementsAttr::get(type, perms_int32)); - rewriter.replaceOpWithNewOp(op, op.getType(), - op.getOperand(), constOp); - return success(); -diff --ruN a/stablehlo/stablehlo/dialect/Version.cpp b/stablehlo/stablehlo/dialect/Version.cpp ---- stablehlo/stablehlo/dialect/Version.cpp -+++ stablehlo/stablehlo/dialect/Version.cpp -@@ -80,9 +80,9 @@ - case CompatibilityRequirement::NONE: - return Version::getCurrentVersion(); - case CompatibilityRequirement::WEEK_4: -- return Version(1, 3, 0); // v1.3.0 - Jul 15, 2024 -+ return Version(1, 5, 0); // v1.3.0 - Aug 1, 2024 - case CompatibilityRequirement::WEEK_12: -- return Version(1, 0, 0); // v1.0.0 - May 14, 2024 -+ return Version(1, 1, 0); // v1.1.0 - May 30, 2024 - case CompatibilityRequirement::MAX: - return Version::getMinimumVersion(); - } -diff --ruN a/stablehlo/stablehlo/tests/chlo/chlo_legalize_to_stablehlo.mlir b/stablehlo/stablehlo/tests/chlo/chlo_legalize_to_stablehlo.mlir ---- stablehlo/stablehlo/tests/chlo/chlo_legalize_to_stablehlo.mlir -+++ stablehlo/stablehlo/tests/chlo/chlo_legalize_to_stablehlo.mlir -@@ -3029,28 +3029,11 @@ - - // ----- - --// CHECK-LABEL: @tan_f16 --// CHECK-SAME: (%[[ARG:.*]]: tensor) --func.func @tan_f16(%arg : tensor) -> tensor { -- // %[[TMP_0:.*]] = stablehlo.convert [[ARG]] : (tensor) -> tensor -- // %[[TMP_1:.*]] = stablehlo.sine %[[TMP_0]] -- // %[[TMP_2:.*]] = stablehlo.cosine %[[TMP_0]] -- // %[[TMP_3:.*]] = stablehlo.divide %[[TMP_1]], %[[TMP_2]] -- // %[[TMP_4:.*]] = stablehlo.convert %[[TMP_3]] : (tensor) -> tensor -- // return %[[TMP_4]] : tensor -- %1 = chlo.tan %arg : tensor -> tensor -- func.return %1 : tensor --} -- --// ----- -- - // CHECK-LABEL: @tan_f32 - // CHECK-SAME: (%[[ARG:.*]]: tensor) - func.func @tan_f32(%arg : tensor) -> tensor { -- // %[[TMP_0:.*]] = stablehlo.sine %[[ARG]] -- // %[[TMP_1:.*]] = stablehlo.cosine %[[ARG]] -- // %[[TMP_2:.*]] = stablehlo.divide %[[TMP_0]], %[[TMP_1]] -- // return %[[TMP_2]] : tensor -+ // CHECK: %[[TMP_0:.*]] = stablehlo.tan %[[ARG]] : tensor -+ // CHECK: return %[[TMP_0]] : tensor - %1 = chlo.tan %arg : tensor -> tensor - func.return %1 : tensor - } -@@ -3060,22 +3043,11 @@ - // CHECK-LABEL: @tan_complexf32 - // CHECK-SAME: %[[ARG0:.+]]: tensor<1xf32>, %[[ARG1:.+]]: tensor<1xf32> - func.func @tan_complexf32(%arg0 : tensor<1xf32>, %arg1 : tensor<1xf32>) -> (tensor<1xf32>, tensor<1xf32>) { -- // CHECK: %[[COMPLEX:.+]] = stablehlo.complex %[[ARG0]], %[[ARG1]] : tensor<1xcomplex> -- // CHECK: %[[REAL:.+]] = stablehlo.real %[[COMPLEX]] : (tensor<1xcomplex>) -> tensor<1xf32> -- // CHECK: %[[SINE:.+]] = stablehlo.sine %[[REAL]] -- // CHECK: %[[COS:.+]] = stablehlo.cosine %[[REAL]] -- // CHECK: %[[TAN:.+]] = stablehlo.divide %[[SINE]], %[[COS]] -- // CHECK: %[[IMAG:.+]] = stablehlo.imag %[[COMPLEX]] : (tensor<1xcomplex>) -> tensor<1xf32> -- // CHECK: %[[TANH:.+]] = stablehlo.tanh %[[IMAG]] -- // CHECK: %[[NUM:.+]] = stablehlo.complex %[[TAN]], %[[TANH]] -- // CHECK: %[[ONE:.+]] = stablehlo.constant dense<1.000000e+00> : tensor<1xf32> -- // CHECK: %[[MUL:.+]] = stablehlo.multiply %[[TAN]], %[[TANH]] -- // CHECK: %[[NEG:.+]] = stablehlo.negate %[[MUL]] -- // CHECK: %[[DEN:.+]] = stablehlo.complex %[[ONE]], %[[NEG]] -- // CHECK: %[[RES:.+]] = stablehlo.divide %[[NUM]], %[[DEN]] -- // CHECK: %[[REAL:.+]] = stablehlo.real %[[RES]] -- // CHECK: %[[IMAG:.+]] = stablehlo.imag %[[RES]] -- // CHECK: return %[[REAL]], %[[IMAG]] -+ // CHECK: %[[TMP_0:.*]] = stablehlo.complex %[[ARG0]], %[[ARG1]] : tensor<1xcomplex> -+ // CHECK: %[[TMP_1:.*]] = stablehlo.tan %[[TMP_0]] : tensor<1xcomplex> -+ // CHECK: %[[TMP_2:.*]] = stablehlo.real %[[TMP_1]] : (tensor<1xcomplex>) -> tensor<1xf32> -+ // CHECK: %[[TMP_3:.*]] = stablehlo.imag %[[TMP_1]] : (tensor<1xcomplex>) -> tensor<1xf32> -+ // CHECK: return %[[TMP_2]], %[[TMP_3]] : tensor<1xf32>, tensor<1xf32> - %0 = stablehlo.complex %arg0, %arg1 : tensor<1xcomplex> - %1 = chlo.tan %0 : tensor<1xcomplex> -> tensor<1xcomplex> - %2 = stablehlo.real %1 : (tensor<1xcomplex>) -> tensor<1xf32> -diff --ruN a/stablehlo/stablehlo/tests/transforms/stablehlo_create_compatibility_expander.mlir b/stablehlo/stablehlo/tests/transforms/stablehlo_create_compatibility_expander.mlir ---- stablehlo/stablehlo/tests/transforms/stablehlo_create_compatibility_expander.mlir -+++ stablehlo/stablehlo/tests/transforms/stablehlo_create_compatibility_expander.mlir -@@ -1,5 +1,5 @@ --// RUN: stablehlo-opt %s -verify-diagnostics -split-input-file -allow-unregistered-dialect --stablehlo-create-compatibility-expander='target=1.0.0' | FileCheck %s --check-prefixes=CHECK --// RUN: stablehlo-opt %s -verify-diagnostics -split-input-file --stablehlo-create-compatibility-expander='target=1.6.0' | FileCheck %s --check-prefixes=CHECK-NO-DOWNGRADE -+// RUN: stablehlo-opt %s -verify-diagnostics -split-input-file --stablehlo-create-compatibility-expander='target=1.0.0' --chlo-legalize-to-stablehlo | FileCheck %s --check-prefixes=CHECK -+// RUN: stablehlo-opt %s -verify-diagnostics -split-input-file --stablehlo-create-compatibility-expander='target=1.6.0' --chlo-legalize-to-stablehlo | FileCheck %s --check-prefixes=CHECK-NO-DOWNGRADE - - // ----- - -diff --ruN a/stablehlo/stablehlo/transforms/CMakeLists.txt b/stablehlo/stablehlo/transforms/CMakeLists.txt ---- stablehlo/stablehlo/transforms/CMakeLists.txt -+++ stablehlo/stablehlo/transforms/CMakeLists.txt -@@ -53,6 +53,7 @@ - StablehloRefineShapes.cpp - VhloLegalizeToStablehlo.cpp - VhloToVersion.cpp -+ passes_utils.cpp - - DEPENDS - ChloDecompositionPatternsIncGen -diff --ruN a/stablehlo/stablehlo/transforms/ChloDecompositionPatterns.td b/stablehlo/stablehlo/transforms/ChloDecompositionPatterns.td ---- stablehlo/stablehlo/transforms/ChloDecompositionPatterns.td -+++ stablehlo/stablehlo/transforms/ChloDecompositionPatterns.td -@@ -109,26 +109,9 @@ - (STABLEHLO_DEFAULT_COMPARISON_TYPE) - )>; - --// Express `tan` as --// sine(x) / cosine(x) --def : Pat<(CHLO_TanOp NonComplexElementType:$input), -- (StableHLO_DivOp -- (StableHLO_SineOp $input), -- (StableHLO_CosineOp $input) -- )>; - -- --// Express `tan(a + bi)` as --// (tan(a) + i tanh(b)) / (1 - i tan(a) * tanh(b)) --def : Pat<(CHLO_TanOp ComplexElementType:$input), -- (StableHLO_DivOp -- (StableHLO_ComplexOp -- (CHLO_TanOp:$tan (StableHLO_RealOp $input)), -- (StableHLO_TanhOp:$tanh (StableHLO_ImagOp $input))), -- (StableHLO_ComplexOp -- (StableHLO_ConstantLike<"1.0"> $tan), -- (StableHLO_NegOp (StableHLO_MulOp $tan, $tanh))) -- )>; -+def : Pat<(CHLO_TanOp $input), -+ (StableHLO_TanOp $input)>; - - def : Pat<(CHLO_ConstantOp $v), - (StableHLO_ConstantOp $v)>; -diff --ruN a/stablehlo/stablehlo/transforms/ChloLegalizeToStablehlo.cpp b/stablehlo/stablehlo/transforms/ChloLegalizeToStablehlo.cpp ---- stablehlo/stablehlo/transforms/ChloLegalizeToStablehlo.cpp -+++ stablehlo/stablehlo/transforms/ChloLegalizeToStablehlo.cpp -@@ -49,6 +49,7 @@ - #include "stablehlo/dialect/ChloOps.h" - #include "stablehlo/dialect/StablehloOps.h" - #include "stablehlo/transforms/Passes.h" -+#include "stablehlo/transforms/passes_utils.h" - - namespace mlir { - namespace stablehlo { -@@ -169,29 +170,6 @@ - patterns->add>( - context, args...); --} -- --template --static Value getConstantLike(OpBuilder &b, Location loc, T constant, -- Value val) { -- Type ty = getElementTypeOrSelf(val.getType()); -- auto getAttr = [&]() -> Attribute { -- if (isa(ty)) return b.getIntegerAttr(ty, constant); -- if (isa(ty)) return b.getFloatAttr(ty, constant); -- if (auto complexTy = dyn_cast(ty)) { -- return complex::NumberAttr::get(complexTy, constant, 0); -- } -- llvm_unreachable("unhandled element type"); -- }; -- return b.create(loc, cast(getAttr()), -- val); --} -- --static Value getConstantLike(OpBuilder &b, Location loc, -- const APFloat &constant, Value val) { -- Type ty = getElementTypeOrSelf(val.getType()); -- return b.create(loc, b.getFloatAttr(ty, constant), -- val); - } - - static Value getConstantLikeMaxFiniteValue(OpBuilder &b, Location loc, -diff --ruN a/stablehlo/stablehlo/transforms/Passes.td b/stablehlo/stablehlo/transforms/Passes.td ---- stablehlo/stablehlo/transforms/Passes.td -+++ stablehlo/stablehlo/transforms/Passes.td -@@ -338,5 +338,6 @@ - ]; - let dependentDialects = [ - "mlir::stablehlo::StablehloDialect", -- ]; --} -+ "mlir::chlo::ChloDialect", -+ ]; -+} -diff --ruN a/stablehlo/stablehlo/transforms/StablehloCreateCompatibilityExpander.cpp b/stablehlo/stablehlo/transforms/StablehloCreateCompatibilityExpander.cpp ---- stablehlo/stablehlo/transforms/StablehloCreateCompatibilityExpander.cpp -+++ stablehlo/stablehlo/transforms/StablehloCreateCompatibilityExpander.cpp -@@ -26,6 +26,7 @@ - #include "stablehlo/dialect/StablehloOps.h" - #include "stablehlo/dialect/Version.h" - #include "stablehlo/transforms/Passes.h" -+#include "stablehlo/transforms/passes_utils.h" - - namespace mlir { - namespace stablehlo { -@@ -37,21 +38,6 @@ - //===----------------------------------------------------------------------===// - // Helpers. - //===----------------------------------------------------------------------===// -- --// Creates a constant with all ones. --static Value createConstantWithAllOnes(OpBuilder &b, Location loc, Value val) { -- if (!isa(getElementTypeOrSelf(val))) -- llvm_unreachable("Unsupported element type, expecting float"); -- -- auto shapedTy = dyn_cast(val.getType()); -- if (!shapedTy) llvm_unreachable("Unsupported shaped type."); -- -- mlir::DenseElementsAttr elementsAttr = -- mlir::DenseElementsAttr::get(shapedTy, 1.0); -- -- return b.create(loc, val.getType(), -- elementsAttr); --} - - // Check user-specified target version. - vhlo::Version validateTargetVersion(llvm::StringRef versionRef) { -diff --ruN a/stablehlo/stablehlo/transforms/StablehloCreateCompatibilityExpanderPatterns.td b/stablehlo/stablehlo/transforms/StablehloCreateCompatibilityExpanderPatterns.td ---- stablehlo/stablehlo/transforms/StablehloCreateCompatibilityExpanderPatterns.td -+++ stablehlo/stablehlo/transforms/StablehloCreateCompatibilityExpanderPatterns.td -@@ -24,7 +24,8 @@ - CPred<"!isa(cast($_self).getElementType())">, - "Non-complex element type">; - --def createConstantWithAllOnes : NativeCodeCall<"createConstantWithAllOnes($_builder, $_loc, $0)">; -+class StableHLO_ConstantLike : NativeCodeCall< -+ "::mlir::stablehlo::getConstantLike($_builder, $_loc, " # value # ", $0)">; - - // Express `tan` as - // sine(x) / cosine(x) -@@ -42,6 +43,6 @@ - (StableHLO_TanOp:$tan (StableHLO_RealOp $input)), - (StableHLO_TanhOp:$tanh (StableHLO_ImagOp $input))), - (StableHLO_ComplexOp -- (createConstantWithAllOnes $tan), -+ (StableHLO_ConstantLike<"1.0"> $tan), - (StableHLO_NegOp (StableHLO_MulOp $tan, $tanh))) - )>; -diff --ruN a/stablehlo/stablehlo/transforms/passes_utils.cpp b/stablehlo/stablehlo/transforms/passes_utils.cpp ---- stablehlo/stablehlo/transforms/passes_utils.cpp -+++ stablehlo/stablehlo/transforms/passes_utils.cpp -@@ -0,0 +1,34 @@ -+/* Copyright 2024 The StableHLO Authors. All Rights Reserved. -+Licensed under the Apache License, Version 2.0 (the "License"); -+you may not use this file except in compliance with the License. -+You may obtain a copy of the License at -+ http://www.apache.org/licenses/LICENSE-2.0 -+Unless required by applicable law or agreed to in writing, software -+distributed under the License is distributed on an "AS IS" BASIS, -+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+See the License for the specific language governing permissions and -+limitations under the License. -+==============================================================================*/ -+ -+#include "stablehlo/transforms/passes_utils.h" -+ -+#include "mlir/IR/Builders.h" -+#include "mlir/IR/Location.h" -+#include "mlir/IR/TypeUtilities.h" -+#include "mlir/IR/Types.h" -+#include "mlir/IR/Value.h" -+#include "mlir/Support/LLVM.h" -+#include "stablehlo/dialect/ChloOps.h" -+ -+namespace mlir { -+namespace stablehlo { -+ -+Value getConstantLike(OpBuilder &b, Location loc, const APFloat &constant, -+ Value val) { -+ Type ty = getElementTypeOrSelf(val.getType()); -+ return b.create(loc, b.getFloatAttr(ty, constant), -+ val); -+} -+ -+} // namespace stablehlo -+} // namespace mlir -diff --ruN a/stablehlo/stablehlo/transforms/passes_utils.h b/stablehlo/stablehlo/transforms/passes_utils.h ---- stablehlo/stablehlo/transforms/passes_utils.h -+++ stablehlo/stablehlo/transforms/passes_utils.h -@@ -0,0 +1,57 @@ -+/* Copyright 2024 The StableHLO Authors. All Rights Reserved. -+Licensed under the Apache License, Version 2.0 (the "License"); -+you may not use this file except in compliance with the License. -+You may obtain a copy of the License at -+ http://www.apache.org/licenses/LICENSE-2.0 -+Unless required by applicable law or agreed to in writing, software -+distributed under the License is distributed on an "AS IS" BASIS, -+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+See the License for the specific language governing permissions and -+limitations under the License. -+==============================================================================*/ -+ -+#ifndef THIRD_PARTY_STABLEHLO_STABLEHLO_TRANSFORMS_PASSES_UTILS_H_ -+#define THIRD_PARTY_STABLEHLO_STABLEHLO_TRANSFORMS_PASSES_UTILS_H_ -+ -+#include "llvm/Support/ErrorHandling.h" -+#include "mlir/Dialect/Complex/IR/Complex.h" -+#include "mlir/IR/Builders.h" -+#include "mlir/IR/BuiltinAttributeInterfaces.h" -+#include "mlir/IR/BuiltinTypes.h" -+#include "mlir/IR/Location.h" -+#include "mlir/IR/TypeUtilities.h" -+#include "mlir/IR/Types.h" -+#include "mlir/IR/Value.h" -+#include "mlir/Support/LLVM.h" -+#include "stablehlo/dialect/ChloOps.h" -+ -+namespace mlir { -+namespace stablehlo { -+// Add utility functions common across passes. -+ -+// Creates a chlo::ConstantLikeOp using a splat `constant` of the same shape -+// as `val`. -+template -+Value getConstantLike(OpBuilder &b, Location loc, T constant, Value val) { -+ Type ty = getElementTypeOrSelf(val.getType()); -+ auto getAttr = [&]() -> Attribute { -+ if (isa(ty)) return b.getIntegerAttr(ty, constant); -+ if (isa(ty)) return b.getFloatAttr(ty, constant); -+ if (auto complexTy = dyn_cast(ty)) { -+ return complex::NumberAttr::get(complexTy, constant, 0); -+ } -+ llvm_unreachable("unhandled element type"); -+ }; -+ return b.create(loc, cast(getAttr()), -+ val); -+} -+ -+// Creates a chlo::ConstantLikeOp using a APFloat splat `constant` of the -+// same shape as `val`. -+Value getConstantLike(OpBuilder &b, Location loc, const APFloat &constant, -+ Value val); -+ -+} // namespace stablehlo -+} // namespace mlir -+ -+#endif // THIRD_PARTY_STABLEHLO_STABLEHLO_TRANSFORMS_PASSES_UTILS_H_ diff --git a/third_party/stablehlo/workspace.bzl b/third_party/stablehlo/workspace.bzl index 0bdf74f3882cc3..97fd0b990fc1c7 100644 --- a/third_party/stablehlo/workspace.bzl +++ b/third_party/stablehlo/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): # LINT.IfChange - STABLEHLO_COMMIT = "e51fd95e5b2c28861f22dc9d609fb2a7f002124e" - STABLEHLO_SHA256 = "72007352d60cc42784908263ccf171f8721a40adf92a17e13b1b6893e986b8b6" + STABLEHLO_COMMIT = "78c753ad13ad8205cacc5fcc12418c1ac97276c7" + STABLEHLO_SHA256 = "b7fef892020eb465a6d1ed921160f5229398ba10acff36b6345171b9867ccc7c" # LINT.ThenChange(Google-internal path) tf_http_archive( diff --git a/xla/service/gpu/runtime/nccl_api.cc b/xla/service/gpu/runtime/nccl_api.cc index 77f022da6ec64f..15949ac9cae999 100644 --- a/xla/service/gpu/runtime/nccl_api.cc +++ b/xla/service/gpu/runtime/nccl_api.cc @@ -112,6 +112,8 @@ static absl::StatusOr ToNcclDataType(PrimitiveType dtype, case S8: case F8E5M2: case F8E4M3FN: + case F8E5M2FNUZ: + case F8E4M3FNUZ: return ncclInt8; case PRED: case U8: diff --git a/xla/service/gpu/runtime/nccl_collective_thunk.cc b/xla/service/gpu/runtime/nccl_collective_thunk.cc index 8e075c8d01c730..fb2282c8e73ae7 100644 --- a/xla/service/gpu/runtime/nccl_collective_thunk.cc +++ b/xla/service/gpu/runtime/nccl_collective_thunk.cc @@ -92,6 +92,8 @@ bool IsTypeSupportedByNccl(PrimitiveType element_type, // they involve actual computation and not just data movement. case F8E5M2: case F8E4M3FN: + case F8E5M2FNUZ: + case F8E4M3FNUZ: return !IsReductionCollective(reduction_op); default: return false; diff --git a/xla/tests/collective_ops_e2e_test.cc b/xla/tests/collective_ops_e2e_test.cc index 1e399127318242..cecf02827a99e1 100644 --- a/xla/tests/collective_ops_e2e_test.cc +++ b/xla/tests/collective_ops_e2e_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "absl/strings/str_replace.h" #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -54,6 +55,13 @@ DeviceAssignment MakeDeviceAssn(int64_t num_replicas) { class CollectiveOpsTestE2E : public HloTestBase { public: + CollectiveOpsTestE2E() { + replacements_[kF8E4M3DatatypePlaceholder] = + IsCuda() ? "f8e4m3fn" : "f8e4m3fnuz"; + replacements_[kF8E5M2DatatypePlaceholder] = + IsCuda() ? "f8e5m2" : "f8e5m2fnuz"; + } + bool IsCuda() { return std::holds_alternative(Capability()); } @@ -108,6 +116,13 @@ class CollectiveOpsTestE2E : public HloTestBase { /*argument_provider*/ [](int64_t, int64_t) { return nullptr; }, num_replicas, /*run_hlo_passes=*/false, &device_assignment); } + + protected: + absl::flat_hash_map replacements_; + + private: + static constexpr const char* kF8E4M3DatatypePlaceholder{"<>"}; + static constexpr const char* kF8E5M2DatatypePlaceholder{"<>"}; }; // E2E tests for collective ops. These will generally verify some HLO transform @@ -811,11 +826,11 @@ ENTRY main.12 { TEST_F(CollectiveOpsTestE2EWindowedNonWindowed, WindowedEinsumE2EAllGatherAndReduceScatterF8) { absl::string_view kModuleReplicatedStr = R"( -HloModule pjit__unnamed_wrapped_function_, entry_computation_layout={(f8e4m3fn[2,16,48]{2,1,0}, f8e4m3fn[48,192]{1,0}, f8e4m3fn[192,48]{1,0}, bf16[], bf16[], bf16[], bf16[], bf16[])->bf16[2,16,48]{2,1,0}}, allow_spmd_sharding_propagation_to_parameters={false,false,false,false}, num_partitions=4 +HloModule pjit__unnamed_wrapped_function_, entry_computation_layout={(<>[2,16,48]{2,1,0}, <>[48,192]{1,0}, <>[192,48]{1,0}, bf16[], bf16[], bf16[], bf16[], bf16[])->bf16[2,16,48]{2,1,0}}, allow_spmd_sharding_propagation_to_parameters={false,false,false,false}, num_partitions=4 ENTRY main.12 { - Arg_0.1 = f8e4m3fn[2,16,48]{2,1,0} parameter(0), sharding={devices=[1,4,1]<=[4]} - Arg_1.2 = f8e4m3fn[48,192]{1,0} parameter(1), sharding={devices=[1,4]<=[4]} + Arg_0.1 = <>[2,16,48]{2,1,0} parameter(0), sharding={devices=[1,4,1]<=[4]} + Arg_1.2 = <>[48,192]{1,0} parameter(1), sharding={devices=[1,4]<=[4]} Arg_2.3 = bf16[] parameter(3) Arg_3.4 = bf16[] parameter(4) broadcast = bf16[2,16,48]{2,1,0} broadcast(Arg_2.3), dimensions={} @@ -834,12 +849,12 @@ ENTRY main.12 { constant.1 = bf16[] constant(448.) broadcast.4 = bf16[2,16,192]{2,1,0} broadcast(constant.1), dimensions={} clamp = bf16[2,16,192]{2,1,0} clamp(broadcast.3, divide, broadcast.4) - convert.2 = f8e4m3fn[2,16,192]{2,1,0} convert(clamp) + convert.2 = <>[2,16,192]{2,1,0} convert(clamp) Arg_5.6 = bf16[] parameter(6) broadcast.5 = bf16[2,16,192]{2,1,0} broadcast(Arg_5.6), dimensions={} convert.3 = bf16[2,16,192]{2,1,0} convert(convert.2) multiply.2 = bf16[2,16,192]{2,1,0} multiply(convert.3, broadcast.5) - Arg_6.7 = f8e4m3fn[192,48]{1,0} parameter(2), sharding={devices=[4,1]<=[4]} + Arg_6.7 = <>[192,48]{1,0} parameter(2), sharding={devices=[4,1]<=[4]} Arg_7.8 = bf16[] parameter(7) broadcast.6 = bf16[192,48]{1,0} broadcast(Arg_7.8), dimensions={} convert.4 = bf16[192,48]{1,0} convert(Arg_6.7) @@ -852,8 +867,9 @@ ENTRY main.12 { // Disable the dot merger pass which can prevent the creation of FP8 GEMM // Custom Calls. - CollectiveOpsCompareWindowedNonWindowed(kModuleReplicatedStr, - /*disable_dot_merger=*/true); + CollectiveOpsCompareWindowedNonWindowed( + absl::StrReplaceAll(kModuleReplicatedStr, replacements_), + /*disable_dot_merger=*/true); // Verify the creation of FP8 GEMM Custom Calls on Hopper and newer // architectures. @@ -863,7 +879,8 @@ ENTRY main.12 { opts.set_xla_gpu_graph_min_graph_size(200); opts.set_xla_gpu_enable_triton_gemm(false); opts.add_xla_disable_hlo_passes("dot-merger"); - CollectiveOpsVerifyF8Matmul(kModuleReplicatedStr, opts); + CollectiveOpsVerifyF8Matmul( + absl::StrReplaceAll(kModuleReplicatedStr, replacements_), opts); } TEST_F(CollectiveOpsTestE2EWindowedNonWindowed, @@ -1023,7 +1040,7 @@ while_body { r = bf16[32,128] bitcast(dynamic-slice.k) a = bf16[32,128] add(r, r), control-predecessors={constant.2559} // A fp8 pattern of quant-dequant before the collective AG. - qa = f8e4m3fn[32,128] convert(a) + qa = <>[32,128] convert(a) dqa = bf16[32,128] convert(qa) a_scale = bf16[] get-tuple-element(param), index=3 a_scales = bf16[32,128] broadcast(a_scale), dimensions={} @@ -1031,7 +1048,7 @@ while_body { mb = bf16[128,128] all-gather(dqa_unscaled), channel_id=1, use_global_device_ids=true, dimensions={0}, replica_groups={{0,1,2,3}} ma = bf16[128,128] dynamic-slice(get-tuple-element.395, select.1348, constant.2561), dynamic_slice_sizes={128,128} - qma = f8e4m3fn[128,128] convert(ma) + qma = <>[128,128] convert(ma) dqma = bf16[128,128] convert(qma) ma_scale = bf16[] get-tuple-element(param), index=4 ma_scales = bf16[128,128] broadcast(ma_scale), dimensions={} @@ -1061,7 +1078,8 @@ ENTRY entry { opts.set_xla_gpu_run_post_layout_collective_pipeliner(true); opts.set_xla_gpu_enable_pipelined_collectives(true); opts.set_xla_gpu_enable_triton_gemm(false); - CollectiveOpsVerifyF8Matmul(kModuleReplicatedStr, opts); + CollectiveOpsVerifyF8Matmul( + absl::StrReplaceAll(kModuleReplicatedStr, replacements_), opts); } TEST_F(CollectiveOpsTestE2E, diff --git a/xla/tests/collective_ops_test.cc b/xla/tests/collective_ops_test.cc index 9cd874c9e03c13..fcecf8f4a66cef 100644 --- a/xla/tests/collective_ops_test.cc +++ b/xla/tests/collective_ops_test.cc @@ -1753,80 +1753,6 @@ XLA_TEST_F(CollectiveOpsTest, AllReduceBFloat16Min) { } } -XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(AllGather_8BitFloat)) { - const char* const kModuleStr = R"( - HloModule test - ENTRY test_computation { - a0 = f8e4m3fn[1,2] constant({{1,2}}) - allgather = f8e4m3fn[2, 2] all-gather(a0), dimensions={0} - p = f8e4m3fn[4] reshape(allgather) - ROOT out = f32[4] convert(p) - } - )"; - const int64_t kNumReplicas = 2; - HloModuleConfig config = - GetModuleConfigForTest(/*replica_count=*/kNumReplicas); - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(kModuleStr, config)); - TF_ASSERT_OK_AND_ASSIGN( - std::vector results, - ExecuteReplicated(std::move(module), absl::Span{}, - kNumReplicas, - /*use_threads=*/true, /*run_hlo_passes=*/true)); - ASSERT_EQ(results.size(), kNumReplicas); - for (const Literal& result : results) { - LiteralTestUtil::ExpectR1Equal({1, 2, 1, 2}, result); - } -} - -XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(AllToAll_8BitFloat)) { - const char* const kModuleStr = R"( - HloModule test - ENTRY test_computation { - a0 = f8e4m3fn[2] constant({1,2}) - a2a = f8e4m3fn[2] all-to-all(a0), dimensions={0} - ROOT out = f32[2] convert(a2a) - } - )"; - const int64_t kNumReplicas = 2; - HloModuleConfig config = - GetModuleConfigForTest(/*replica_count=*/kNumReplicas); - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(kModuleStr, config)); - TF_ASSERT_OK_AND_ASSIGN( - std::vector results, - ExecuteReplicated(std::move(module), absl::Span{}, - kNumReplicas, - /*use_threads=*/true, /*run_hlo_passes=*/true)); - ASSERT_EQ(results.size(), kNumReplicas); - LiteralTestUtil::ExpectR1Equal({1, 1}, results[0]); - LiteralTestUtil::ExpectR1Equal({2, 2}, results[1]); -} - -XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(CollectivePermute_8BitFloat)) { - const char* const kModuleStr = R"( - HloModule test - ENTRY test_computation { - a0 = f8e5m2[2] constant({1,2}) - a1 = f8e5m2[2] collective-permute(a0), source_target_pairs={{0,1}, {1,0}} - ROOT out = f32[2] convert(a1) - } - )"; - const int64_t kNumReplicas = 2; - HloModuleConfig config = - GetModuleConfigForTest(/*replica_count=*/kNumReplicas); - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(kModuleStr, config)); - TF_ASSERT_OK_AND_ASSIGN( - std::vector results, - ExecuteReplicated(std::move(module), absl::Span{}, - kNumReplicas, - /*use_threads=*/true, /*run_hlo_passes=*/true)); - ASSERT_EQ(results.size(), kNumReplicas); - LiteralTestUtil::ExpectR1Equal({1, 2}, results[0]); - LiteralTestUtil::ExpectR1Equal({1, 2}, results[1]); -} - XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(AsyncAllGather)) { const char* const kModuleStr = R"( HloModule test @@ -2273,5 +2199,110 @@ body { results[1])); } +class Fp8CollectiveOpsTest : public CollectiveOpsTest { + public: + Fp8CollectiveOpsTest() { + replacements_[kF8E4M3DatatypePlaceholder] = + IsCuda() ? "f8e4m3fn" : "f8e4m3fnuz"; + replacements_[kF8E5M2DatatypePlaceholder] = + IsCuda() ? "f8e5m2" : "f8e5m2fnuz"; + } + + protected: + bool IsCuda() { + return std::holds_alternative(Capability()); + } + + const se::GpuComputeCapability& Capability() { + return backend() + .default_stream_executor() + ->GetDeviceDescription() + .gpu_compute_capability(); + } + + absl::flat_hash_map replacements_; + + private: + static constexpr const char* kF8E4M3DatatypePlaceholder{"<>"}; + static constexpr const char* kF8E5M2DatatypePlaceholder{"<>"}; +}; + +XLA_TEST_F(Fp8CollectiveOpsTest, DISABLED_ON_CPU(AllGather_8BitFloat)) { + const char* const kModuleStr = R"( + HloModule test + ENTRY test_computation { + a0 = <>[1,2] constant({{1,2}}) + allgather = <>[2, 2] all-gather(a0), dimensions={0} + p = <>[4] reshape(allgather) + ROOT out = f32[4] convert(p) + } + )"; + const int64_t kNumReplicas = 2; + HloModuleConfig config = + GetModuleConfigForTest(/*replica_count=*/kNumReplicas); + TF_ASSERT_OK_AND_ASSIGN( + auto module, ParseAndReturnVerifiedModule( + absl::StrReplaceAll(kModuleStr, replacements_), config)); + TF_ASSERT_OK_AND_ASSIGN( + std::vector results, + ExecuteReplicated(std::move(module), absl::Span{}, + kNumReplicas, + /*use_threads=*/true, /*run_hlo_passes=*/true)); + ASSERT_EQ(results.size(), kNumReplicas); + for (const Literal& result : results) { + LiteralTestUtil::ExpectR1Equal({1, 2, 1, 2}, result); + } +} + +XLA_TEST_F(Fp8CollectiveOpsTest, DISABLED_ON_CPU(AllToAll_8BitFloat)) { + const char* const kModuleStr = R"( + HloModule test + ENTRY test_computation { + a0 = <>[2] constant({1,2}) + a2a = <>[2] all-to-all(a0), dimensions={0} + ROOT out = f32[2] convert(a2a) + } + )"; + const int64_t kNumReplicas = 2; + HloModuleConfig config = + GetModuleConfigForTest(/*replica_count=*/kNumReplicas); + TF_ASSERT_OK_AND_ASSIGN( + auto module, ParseAndReturnVerifiedModule( + absl::StrReplaceAll(kModuleStr, replacements_), config)); + TF_ASSERT_OK_AND_ASSIGN( + std::vector results, + ExecuteReplicated(std::move(module), absl::Span{}, + kNumReplicas, + /*use_threads=*/true, /*run_hlo_passes=*/true)); + ASSERT_EQ(results.size(), kNumReplicas); + LiteralTestUtil::ExpectR1Equal({1, 1}, results[0]); + LiteralTestUtil::ExpectR1Equal({2, 2}, results[1]); +} + +XLA_TEST_F(Fp8CollectiveOpsTest, DISABLED_ON_CPU(CollectivePermute_8BitFloat)) { + const char* const kModuleStr = R"( + HloModule test + ENTRY test_computation { + a0 = <>[2] constant({1,2}) + a1 = <>[2] collective-permute(a0), source_target_pairs={{0,1}, {1,0}} + ROOT out = f32[2] convert(a1) + } + )"; + const int64_t kNumReplicas = 2; + HloModuleConfig config = + GetModuleConfigForTest(/*replica_count=*/kNumReplicas); + TF_ASSERT_OK_AND_ASSIGN( + auto module, ParseAndReturnVerifiedModule( + absl::StrReplaceAll(kModuleStr, replacements_), config)); + TF_ASSERT_OK_AND_ASSIGN( + std::vector results, + ExecuteReplicated(std::move(module), absl::Span{}, + kNumReplicas, + /*use_threads=*/true, /*run_hlo_passes=*/true)); + ASSERT_EQ(results.size(), kNumReplicas); + LiteralTestUtil::ExpectR1Equal({1, 2}, results[0]); + LiteralTestUtil::ExpectR1Equal({1, 2}, results[1]); +} + } // namespace } // namespace xla