Skip to content

Failed to launch ROCm kernel on MI250x #247

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
ffrancesco94 opened this issue Feb 28, 2025 · 12 comments
Closed

Failed to launch ROCm kernel on MI250x #247

ffrancesco94 opened this issue Feb 28, 2025 · 12 comments
Assignees
Labels
bug Something isn't working Under Investigation

Comments

@ffrancesco94
Copy link

Description

Hi,
I tried to use this build of jax since the upstream one was giving me some trouble. I built a simple reproducer where I try to train a PINN using jax (I can upload it if you need) and I get the following error:

2025-02-28 15:59:06.281591: W external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1035] Compiling 113 configs for 27 fusions on a single thread.
F0228 15:59:16.766740  124342 stream_executor_util.cc:515] Non-OK-status: kernel->Launch(se::ThreadDim(threads_per_block, 1, 1), se::BlockDim(blocks_per_grid, 1, 1), stream, buffer, host_buffer_bytes, static_cast<int64_t>(buffer.size()))
Status: INTERNAL: Failed to launch ROCm kernel: RepeatBufferKernel with block dimensions: 256x1x1: hipError_t(303)
*** Check failure stack trace: ***
    @     0x152a0e493d54  absl::lts_20230802::log_internal::LogMessage::SendToLog()
    @     0x152a0e4936f4  absl::lts_20230802::log_internal::LogMessage::Flush()
    @     0x152a0e4941b9  absl::lts_20230802::log_internal::LogMessageFatal::~LogMessageFatal()
    @     0x152a0ae00837  xla::gpu::InitializeTypedBuffer<>()
    @     0x152a0adfb0af  xla::primitive_util::FloatingPointTypeSwitch<>()
    @     0x152a0adf94c3  xla::gpu::InitializeBuffer()
    @     0x152a0adf5634  stream_executor::RedzoneAllocator::CreateBuffer()
    @     0x152a09702536  xla::gpu::RedzoneBuffers::CreateInputs()
    @     0x152a097021df  xla::gpu::RedzoneBuffers::FromInstruction()
    @     0x152a096f2b6a  xla::gpu::GemmFusionAutotunerImpl::MeasurePerformance()
    @     0x152a096f3ba7  xla::gpu::GemmFusionAutotunerImpl::Profile()
    @     0x152a096f46d9  xla::gpu::GemmFusionAutotunerImpl::Autotune()
    @     0x152a096f729b  xla::gpu::GemmFusionAutotuner::Run()
    @     0x152a0970dabf  xla::HloPassPipeline::RunHelper()
    @     0x152a0970ad15  xla::HloPassPipeline::RunPassesInternal<>()
    @     0x152a0970a66a  xla::HloPassPipeline::Run()
    @     0x152a060004ee  xla::gpu::GpuCompiler::OptimizeHloPostLayoutAssignment()
    @     0x152a05ff2ab1  xla::gpu::AMDGPUCompiler::OptimizeHloPostLayoutAssignment()
    @     0x152a05ffb644  xla::gpu::GpuCompiler::OptimizeHloModule()
    @     0x152a060041c2  xla::gpu::GpuCompiler::RunHloPasses()
    @     0x152a05fdba6f  xla::Service::BuildExecutable()
    @     0x152a05f54187  xla::LocalService::CompileExecutables()
    @     0x152a05f4ebf4  xla::LocalClient::Compile()
    @     0x152a05ef00fb  xla::PjRtStreamExecutorClient::CompileInternal()
    @     0x152a05ef13d0  xla::PjRtStreamExecutorClient::Compile()
    @     0x152a05e5d44a  std::__detail::__variant::__gen_vtable_impl<>::__visit_invoke()
    @     0x152a05e4cddc  pjrt::PJRT_Client_Compile()
    @     0x152b0980676d  xla::InitializeArgsAndCompile()
    @     0x152b09806e9e  xla::PjRtCApiClient::Compile()
    @     0x152b11258f38  xla::ifrt::PjRtLoadedExecutable::Create()
    @     0x152b11263010  xla::ifrt::PjRtCompiler::Compile()
    @     0x152b100e4cff  xla::PyClient::CompileIfrtProgram()
    @     0x152b100e595b  xla::PyClient::Compile()
    @     0x152b100eca26  nanobind::detail::func_create<>()::{lambda()#1}::__invoke()
    @     0x152b11230ec8  nanobind::detail::nb_func_vectorcall_complex()
    @     0x152b1ae1e82c  nanobind::detail::nb_bound_method_vectorcall()
    @     0x55628df359dc  PyObject_Vectorcall

I installed this jax build following these instructions. Interestingly, I had a similar error when I was experimenting a while back and I was missing the jax-rocm60-pjrt and jax-rocm60-plugin packages, but they are installed now. I installed the packages in a Singularity container and ran from there, but it seems that the GPU is being correctly found.

System info (python version, jaxlib version, accelerator, etc.)

The accelerator is an AMD MI250X.

python -c "import jax; jax.print_environment_info()" > jaxinfo.txt
/usr/share/libdrm/amdgpu.ids: No such file or directory
Singularity> cat jax
jax/         jaxinfo.txt  jax.sif
Singularity> cat jaxinfo.txt
jax:    0.5.0
jaxlib: 0.5.0
numpy:  2.2.3
python: 3.11.11 | packaged by conda-forge | (main, Dec  5 2024, 14:17:24) [GCC 13.3.0]
device info: AMD Radeon Graphics-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='nid007976', release='5.14.21-150500.55.49_13.0.56-cray_shasta_c', version='#1 SMP Mon Mar 4 14:19:49 UTC 2024 (9d8355b)', machine='x86_64')
@ffrancesco94 ffrancesco94 added the bug Something isn't working label Feb 28, 2025
@ppanchad-amd
Copy link

