Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GPU pallas_call loses compiler params during second call when double jit-wrapped #25714

Open
hanzhi713 opened this issue Jan 3, 2025 · 1 comment
Assignees
Labels
bug Something isn't working

Comments

@hanzhi713
Copy link

hanzhi713 commented Jan 3, 2025

Description

If you add a print(triton_params) after this line

triton_params = compiler_params.get("triton", compiler_params)
and run the following reproducer, you'll get

{'num_warps': 8}
{}

This causes some performance problems in production as kernels can't get the right compiler params.

Reproducer

from functools import partial

import jax
from jax.experimental import pallas as pl
import jax.numpy as jnp


def add_vectors_kernel(x_ref, y_ref, o_ref):
  x, y = x_ref[...], y_ref[...]
  o_ref[...] = x + y
  

@partial(jax.jit, static_argnames="z")
def dummy(x, y, z):
    x += z
    return add_vectors(x, y)
  
  
@jax.jit
def add_vectors(x: jax.Array, y: jax.Array) -> jax.Array:
  return pl.pallas_call(
      add_vectors_kernel,
      out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype),
      compiler_params=dict(triton=dict(num_warps=8))
  )(x, y)
  
a = jnp.arange(8)
b = jnp.arange(8)
dummy(a, b, 1)
dummy(a, b, 2)

System info (python version, jaxlib version, accelerator, etc.)

jax                                      0.4.39.dev20250102
jax-cuda12-pjrt                          0.4.39.dev20250102
jax-cuda12-plugin                        0.4.39.dev20250102
jaxlib                                   0.4.39.dev20250102
@hanzhi713 hanzhi713 added the bug Something isn't working label Jan 3, 2025
@justinjfu justinjfu self-assigned this Jan 6, 2025
copybara-service bot pushed a commit that referenced this issue Jan 6, 2025
…ton lowering.

Addresses: #25714
PiperOrigin-RevId: 712566709
@justinjfu
Copy link
Collaborator

Looks like a silly bug where the params are mutated because we pop 'num_warps' instead of get. #25735 should address.

copybara-service bot pushed a commit that referenced this issue Jan 6, 2025
…ton lowering.

Addresses: #25714
PiperOrigin-RevId: 712566709
copybara-service bot pushed a commit that referenced this issue Jan 7, 2025
…ton lowering.

Addresses: #25714
PiperOrigin-RevId: 712566709
copybara-service bot pushed a commit that referenced this issue Jan 7, 2025
…ton lowering.

Addresses: #25714
PiperOrigin-RevId: 712930760
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants