From b20928845fa5f98928a3bdd8cbf161307baec78f Mon Sep 17 00:00:00 2001 From: Xuefei Jiang Date: Tue, 17 Sep 2024 11:07:13 -0700 Subject: [PATCH] PR #16938: Add NANOO FP8 support for collaborative communication unit tests Imported from GitHub PR https://github.com/openxla/xla/pull/16938 This PR adds support for NANOO FP8 data format in the collaborative communication unit tests. - For the context on OCP FP8 and NANOO FP8, please refer to this comment: https://github.com/google/flax/pull/3993#issue-2350000228 - The unit tests in this PR are similar to GEMM unit test introduced in the following PR to be able to deal with both OCP and NANOO fp8 formats: https://github.com/openxla/xla/pull/10488 Copybara import of the project: -- 0fc74ccae6cfcaf4e8627ea338ee03783af0626b by Wen Chen : [AMD] Added NCCL support for fp8e4m3fnuz and fp8e5m2fnuz. -- d247af5cd33fe42698bb55ef1c18f32df8a02a21 by scxfjiang : refactor tests for collective comm ops -- 6f8c418b3052f7c531896bd5f8cbbc7a766ef7fc by scxfjiang : rafactor collective comm e2e tests -- 8ecb6ecf08a1536c5b3f8ba87e0e9f8813b1b359 by scxfjiang : update: replace str -- 338d3af2ca1a32302fdfe9d7abee335d24539ee9 by scxfjiang : get rid of macros Merging this change closes #16938 FUTURE_COPYBARA_INTEGRATE_REVIEW=https://github.com/openxla/xla/pull/16938 from ROCm:ci_dev_rccl_nanoo_fp8 338d3af2ca1a32302fdfe9d7abee335d24539ee9 PiperOrigin-RevId: 675635116 --- xla/service/gpu/runtime/nccl_api.cc | 2 + .../gpu/runtime/nccl_collective_thunk.cc | 2 + xla/tests/collective_ops_e2e_test.cc | 40 ++-- xla/tests/collective_ops_test.cc | 179 ++++++++++-------- 4 files changed, 138 insertions(+), 85 deletions(-) diff --git a/xla/service/gpu/runtime/nccl_api.cc b/xla/service/gpu/runtime/nccl_api.cc index 77f022da6ec64f..15949ac9cae999 100644 --- a/xla/service/gpu/runtime/nccl_api.cc +++ b/xla/service/gpu/runtime/nccl_api.cc @@ -112,6 +112,8 @@ static absl::StatusOr ToNcclDataType(PrimitiveType dtype, case S8: case F8E5M2: case F8E4M3FN: + case F8E5M2FNUZ: + case F8E4M3FNUZ: return ncclInt8; case PRED: case U8: diff --git a/xla/service/gpu/runtime/nccl_collective_thunk.cc b/xla/service/gpu/runtime/nccl_collective_thunk.cc index 8e075c8d01c730..fb2282c8e73ae7 100644 --- a/xla/service/gpu/runtime/nccl_collective_thunk.cc +++ b/xla/service/gpu/runtime/nccl_collective_thunk.cc @@ -92,6 +92,8 @@ bool IsTypeSupportedByNccl(PrimitiveType element_type, // they involve actual computation and not just data movement. case F8E5M2: case F8E4M3FN: + case F8E5M2FNUZ: + case F8E4M3FNUZ: return !IsReductionCollective(reduction_op); default: return false; diff --git a/xla/tests/collective_ops_e2e_test.cc b/xla/tests/collective_ops_e2e_test.cc index 1e399127318242..cecf02827a99e1 100644 --- a/xla/tests/collective_ops_e2e_test.cc +++ b/xla/tests/collective_ops_e2e_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "absl/strings/str_replace.h" #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -54,6 +55,13 @@ DeviceAssignment MakeDeviceAssn(int64_t num_replicas) { class CollectiveOpsTestE2E : public HloTestBase { public: + CollectiveOpsTestE2E() { + replacements_[kF8E4M3DatatypePlaceholder] = + IsCuda() ? "f8e4m3fn" : "f8e4m3fnuz"; + replacements_[kF8E5M2DatatypePlaceholder] = + IsCuda() ? "f8e5m2" : "f8e5m2fnuz"; + } + bool IsCuda() { return std::holds_alternative(Capability()); } @@ -108,6 +116,13 @@ class CollectiveOpsTestE2E : public HloTestBase { /*argument_provider*/ [](int64_t, int64_t) { return nullptr; }, num_replicas, /*run_hlo_passes=*/false, &device_assignment); } + + protected: + absl::flat_hash_map replacements_; + + private: + static constexpr const char* kF8E4M3DatatypePlaceholder{"<>"}; + static constexpr const char* kF8E5M2DatatypePlaceholder{"<>"}; }; // E2E tests for collective ops. These will generally verify some HLO transform @@ -811,11 +826,11 @@ ENTRY main.12 { TEST_F(CollectiveOpsTestE2EWindowedNonWindowed, WindowedEinsumE2EAllGatherAndReduceScatterF8) { absl::string_view kModuleReplicatedStr = R"( -HloModule pjit__unnamed_wrapped_function_, entry_computation_layout={(f8e4m3fn[2,16,48]{2,1,0}, f8e4m3fn[48,192]{1,0}, f8e4m3fn[192,48]{1,0}, bf16[], bf16[], bf16[], bf16[], bf16[])->bf16[2,16,48]{2,1,0}}, allow_spmd_sharding_propagation_to_parameters={false,false,false,false}, num_partitions=4 +HloModule pjit__unnamed_wrapped_function_, entry_computation_layout={(<>[2,16,48]{2,1,0}, <>[48,192]{1,0}, <>[192,48]{1,0}, bf16[], bf16[], bf16[], bf16[], bf16[])->bf16[2,16,48]{2,1,0}}, allow_spmd_sharding_propagation_to_parameters={false,false,false,false}, num_partitions=4 ENTRY main.12 { - Arg_0.1 = f8e4m3fn[2,16,48]{2,1,0} parameter(0), sharding={devices=[1,4,1]<=[4]} - Arg_1.2 = f8e4m3fn[48,192]{1,0} parameter(1), sharding={devices=[1,4]<=[4]} + Arg_0.1 = <>[2,16,48]{2,1,0} parameter(0), sharding={devices=[1,4,1]<=[4]} + Arg_1.2 = <>[48,192]{1,0} parameter(1), sharding={devices=[1,4]<=[4]} Arg_2.3 = bf16[] parameter(3) Arg_3.4 = bf16[] parameter(4) broadcast = bf16[2,16,48]{2,1,0} broadcast(Arg_2.3), dimensions={} @@ -834,12 +849,12 @@ ENTRY main.12 { constant.1 = bf16[] constant(448.) broadcast.4 = bf16[2,16,192]{2,1,0} broadcast(constant.1), dimensions={} clamp = bf16[2,16,192]{2,1,0} clamp(broadcast.3, divide, broadcast.4) - convert.2 = f8e4m3fn[2,16,192]{2,1,0} convert(clamp) + convert.2 = <>[2,16,192]{2,1,0} convert(clamp) Arg_5.6 = bf16[] parameter(6) broadcast.5 = bf16[2,16,192]{2,1,0} broadcast(Arg_5.6), dimensions={} convert.3 = bf16[2,16,192]{2,1,0} convert(convert.2) multiply.2 = bf16[2,16,192]{2,1,0} multiply(convert.3, broadcast.5) - Arg_6.7 = f8e4m3fn[192,48]{1,0} parameter(2), sharding={devices=[4,1]<=[4]} + Arg_6.7 = <>[192,48]{1,0} parameter(2), sharding={devices=[4,1]<=[4]} Arg_7.8 = bf16[] parameter(7) broadcast.6 = bf16[192,48]{1,0} broadcast(Arg_7.8), dimensions={} convert.4 = bf16[192,48]{1,0} convert(Arg_6.7) @@ -852,8 +867,9 @@ ENTRY main.12 { // Disable the dot merger pass which can prevent the creation of FP8 GEMM // Custom Calls. - CollectiveOpsCompareWindowedNonWindowed(kModuleReplicatedStr, - /*disable_dot_merger=*/true); + CollectiveOpsCompareWindowedNonWindowed( + absl::StrReplaceAll(kModuleReplicatedStr, replacements_), + /*disable_dot_merger=*/true); // Verify the creation of FP8 GEMM Custom Calls on Hopper and newer // architectures. @@ -863,7 +879,8 @@ ENTRY main.12 { opts.set_xla_gpu_graph_min_graph_size(200); opts.set_xla_gpu_enable_triton_gemm(false); opts.add_xla_disable_hlo_passes("dot-merger"); - CollectiveOpsVerifyF8Matmul(kModuleReplicatedStr, opts); + CollectiveOpsVerifyF8Matmul( + absl::StrReplaceAll(kModuleReplicatedStr, replacements_), opts); } TEST_F(CollectiveOpsTestE2EWindowedNonWindowed, @@ -1023,7 +1040,7 @@ while_body { r = bf16[32,128] bitcast(dynamic-slice.k) a = bf16[32,128] add(r, r), control-predecessors={constant.2559} // A fp8 pattern of quant-dequant before the collective AG. - qa = f8e4m3fn[32,128] convert(a) + qa = <>[32,128] convert(a) dqa = bf16[32,128] convert(qa) a_scale = bf16[] get-tuple-element(param), index=3 a_scales = bf16[32,128] broadcast(a_scale), dimensions={} @@ -1031,7 +1048,7 @@ while_body { mb = bf16[128,128] all-gather(dqa_unscaled), channel_id=1, use_global_device_ids=true, dimensions={0}, replica_groups={{0,1,2,3}} ma = bf16[128,128] dynamic-slice(get-tuple-element.395, select.1348, constant.2561), dynamic_slice_sizes={128,128} - qma = f8e4m3fn[128,128] convert(ma) + qma = <>[128,128] convert(ma) dqma = bf16[128,128] convert(qma) ma_scale = bf16[] get-tuple-element(param), index=4 ma_scales = bf16[128,128] broadcast(ma_scale), dimensions={} @@ -1061,7 +1078,8 @@ ENTRY entry { opts.set_xla_gpu_run_post_layout_collective_pipeliner(true); opts.set_xla_gpu_enable_pipelined_collectives(true); opts.set_xla_gpu_enable_triton_gemm(false); - CollectiveOpsVerifyF8Matmul(kModuleReplicatedStr, opts); + CollectiveOpsVerifyF8Matmul( + absl::StrReplaceAll(kModuleReplicatedStr, replacements_), opts); } TEST_F(CollectiveOpsTestE2E, diff --git a/xla/tests/collective_ops_test.cc b/xla/tests/collective_ops_test.cc index 9cd874c9e03c13..fcecf8f4a66cef 100644 --- a/xla/tests/collective_ops_test.cc +++ b/xla/tests/collective_ops_test.cc @@ -1753,80 +1753,6 @@ XLA_TEST_F(CollectiveOpsTest, AllReduceBFloat16Min) { } } -XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(AllGather_8BitFloat)) { - const char* const kModuleStr = R"( - HloModule test - ENTRY test_computation { - a0 = f8e4m3fn[1,2] constant({{1,2}}) - allgather = f8e4m3fn[2, 2] all-gather(a0), dimensions={0} - p = f8e4m3fn[4] reshape(allgather) - ROOT out = f32[4] convert(p) - } - )"; - const int64_t kNumReplicas = 2; - HloModuleConfig config = - GetModuleConfigForTest(/*replica_count=*/kNumReplicas); - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(kModuleStr, config)); - TF_ASSERT_OK_AND_ASSIGN( - std::vector results, - ExecuteReplicated(std::move(module), absl::Span{}, - kNumReplicas, - /*use_threads=*/true, /*run_hlo_passes=*/true)); - ASSERT_EQ(results.size(), kNumReplicas); - for (const Literal& result : results) { - LiteralTestUtil::ExpectR1Equal({1, 2, 1, 2}, result); - } -} - -XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(AllToAll_8BitFloat)) { - const char* const kModuleStr = R"( - HloModule test - ENTRY test_computation { - a0 = f8e4m3fn[2] constant({1,2}) - a2a = f8e4m3fn[2] all-to-all(a0), dimensions={0} - ROOT out = f32[2] convert(a2a) - } - )"; - const int64_t kNumReplicas = 2; - HloModuleConfig config = - GetModuleConfigForTest(/*replica_count=*/kNumReplicas); - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(kModuleStr, config)); - TF_ASSERT_OK_AND_ASSIGN( - std::vector results, - ExecuteReplicated(std::move(module), absl::Span{}, - kNumReplicas, - /*use_threads=*/true, /*run_hlo_passes=*/true)); - ASSERT_EQ(results.size(), kNumReplicas); - LiteralTestUtil::ExpectR1Equal({1, 1}, results[0]); - LiteralTestUtil::ExpectR1Equal({2, 2}, results[1]); -} - -XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(CollectivePermute_8BitFloat)) { - const char* const kModuleStr = R"( - HloModule test - ENTRY test_computation { - a0 = f8e5m2[2] constant({1,2}) - a1 = f8e5m2[2] collective-permute(a0), source_target_pairs={{0,1}, {1,0}} - ROOT out = f32[2] convert(a1) - } - )"; - const int64_t kNumReplicas = 2; - HloModuleConfig config = - GetModuleConfigForTest(/*replica_count=*/kNumReplicas); - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(kModuleStr, config)); - TF_ASSERT_OK_AND_ASSIGN( - std::vector results, - ExecuteReplicated(std::move(module), absl::Span{}, - kNumReplicas, - /*use_threads=*/true, /*run_hlo_passes=*/true)); - ASSERT_EQ(results.size(), kNumReplicas); - LiteralTestUtil::ExpectR1Equal({1, 2}, results[0]); - LiteralTestUtil::ExpectR1Equal({1, 2}, results[1]); -} - XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(AsyncAllGather)) { const char* const kModuleStr = R"( HloModule test @@ -2273,5 +2199,110 @@ body { results[1])); } +class Fp8CollectiveOpsTest : public CollectiveOpsTest { + public: + Fp8CollectiveOpsTest() { + replacements_[kF8E4M3DatatypePlaceholder] = + IsCuda() ? "f8e4m3fn" : "f8e4m3fnuz"; + replacements_[kF8E5M2DatatypePlaceholder] = + IsCuda() ? "f8e5m2" : "f8e5m2fnuz"; + } + + protected: + bool IsCuda() { + return std::holds_alternative(Capability()); + } + + const se::GpuComputeCapability& Capability() { + return backend() + .default_stream_executor() + ->GetDeviceDescription() + .gpu_compute_capability(); + } + + absl::flat_hash_map replacements_; + + private: + static constexpr const char* kF8E4M3DatatypePlaceholder{"<>"}; + static constexpr const char* kF8E5M2DatatypePlaceholder{"<>"}; +}; + +XLA_TEST_F(Fp8CollectiveOpsTest, DISABLED_ON_CPU(AllGather_8BitFloat)) { + const char* const kModuleStr = R"( + HloModule test + ENTRY test_computation { + a0 = <>[1,2] constant({{1,2}}) + allgather = <>[2, 2] all-gather(a0), dimensions={0} + p = <>[4] reshape(allgather) + ROOT out = f32[4] convert(p) + } + )"; + const int64_t kNumReplicas = 2; + HloModuleConfig config = + GetModuleConfigForTest(/*replica_count=*/kNumReplicas); + TF_ASSERT_OK_AND_ASSIGN( + auto module, ParseAndReturnVerifiedModule( + absl::StrReplaceAll(kModuleStr, replacements_), config)); + TF_ASSERT_OK_AND_ASSIGN( + std::vector results, + ExecuteReplicated(std::move(module), absl::Span{}, + kNumReplicas, + /*use_threads=*/true, /*run_hlo_passes=*/true)); + ASSERT_EQ(results.size(), kNumReplicas); + for (const Literal& result : results) { + LiteralTestUtil::ExpectR1Equal({1, 2, 1, 2}, result); + } +} + +XLA_TEST_F(Fp8CollectiveOpsTest, DISABLED_ON_CPU(AllToAll_8BitFloat)) { + const char* const kModuleStr = R"( + HloModule test + ENTRY test_computation { + a0 = <>[2] constant({1,2}) + a2a = <>[2] all-to-all(a0), dimensions={0} + ROOT out = f32[2] convert(a2a) + } + )"; + const int64_t kNumReplicas = 2; + HloModuleConfig config = + GetModuleConfigForTest(/*replica_count=*/kNumReplicas); + TF_ASSERT_OK_AND_ASSIGN( + auto module, ParseAndReturnVerifiedModule( + absl::StrReplaceAll(kModuleStr, replacements_), config)); + TF_ASSERT_OK_AND_ASSIGN( + std::vector results, + ExecuteReplicated(std::move(module), absl::Span{}, + kNumReplicas, + /*use_threads=*/true, /*run_hlo_passes=*/true)); + ASSERT_EQ(results.size(), kNumReplicas); + LiteralTestUtil::ExpectR1Equal({1, 1}, results[0]); + LiteralTestUtil::ExpectR1Equal({2, 2}, results[1]); +} + +XLA_TEST_F(Fp8CollectiveOpsTest, DISABLED_ON_CPU(CollectivePermute_8BitFloat)) { + const char* const kModuleStr = R"( + HloModule test + ENTRY test_computation { + a0 = <>[2] constant({1,2}) + a1 = <>[2] collective-permute(a0), source_target_pairs={{0,1}, {1,0}} + ROOT out = f32[2] convert(a1) + } + )"; + const int64_t kNumReplicas = 2; + HloModuleConfig config = + GetModuleConfigForTest(/*replica_count=*/kNumReplicas); + TF_ASSERT_OK_AND_ASSIGN( + auto module, ParseAndReturnVerifiedModule( + absl::StrReplaceAll(kModuleStr, replacements_), config)); + TF_ASSERT_OK_AND_ASSIGN( + std::vector results, + ExecuteReplicated(std::move(module), absl::Span{}, + kNumReplicas, + /*use_threads=*/true, /*run_hlo_passes=*/true)); + ASSERT_EQ(results.size(), kNumReplicas); + LiteralTestUtil::ExpectR1Equal({1, 2}, results[0]); + LiteralTestUtil::ExpectR1Equal({1, 2}, results[1]); +} + } // namespace } // namespace xla