Hi @ffrancesco94. Internal ticket has been created to investigate this issue. Thanks!

@ksebaz
Copy link

ksebaz commented Mar 4, 2025

I think could imagine this is already fixed with 1c9f907 which unfortunately didn't make it into the 0.5.0 release nor in the released wheels (which seem to be built for gfx900 checking with roc-obj-ls and which might not be enough for the MI250). Building from the rocm-jaxlib-v0.5.0 branch, which moved on and includes the relevant fixes, for my set of tests at least, these recurring launching errors are fixed for the moment.

@ffrancesco94
Copy link
Author

I see. So the way to go for now is to compile everything from source? Or will a minor release be published soon?
By the way, in which way are the wheels published in the GitHub release of this repo different from the ones I find on PyPI? I saw there that the jax-rocm60-pjrt and jax-rocm60-plugin packages are available up to jax==0.4.38, but I'm a bit confused on which wheels are the "preferred" way to go.

@ksebaz
Copy link

ksebaz commented Mar 4, 2025

To clarify: the above comment was only to add from my experience as a user. I adjusted the wording a bit since it seems overly confident reading it now :-)

@ffrancesco94
Copy link
Author

Aah, I see. But it's still very good to get that input, so thank you, at least I know that I can compile things myself for the time being!

@Ruturaj4
Copy link

Ruturaj4 commented Mar 4, 2025

I see. So the way to go for now is to compile everything from source? Or will a minor release be published soon? By the way, in which way are the wheels published in the GitHub release of this repo different from the ones I find on PyPI? I saw there that the jax-rocm60-pjrt and jax-rocm60-plugin packages are available up to jax==0.4.38, but I'm a bit confused on which wheels are the "preferred" way to go.

Thanks for letting us know. I will shortly upload the latest wheels.

@ffrancesco94
Copy link
Author

ffrancesco94 commented Mar 4, 2025

That would mean that the ones built here are the same as the ones on PyPI?

edit:typo

@Ruturaj4 Ruturaj4 self-assigned this Mar 5, 2025
@tcgu-amd
Copy link

tcgu-amd commented Mar 6, 2025

That would mean that the ones built here are the same as the ones on PyPI?

edit:typo

I think the ones of PyPI might be behind; building from this repo should yield the latest public available Jax build for ROCm.

@phambinhfin
Copy link

@Ruturaj4 For the version jax0.0.5, rocm6.3.1 I also having similar error

