From c32e2b920cf5d2f513576f72dd2a88ce9f97a193 Mon Sep 17 00:00:00 2001 From: Xuefei Jiang Date: Mon, 16 Sep 2024 04:09:44 -0700 Subject: [PATCH] PR #17058: Replace "Navi" with corresponding public product names Imported from GitHub PR https://github.com/openxla/xla/pull/17058 The term 'Navi' is an internal product name used exclusively within AMD and should not appear in public projects. This PR replaces those 'Navi' names with the corresponding public product names. Copybara import of the project: -- 47248a044861ffe4d2f129674bab11916626e5a3 by scxfjiang : scrub navi Merging this change closes #17058 COPYBARA_INTEGRATE_REVIEW=https://github.com/openxla/xla/pull/17058 from ROCm:ci_scrub_navi_symbol 47248a044861ffe4d2f129674bab11916626e5a3 PiperOrigin-RevId: 675089917 --- xla/stream_executor/device_description.h | 13 ++++++++----- xla/stream_executor/rocm/rocm_driver.cc | 8 ++++++-- 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/xla/stream_executor/device_description.h b/xla/stream_executor/device_description.h index 3dfb79b3486ca..99d7f1ce5d83c 100644 --- a/xla/stream_executor/device_description.h +++ b/xla/stream_executor/device_description.h @@ -184,16 +184,19 @@ class RocmComputeCapability { return absl::c_count(kList, gfx_version()) != 0; } - bool navi21() const { return gfx_version() == "gfx1030"; } + bool gfx10_rx68xx() const { return gfx_version() == "gfx1030"; } - bool navi31() const { return gfx_version() == "gfx1100"; } + bool gfx10_rx69xx() const { return gfx_version() == "gfx1030"; } + + bool gfx11_rx7900() const { return gfx_version() == "gfx1100"; } bool has_nhwc_layout_support() const { return gfx9_mi100_or_later(); } bool has_bf16_dtype_support() const { return gfx9_mi100_or_later(); } bool has_fast_fp16_support() const { - return gfx9_mi100_or_later() || navi21() || navi31(); + return gfx9_mi100_or_later() || gfx10_rx68xx() || gfx10_rx69xx() || + gfx11_rx7900(); } bool has_mfma_instr_support() const { return gfx9_mi100_or_later(); } @@ -235,8 +238,8 @@ class RocmComputeCapability { "gfx908", // MI100 "gfx90a", // MI200 "gfx940", "gfx941", "gfx942", // MI300 - "gfx1030", // Navi21 - "gfx1100" // Navi31 + "gfx1030", // RX68xx / RX69xx + "gfx1100" // RX7900 }; }; diff --git a/xla/stream_executor/rocm/rocm_driver.cc b/xla/stream_executor/rocm/rocm_driver.cc index 9e838abfabcbf..3381b97c5553e 100644 --- a/xla/stream_executor/rocm/rocm_driver.cc +++ b/xla/stream_executor/rocm/rocm_driver.cc @@ -1594,12 +1594,16 @@ bool GetReservedMemory(uint64_t* reserve) { const uint64_t RESERVED_GFX908 = 1048576 * 512; const uint64_t RESERVED_GFX9_X = 1048576 * 1024; const uint64_t RESERVED_GFX10_X = 1048576 * 512; - if (compute_capability.gfx_version() == "gfx908") { + const uint64_t RESERVED_GFX11_X = 1048576 * 512; + if (compute_capability.gfx9_mi100()) { *reserve = RESERVED_GFX908; } else if (compute_capability.gfx9_mi200_or_later()) { *reserve = RESERVED_GFX9_X; - } else if (compute_capability.navi21() || compute_capability.navi31()) { + } else if (compute_capability.gfx10_rx68xx() || + compute_capability.gfx10_rx69xx()) { *reserve = RESERVED_GFX10_X; + } else if (compute_capability.gfx11_rx7900()) { + *reserve = RESERVED_GFX11_X; } return true;