Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PR #16938: Add NANOO FP8 support for collaborative communication unit tests #17265

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions xla/service/gpu/runtime/nccl_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,8 @@ static absl::StatusOr<ncclDataType_t> ToNcclDataType(PrimitiveType dtype,
case S8:
case F8E5M2:
case F8E4M3FN:
case F8E5M2FNUZ:
case F8E4M3FNUZ:
return ncclInt8;
case PRED:
case U8:
Expand Down
2 changes: 2 additions & 0 deletions xla/service/gpu/runtime/nccl_collective_thunk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ bool IsTypeSupportedByNccl(PrimitiveType element_type,
// they involve actual computation and not just data movement.
case F8E5M2:
case F8E4M3FN:
case F8E5M2FNUZ:
case F8E4M3FNUZ:
return !IsReductionCollective(reduction_op);
default:
return false;
Expand Down
40 changes: 29 additions & 11 deletions xla/tests/collective_ops_e2e_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ limitations under the License.
#include <variant>
#include <vector>

#include "absl/strings/str_replace.h"
#include "absl/strings/string_view.h"
#include "xla/hlo/ir/hlo_casting_utils.h"
#include "xla/hlo/ir/hlo_instruction.h"
Expand Down Expand Up @@ -54,6 +55,13 @@ DeviceAssignment MakeDeviceAssn(int64_t num_replicas) {

class CollectiveOpsTestE2E : public HloTestBase {
public:
CollectiveOpsTestE2E() {
replacements_[kF8E4M3DatatypePlaceholder] =
IsCuda() ? "f8e4m3fn" : "f8e4m3fnuz";
replacements_[kF8E5M2DatatypePlaceholder] =
IsCuda() ? "f8e5m2" : "f8e5m2fnuz";
}

bool IsCuda() {
return std::holds_alternative<se::CudaComputeCapability>(Capability());
}
Expand Down Expand Up @@ -108,6 +116,13 @@ class CollectiveOpsTestE2E : public HloTestBase {
/*argument_provider*/ [](int64_t, int64_t) { return nullptr; },
num_replicas, /*run_hlo_passes=*/false, &device_assignment);
}

protected:
absl::flat_hash_map<absl::string_view, absl::string_view> replacements_;

private:
static constexpr const char* kF8E4M3DatatypePlaceholder{"<<F8E4M3>>"};
static constexpr const char* kF8E5M2DatatypePlaceholder{"<<F8E5M2>>"};
};

// E2E tests for collective ops. These will generally verify some HLO transform
Expand Down Expand Up @@ -811,11 +826,11 @@ ENTRY main.12 {
TEST_F(CollectiveOpsTestE2EWindowedNonWindowed,
WindowedEinsumE2EAllGatherAndReduceScatterF8) {
absl::string_view kModuleReplicatedStr = R"(
HloModule pjit__unnamed_wrapped_function_, entry_computation_layout={(f8e4m3fn[2,16,48]{2,1,0}, f8e4m3fn[48,192]{1,0}, f8e4m3fn[192,48]{1,0}, bf16[], bf16[], bf16[], bf16[], bf16[])->bf16[2,16,48]{2,1,0}}, allow_spmd_sharding_propagation_to_parameters={false,false,false,false}, num_partitions=4
HloModule pjit__unnamed_wrapped_function_, entry_computation_layout={(<<F8E4M3>>[2,16,48]{2,1,0}, <<F8E4M3>>[48,192]{1,0}, <<F8E4M3>>[192,48]{1,0}, bf16[], bf16[], bf16[], bf16[], bf16[])->bf16[2,16,48]{2,1,0}}, allow_spmd_sharding_propagation_to_parameters={false,false,false,false}, num_partitions=4

ENTRY main.12 {
Arg_0.1 = f8e4m3fn[2,16,48]{2,1,0} parameter(0), sharding={devices=[1,4,1]<=[4]}
Arg_1.2 = f8e4m3fn[48,192]{1,0} parameter(1), sharding={devices=[1,4]<=[4]}
Arg_0.1 = <<F8E4M3>>[2,16,48]{2,1,0} parameter(0), sharding={devices=[1,4,1]<=[4]}
Arg_1.2 = <<F8E4M3>>[48,192]{1,0} parameter(1), sharding={devices=[1,4]<=[4]}
Arg_2.3 = bf16[] parameter(3)
Arg_3.4 = bf16[] parameter(4)
broadcast = bf16[2,16,48]{2,1,0} broadcast(Arg_2.3), dimensions={}
Expand All @@ -834,12 +849,12 @@ ENTRY main.12 {
constant.1 = bf16[] constant(448.)
broadcast.4 = bf16[2,16,192]{2,1,0} broadcast(constant.1), dimensions={}
clamp = bf16[2,16,192]{2,1,0} clamp(broadcast.3, divide, broadcast.4)
convert.2 = f8e4m3fn[2,16,192]{2,1,0} convert(clamp)
convert.2 = <<F8E4M3>>[2,16,192]{2,1,0} convert(clamp)
Arg_5.6 = bf16[] parameter(6)
broadcast.5 = bf16[2,16,192]{2,1,0} broadcast(Arg_5.6), dimensions={}
convert.3 = bf16[2,16,192]{2,1,0} convert(convert.2)
multiply.2 = bf16[2,16,192]{2,1,0} multiply(convert.3, broadcast.5)
Arg_6.7 = f8e4m3fn[192,48]{1,0} parameter(2), sharding={devices=[4,1]<=[4]}
Arg_6.7 = <<F8E4M3>>[192,48]{1,0} parameter(2), sharding={devices=[4,1]<=[4]}
Arg_7.8 = bf16[] parameter(7)
broadcast.6 = bf16[192,48]{1,0} broadcast(Arg_7.8), dimensions={}
convert.4 = bf16[192,48]{1,0} convert(Arg_6.7)
Expand All @@ -852,8 +867,9 @@ ENTRY main.12 {

// Disable the dot merger pass which can prevent the creation of FP8 GEMM
// Custom Calls.
CollectiveOpsCompareWindowedNonWindowed(kModuleReplicatedStr,
/*disable_dot_merger=*/true);
CollectiveOpsCompareWindowedNonWindowed(
absl::StrReplaceAll(kModuleReplicatedStr, replacements_),
/*disable_dot_merger=*/true);

// Verify the creation of FP8 GEMM Custom Calls on Hopper and newer
// architectures.
Expand All @@ -863,7 +879,8 @@ ENTRY main.12 {
opts.set_xla_gpu_graph_min_graph_size(200);
opts.set_xla_gpu_enable_triton_gemm(false);
opts.add_xla_disable_hlo_passes("dot-merger");
CollectiveOpsVerifyF8Matmul(kModuleReplicatedStr, opts);
CollectiveOpsVerifyF8Matmul(
absl::StrReplaceAll(kModuleReplicatedStr, replacements_), opts);
}

TEST_F(CollectiveOpsTestE2EWindowedNonWindowed,
Expand Down Expand Up @@ -1023,15 +1040,15 @@ while_body {
r = bf16[32,128] bitcast(dynamic-slice.k)
a = bf16[32,128] add(r, r), control-predecessors={constant.2559}
// A fp8 pattern of quant-dequant before the collective AG.
qa = f8e4m3fn[32,128] convert(a)
qa = <<F8E4M3>>[32,128] convert(a)
dqa = bf16[32,128] convert(qa)
a_scale = bf16[] get-tuple-element(param), index=3
a_scales = bf16[32,128] broadcast(a_scale), dimensions={}
dqa_unscaled = bf16[32,128] multiply(dqa, a_scales)
mb = bf16[128,128] all-gather(dqa_unscaled), channel_id=1, use_global_device_ids=true, dimensions={0}, replica_groups={{0,1,2,3}}
ma = bf16[128,128] dynamic-slice(get-tuple-element.395, select.1348, constant.2561), dynamic_slice_sizes={128,128}

qma = f8e4m3fn[128,128] convert(ma)
qma = <<F8E4M3>>[128,128] convert(ma)
dqma = bf16[128,128] convert(qma)
ma_scale = bf16[] get-tuple-element(param), index=4
ma_scales = bf16[128,128] broadcast(ma_scale), dimensions={}
Expand Down Expand Up @@ -1061,7 +1078,8 @@ ENTRY entry {
opts.set_xla_gpu_run_post_layout_collective_pipeliner(true);
opts.set_xla_gpu_enable_pipelined_collectives(true);
opts.set_xla_gpu_enable_triton_gemm(false);
CollectiveOpsVerifyF8Matmul(kModuleReplicatedStr, opts);
CollectiveOpsVerifyF8Matmul(
absl::StrReplaceAll(kModuleReplicatedStr, replacements_), opts);
}

TEST_F(CollectiveOpsTestE2E,
Expand Down
179 changes: 105 additions & 74 deletions xla/tests/collective_ops_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1753,80 +1753,6 @@ XLA_TEST_F(CollectiveOpsTest, AllReduceBFloat16Min) {
}
}

XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(AllGather_8BitFloat)) {
const char* const kModuleStr = R"(
HloModule test
ENTRY test_computation {
a0 = f8e4m3fn[1,2] constant({{1,2}})
allgather = f8e4m3fn[2, 2] all-gather(a0), dimensions={0}
p = f8e4m3fn[4] reshape(allgather)
ROOT out = f32[4] convert(p)
}
)";
const int64_t kNumReplicas = 2;
HloModuleConfig config =
GetModuleConfigForTest(/*replica_count=*/kNumReplicas);
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(kModuleStr, config));
TF_ASSERT_OK_AND_ASSIGN(
std::vector<Literal> results,
ExecuteReplicated(std::move(module), absl::Span<Literal* const>{},
kNumReplicas,
/*use_threads=*/true, /*run_hlo_passes=*/true));
ASSERT_EQ(results.size(), kNumReplicas);
for (const Literal& result : results) {
LiteralTestUtil::ExpectR1Equal<float>({1, 2, 1, 2}, result);
}
}

XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(AllToAll_8BitFloat)) {
const char* const kModuleStr = R"(
HloModule test
ENTRY test_computation {
a0 = f8e4m3fn[2] constant({1,2})
a2a = f8e4m3fn[2] all-to-all(a0), dimensions={0}
ROOT out = f32[2] convert(a2a)
}
)";
const int64_t kNumReplicas = 2;
HloModuleConfig config =
GetModuleConfigForTest(/*replica_count=*/kNumReplicas);
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(kModuleStr, config));
TF_ASSERT_OK_AND_ASSIGN(
std::vector<Literal> results,
ExecuteReplicated(std::move(module), absl::Span<Literal* const>{},
kNumReplicas,
/*use_threads=*/true, /*run_hlo_passes=*/true));
ASSERT_EQ(results.size(), kNumReplicas);
LiteralTestUtil::ExpectR1Equal<float>({1, 1}, results[0]);
LiteralTestUtil::ExpectR1Equal<float>({2, 2}, results[1]);
}

XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(CollectivePermute_8BitFloat)) {
const char* const kModuleStr = R"(
HloModule test
ENTRY test_computation {
a0 = f8e5m2[2] constant({1,2})
a1 = f8e5m2[2] collective-permute(a0), source_target_pairs={{0,1}, {1,0}}
ROOT out = f32[2] convert(a1)
}
)";
const int64_t kNumReplicas = 2;
HloModuleConfig config =
GetModuleConfigForTest(/*replica_count=*/kNumReplicas);
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(kModuleStr, config));
TF_ASSERT_OK_AND_ASSIGN(
std::vector<Literal> results,
ExecuteReplicated(std::move(module), absl::Span<Literal* const>{},
kNumReplicas,
/*use_threads=*/true, /*run_hlo_passes=*/true));
ASSERT_EQ(results.size(), kNumReplicas);
LiteralTestUtil::ExpectR1Equal<float>({1, 2}, results[0]);
LiteralTestUtil::ExpectR1Equal<float>({1, 2}, results[1]);
}

XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(AsyncAllGather)) {
const char* const kModuleStr = R"(
HloModule test
Expand Down Expand Up @@ -2273,5 +2199,110 @@ body {
results[1]));
}

class Fp8CollectiveOpsTest : public CollectiveOpsTest {
public:
Fp8CollectiveOpsTest() {
replacements_[kF8E4M3DatatypePlaceholder] =
IsCuda() ? "f8e4m3fn" : "f8e4m3fnuz";
replacements_[kF8E5M2DatatypePlaceholder] =
IsCuda() ? "f8e5m2" : "f8e5m2fnuz";
}

protected:
bool IsCuda() {
return std::holds_alternative<se::CudaComputeCapability>(Capability());
}

const se::GpuComputeCapability& Capability() {
return backend()
.default_stream_executor()
->GetDeviceDescription()
.gpu_compute_capability();
}

absl::flat_hash_map<absl::string_view, absl::string_view> replacements_;

private:
static constexpr const char* kF8E4M3DatatypePlaceholder{"<<F8E4M3>>"};
static constexpr const char* kF8E5M2DatatypePlaceholder{"<<F8E5M2>>"};
};

XLA_TEST_F(Fp8CollectiveOpsTest, DISABLED_ON_CPU(AllGather_8BitFloat)) {
const char* const kModuleStr = R"(
HloModule test
ENTRY test_computation {
a0 = <<F8E4M3>>[1,2] constant({{1,2}})
allgather = <<F8E4M3>>[2, 2] all-gather(a0), dimensions={0}
p = <<F8E4M3>>[4] reshape(allgather)
ROOT out = f32[4] convert(p)
}
)";
const int64_t kNumReplicas = 2;
HloModuleConfig config =
GetModuleConfigForTest(/*replica_count=*/kNumReplicas);
TF_ASSERT_OK_AND_ASSIGN(
auto module, ParseAndReturnVerifiedModule(
absl::StrReplaceAll(kModuleStr, replacements_), config));
TF_ASSERT_OK_AND_ASSIGN(
std::vector<Literal> results,
ExecuteReplicated(std::move(module), absl::Span<Literal* const>{},
kNumReplicas,
/*use_threads=*/true, /*run_hlo_passes=*/true));
ASSERT_EQ(results.size(), kNumReplicas);
for (const Literal& result : results) {
LiteralTestUtil::ExpectR1Equal<float>({1, 2, 1, 2}, result);
}
}

