You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
import jax
import jax.numpy as jnp
from jax.experimental import pallas as pl
def add_kernel(x_ref, y_ref, o_ref):
# In this code, `x_ref`, `y_ref` and `o_ref` are (8,)-shaped `Ref`s
x = x_ref[:]
y = y_ref[:]
o_ref[:] = x + y
x, y = jnp.arange(8), jnp.arange(8, 16)
add = pl.pallas_call(add_kernel, out_shape=jax.ShapeDtypeStruct((8,), jnp.int32))
add(x, y)
Hi,
Installing jax, jaxlib, and jax-triton nightly builds cause the following error:
Package Version Editable project location
absl-py 1.4.0
filelock 3.12.2
jax 0.4.15
jax-triton 0.1.4 /home/mehdi/Repos/jax-triton
jaxlib 0.4.15.dev20230822+cuda12.cudnn89
ml-dtypes 0.2.0
numpy 1.25.2
opt-einsum 3.3.0
pip 22.0.2
scipy 1.11.2
setuptools 59.6.0
triton-nightly 2.1.0.dev20230714011643
The text was updated successfully, but these errors were encountered: