Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

lambda x: jnp.log(x == 1) is not jit-able on metal. #25935

Open
tillahoffmann opened this issue Jan 16, 2025 · 1 comment
Open

lambda x: jnp.log(x == 1) is not jit-able on metal. #25935

tillahoffmann opened this issue Jan 16, 2025 · 1 comment
Assignees
Labels
Apple GPU (Metal) plugin bug Something isn't working

Comments

@tillahoffmann
Copy link
Contributor

Description

The function lambda x: jnp.log(x == 1) is not jit-able on metal and raises the following error message.

/AppleInternal/Library/BuildRoots/b11baf73-9ee0-11ef-b7b4-7aebe1f78c73/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShaders/MPSCore/Utility/MPSLibrary.mm:570: failed assertion `MPSKernel MTLComputePipelineStateCache unable to load function ndArrayIdentity.
        Compiler encountered an internal error: (null)

The expression above is of course equivalent to lambda x: jnp.where(x == 1, 0, -jnp.inf), but, for example, appears in the evaluation of the log density of the delta distribution in numpyro here. I couldn't find a simpler expression that raised the same error.

Here's a script to reproduce the issue.

# identity_bug.py
import jax
from jax.lib import xla_bridge
from jax import numpy as jnp


def f(x):
    return jnp.log(x == 1)


args = [jnp.arange(2), jnp.ones(2)]
print(f"running on {xla_bridge.get_backend().platform} ...")
fns = {"f": f, "jitted_f": jax.jit(f)}
for key, fn in fns.items():
    for x in args:
        print(f"{key}({x}) -> ...", end=" ", flush=True)
        print(fn(x), flush=True)
print("done")

Running on metal, I get the following.

$ python identity_bug.py
WARNING:2025-01-16 17:19:57,590:jax._src.xla_bridge:1000: Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
W0000 00:00:1737065997.590723 34394534 mps_client.cc:510] WARNING: JAX Apple GPU support is experimental and not all JAX functionality is correctly supported!
Metal device set to: Apple M1

systemMemory: 16.00 GB
maxCacheSize: 5.33 GB

I0000 00:00:1737065997.605677 34394534 service.cc:145] XLA service 0x128fbfd70 initialized for platform METAL (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1737065997.605694 34394534 service.cc:153]   StreamExecutor device (0): Metal, <undefined>
I0000 00:00:1737065997.607700 34394534 mps_client.cc:406] Using Simple allocator.
I0000 00:00:1737065997.607742 34394534 mps_client.cc:384] XLA backend will use up to 11452858368 bytes on device 0 for SimpleAllocator.
/Users/till/git/tycho/playground/identity_bug.py:11: DeprecationWarning: jax.lib.xla_bridge.get_backend is deprecated; use jax.extend.backend.get_backend.
  print(f"running on {xla_bridge.get_backend().platform} ...")
running on METAL ...
f([0 1]) -> ... [-inf   0.]
f([1. 1.]) -> ... [0. 0.]
jitted_f([0 1]) -> ... /AppleInternal/Library/BuildRoots/b11baf73-9ee0-11ef-b7b4-7aebe1f78c73/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShaders/MPSCore/Utility/MPSLibrary.mm:570: failed assertion `MPSKernel MTLComputePipelineStateCache unable to load function ndArrayIdentity.
        Compiler encountered an internal error: (null)
'
[1]    52590 abort      python identity_bug.py

Restricting to CPU gives the following.

$ JAX_PLATFORM_NAME=cpu python identity_bug.py                 
WARNING:2025-01-16 17:22:04,614:jax._src.xla_bridge:1000: Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
W0000 00:00:1737066124.615053 34398748 mps_client.cc:510] WARNING: JAX Apple GPU support is experimental and not all JAX functionality is correctly supported!
Metal device set to: Apple M1

systemMemory: 16.00 GB
maxCacheSize: 5.33 GB

I0000 00:00:1737066124.629618 34398748 service.cc:145] XLA service 0x112537590 initialized for platform METAL (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1737066124.629654 34398748 service.cc:153]   StreamExecutor device (0): Metal, <undefined>
I0000 00:00:1737066124.631511 34398748 mps_client.cc:406] Using Simple allocator.
I0000 00:00:1737066124.631542 34398748 mps_client.cc:384] XLA backend will use up to 11452858368 bytes on device 0 for SimpleAllocator.
/Users/till/git/tycho/playground/identity_bug.py:11: DeprecationWarning: jax.lib.xla_bridge.get_backend is deprecated; use jax.extend.backend.get_backend.
  print(f"running on {xla_bridge.get_backend().platform} ...")
running on cpu ...
f([0 1]) -> ... [-inf   0.]
f([1. 1.]) -> ... [0. 0.]
jitted_f([0 1]) -> ... [-inf   0.]
jitted_f([1. 1.]) -> ... [0. 0.]
done
I0000 00:00:1737066124.861558 34398748 mps_client.h:209] MetalClient destroyed.

This seems to be specific to the combination of equality testing and jnp.log, e.g., the following all work on both backends:

lambda x: 1 / (x == 1)
lamda x: jnp.exp(x == 1)
# This is equivalent to jnp.log(x == 1) because log(1 / x) = - log(x).
lambda x: - jnp.log(1 / (x == 1))

System info (python version, jaxlib version, accelerator, etc.)

Metal device set to: Apple M1
jax:    0.4.38
jaxlib: 0.4.38
numpy:  2.2.1
python: 3.11.5 (main, Dec  8 2023, 17:04:09) [Clang 15.0.0 (clang-1500.0.40.1)]
device info: Metal-1, 1 local devices"
process_count: 1
platform: uname_result(system='Darwin', node='Tills-MacBook-Pro-3.local', release='24.2.0', version='Darwin Kernel Version 24.2.0: Fri Dec  6 18:40:14 PST 2024; root:xnu-11215.61.5~2/RELEASE_ARM64_T8103', machine='arm64')


systemMemory: 16.00 GB
maxCacheSize: 5.33 GB
@tillahoffmann tillahoffmann added the bug Something isn't working label Jan 16, 2025
tillahoffmann added a commit to tillahoffmann/numpyro that referenced this issue Jan 16, 2025
fehiepsi pushed a commit to pyro-ppl/numpyro that referenced this issue Jan 17, 2025
…`Delta` sampling site during initialization. (#1950)

* Use `jnp.where` for `Delta.log_prob` (cf. jax-ml/jax#25935).

* Raise explicit error for unobserved `Delta` sample sites during intialization.
@shuhand0
Copy link
Collaborator

Thx, we are looking into it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Apple GPU (Metal) plugin bug Something isn't working
Projects
None yet
Development

No branches or pull requests

4 participants