F0314 17:16:06.322534     361 stream_executor_util.cc:503] Could not create RepeatBufferKernel: INTERNAL: Failed call to hipGetFuncBySymbol: hipError_t(98)
*** Check failure stack trace: ***
    @     0x7f39255141d4  absl::lts_20230802::log_internal::LogMessage::SendToLog()
    @     0x7f3925513b74  absl::lts_20230802::log_internal::LogMessage::Flush()
    @     0x7f3925514639  absl::lts_20230802::log_internal::LogMessageFatal::~LogMessageFatal()
    @     0x7f3921e7f19c  xla::gpu::InitializeTypedBuffer<>()
    @     0x7f3921e79a3f  xla::primitive_util::FloatingPointTypeSwitch<>()
    @     0x7f3921e77e53  xla::gpu::InitializeBuffer()
    @     0x7f3921e73fc4  stream_executor::RedzoneAllocator::CreateBuffer()
    @     0x7f3920780b06  xla::gpu::RedzoneBuffers::CreateInputs()
    @     0x7f39207807af  xla::gpu::RedzoneBuffers::FromInstruction()
    @     0x7f392077113a  xla::gpu::GemmFusionAutotunerImpl::MeasurePerformance()
    @     0x7f3920772177  xla::gpu::GemmFusionAutotunerImpl::Profile()
    @     0x7f3920772ca9  xla::gpu::GemmFusionAutotunerImpl::Autotune()
    @     0x7f392077586b  xla::gpu::GemmFusionAutotuner::Run()
    @     0x7f392078c08f  xla::HloPassPipeline::RunHelper()
    @     0x7f39207892e5  xla::HloPassPipeline::RunPassesInternal<>()
    @     0x7f3920788c3a  xla::HloPassPipeline::Run()
    @     0x7f391d02112e  xla::gpu::GpuCompiler::OptimizeHloPostLayoutAssignment()
    @     0x7f391d0136f1  xla::gpu::AMDGPUCompiler::OptimizeHloPostLayoutAssignment()
    @     0x7f391d01c284  xla::gpu::GpuCompiler::OptimizeHloModule()
    @     0x7f391d024e02  xla::gpu::GpuCompiler::RunHloPasses()
    @     0x7f391cffc6af  xla::Service::BuildExecutable()
    @     0x7f391cf74dc7  xla::LocalService::CompileExecutables()
    @     0x7f391cf6f834  xla::LocalClient::Compile()
    @     0x7f391cf10d3b  xla::PjRtStreamExecutorClient::CompileInternal()
    @     0x7f391cf12010  xla::PjRtStreamExecutorClient::Compile()
    @     0x7f391ce7e08a  std::__detail::__variant::__gen_vtable_impl<>::__visit_invoke()
    @     0x7f391ce6da1c  pjrt::PJRT_Client_Compile()
    @     0x7f3a0ea525ad  xla::InitializeArgsAndCompile()
    @     0x7f3a0ea52cde  xla::PjRtCApiClient::Compile()
    @     0x7f3a164a4d78  xla::ifrt::PjRtLoadedExecutable::Create()
    @     0x7f3a164aee50  xla::ifrt::PjRtCompiler::Compile()
    @     0x7f3a15330b3f  xla::PyClient::CompileIfrtProgram()
    @     0x7f3a1533179b  xla::PyClient::Compile()
    @     0x7f3a15338866  nanobind::detail::func_create<>()::{lambda()#1}::__invoke()
    @     0x7f3a1647cd08  nanobind::detail::nb_func_vectorcall_complex()
    @     0x7f3abcaec82c  nanobind::detail::nb_bound_method_vectorcall()
    @     0x7f3abd715de8  PyObject_Vectorcall

When trying to run

import flax
import jax
import jax.numpy as jnp
import transformer_engine.jax as te
import transformer_engine.jax.flax as te_flax
from transformer_engine.common import recipe

BATCH = 32
SEQLEN = 128
HIDDEN = 1024

# Initialize RNG and inputs.
rng = jax.random.PRNGKey(0)
init_rng, data_rng = jax.random.split(rng)
inp = jax.random.normal(data_rng, [BATCH, SEQLEN, HIDDEN], jnp.float32)

# Create an FP8 recipe. Note: All input args are optional.
fp8_recipe = recipe.DelayedScaling(margin=0, interval=1, fp8_format=recipe.Format.HYBRID)

# Enable autocasting for the forward pass
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
    model = te_flax.DenseGeneral(features=HIDDEN)

    def loss_fn(params, other_vars, inp):
      out = model.apply({'params':params, **other_vars}, inp)
      return jnp.mean(out)

    # Initialize models.
    variables = model.init(init_rng, inp)

@ffrancesco94
Copy link
Author

@phambinhfin For me it was due to the driver on the host system being older (ROCm 6.0.x or 6.2.x), not sure if it helps in your case.

@Ruturaj4
Copy link

Ruturaj4 commented Apr 10, 2025

@ffrancesco94 can you try installing jax 0.5.0 using pip? We now fixed the pypi issue and our plugins are up.

@tcgu-amd
Copy link

@ffrancesco94, this issue will be closed for now due to inactivity. Please feel free to follow up below if the issue persists. Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working Under Investigation
Projects
None yet
Development

No branches or pull requests

6 participants