XLA_TEST_F(Fp8CollectiveOpsTest, DISABLED_ON_CPU(AllToAll_8BitFloat)) {
const char* const kModuleStr = R"(
HloModule test
ENTRY test_computation {
a0 = <<F8E4M3>>[2] constant({1,2})
a2a = <<F8E4M3>>[2] all-to-all(a0), dimensions={0}
ROOT out = f32[2] convert(a2a)
}
)";
const int64_t kNumReplicas = 2;
HloModuleConfig config =
GetModuleConfigForTest(/*replica_count=*/kNumReplicas);
TF_ASSERT_OK_AND_ASSIGN(
auto module, ParseAndReturnVerifiedModule(
absl::StrReplaceAll(kModuleStr, replacements_), config));
TF_ASSERT_OK_AND_ASSIGN(
std::vector<Literal> results,
ExecuteReplicated(std::move(module), absl::Span<Literal* const>{},
kNumReplicas,
/*use_threads=*/true, /*run_hlo_passes=*/true));
ASSERT_EQ(results.size(), kNumReplicas);
LiteralTestUtil::ExpectR1Equal<float>({1, 1}, results[0]);
LiteralTestUtil::ExpectR1Equal<float>({2, 2}, results[1]);
}

XLA_TEST_F(Fp8CollectiveOpsTest, DISABLED_ON_CPU(CollectivePermute_8BitFloat)) {
const char* const kModuleStr = R"(
HloModule test
ENTRY test_computation {
a0 = <<F8E5M2>>[2] constant({1,2})
a1 = <<F8E5M2>>[2] collective-permute(a0), source_target_pairs={{0,1}, {1,0}}
ROOT out = f32[2] convert(a1)
}
)";
const int64_t kNumReplicas = 2;
HloModuleConfig config =
GetModuleConfigForTest(/*replica_count=*/kNumReplicas);
TF_ASSERT_OK_AND_ASSIGN(
auto module, ParseAndReturnVerifiedModule(
absl::StrReplaceAll(kModuleStr, replacements_), config));
TF_ASSERT_OK_AND_ASSIGN(
std::vector<Literal> results,
ExecuteReplicated(std::move(module), absl::Span<Literal* const>{},
kNumReplicas,
/*use_threads=*/true, /*run_hlo_passes=*/true));
ASSERT_EQ(results.size(), kNumReplicas);
LiteralTestUtil::ExpectR1Equal<float>({1, 2}, results[0]);
LiteralTestUtil::ExpectR1Equal<float>({1, 2}, results[1]);
}

} // namespace
} // namespace xla
Loading