You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The function lambda x: jnp.log(x == 1) is not jit-able on metal and raises the following error message.
/AppleInternal/Library/BuildRoots/b11baf73-9ee0-11ef-b7b4-7aebe1f78c73/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShaders/MPSCore/Utility/MPSLibrary.mm:570: failed assertion `MPSKernel MTLComputePipelineStateCache unable to load function ndArrayIdentity.
Compiler encountered an internal error: (null)
The expression above is of course equivalent to lambda x: jnp.where(x == 1, 0, -jnp.inf), but, for example, appears in the evaluation of the log density of the delta distribution in numpyrohere. I couldn't find a simpler expression that raised the same error.
$ python identity_bug.py
WARNING:2025-01-16 17:19:57,590:jax._src.xla_bridge:1000: Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
W0000 00:00:1737065997.590723 34394534 mps_client.cc:510] WARNING: JAX Apple GPU support is experimental and not all JAX functionality is correctly supported!
Metal device set to: Apple M1
systemMemory: 16.00 GB
maxCacheSize: 5.33 GB
I0000 00:00:1737065997.605677 34394534 service.cc:145] XLA service 0x128fbfd70 initialized for platform METAL (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1737065997.605694 34394534 service.cc:153] StreamExecutor device (0): Metal, <undefined>
I0000 00:00:1737065997.607700 34394534 mps_client.cc:406] Using Simple allocator.
I0000 00:00:1737065997.607742 34394534 mps_client.cc:384] XLA backend will use up to 11452858368 bytes on device 0 for SimpleAllocator.
/Users/till/git/tycho/playground/identity_bug.py:11: DeprecationWarning: jax.lib.xla_bridge.get_backend is deprecated; use jax.extend.backend.get_backend.
print(f"running on {xla_bridge.get_backend().platform} ...")
running on METAL ...
f([0 1]) -> ... [-inf 0.]
f([1. 1.]) -> ... [0. 0.]
jitted_f([0 1]) -> ... /AppleInternal/Library/BuildRoots/b11baf73-9ee0-11ef-b7b4-7aebe1f78c73/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShaders/MPSCore/Utility/MPSLibrary.mm:570: failed assertion `MPSKernel MTLComputePipelineStateCache unable to load functionndArrayIdentity. Compiler encountered an internal error: (null)'[1] 52590 abort python identity_bug.py
Restricting to CPU gives the following.
$ JAX_PLATFORM_NAME=cpu python identity_bug.py
WARNING:2025-01-16 17:22:04,614:jax._src.xla_bridge:1000: Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
W0000 00:00:1737066124.615053 34398748 mps_client.cc:510] WARNING: JAX Apple GPU support is experimental and not all JAX functionality is correctly supported!
Metal device set to: Apple M1
systemMemory: 16.00 GB
maxCacheSize: 5.33 GB
I0000 00:00:1737066124.629618 34398748 service.cc:145] XLA service 0x112537590 initialized for platform METAL (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1737066124.629654 34398748 service.cc:153] StreamExecutor device (0): Metal, <undefined>
I0000 00:00:1737066124.631511 34398748 mps_client.cc:406] Using Simple allocator.
I0000 00:00:1737066124.631542 34398748 mps_client.cc:384] XLA backend will use up to 11452858368 bytes on device 0 for SimpleAllocator.
/Users/till/git/tycho/playground/identity_bug.py:11: DeprecationWarning: jax.lib.xla_bridge.get_backend is deprecated; use jax.extend.backend.get_backend.
print(f"running on {xla_bridge.get_backend().platform} ...")
running on cpu ...
f([0 1]) -> ... [-inf 0.]
f([1. 1.]) -> ... [0. 0.]
jitted_f([0 1]) -> ... [-inf 0.]
jitted_f([1. 1.]) -> ... [0. 0.]
done
I0000 00:00:1737066124.861558 34398748 mps_client.h:209] MetalClient destroyed.
This seems to be specific to the combination of equality testing and jnp.log, e.g., the following all work on both backends:
lambdax: 1/ (x==1)
lamdax: jnp.exp(x==1)
# This is equivalent to jnp.log(x == 1) because log(1 / x) = - log(x).lambdax: -jnp.log(1/ (x==1))
System info (python version, jaxlib version, accelerator, etc.)
…`Delta` sampling site during initialization. (#1950)
* Use `jnp.where` for `Delta.log_prob` (cf. jax-ml/jax#25935).
* Raise explicit error for unobserved `Delta` sample sites during intialization.
Description
The function
lambda x: jnp.log(x == 1)
is not jit-able on metal and raises the following error message.The expression above is of course equivalent to
lambda x: jnp.where(x == 1, 0, -jnp.inf)
, but, for example, appears in the evaluation of the log density of the delta distribution innumpyro
here. I couldn't find a simpler expression that raised the same error.Here's a script to reproduce the issue.
Running on metal, I get the following.
Restricting to CPU gives the following.
This seems to be specific to the combination of equality testing and
jnp.log
, e.g., the following all work on both backends:System info (python version, jaxlib version, accelerator, etc.)
The text was updated successfully, but these errors were encountered: