This package contains two functions for debugging jax Array
s:
pip install git+https://github.com/Findus23/jax-array-info.git
from jax_array_info import sharding_info, sharding_vis, simple_array_info, print_array_stats
sharding_info(arr)
prints general information about a jax or numpy array with special focus on sharding (
supporting SingleDeviceSharding
, GSPMDSharding
, PositionalSharding
, NamedSharding
and PmapSharding
)
some_array = jax.numpy.zeros(shape=(N, N, N), dtype=jax.numpy.float32)
some_array = jax.device_put(some_array, NamedSharding(mesh, P(None, "gpus")))
sharding_info(some_array, "some_array")
╭───────────────── some_array ─────────────────╮
│ shape: (128, 128, 128) │
│ dtype: float32 │
│ size: 8.0 MiB │
│ NamedSharding: P(None, 'gpus') │
│ axis 1 is sharded: CPU 0 contains 0:16 (1/8) │
│ Total size: 128 │
╰──────────────────────────────────────────────╯
sharding_info()
uses a jax callback to make sure it can get sharding information in as many situations as possible.
But this means it is broken when used in some type of functions (e.g. when using
shard_map
). simple_array_info()
gives the same output, with the
advantage of working everywhere (it is equivalent to a print
) and the tradeoff that it is not guaranteed to always
report correct sharding information (e.g. inside jitted functions).
Shows a nice overview over the all currently allocated arrays ordered by their size. To save space, scalar values are grouped by dtype.
Disclaimer: This uses jax.live_arrays()
to get its information. There might be allocated arrays that are missing in this view.
arr = jax.numpy.zeros(shape=(16, 16, 16))
arr2 = jax.device_put(jax.numpy.zeros(shape=(2, 16, 4)), NamedSharding(mesh, P(None, "gpus")))
scalar = jax.numpy.array(42)
print_array_stats()
allocated jax arrays
┏━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓
┃ size ┃ shape ┃ dtype ┃ sharded ┃
┡━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩
│ 16.0 KiB │ (16, 16, 16) │ float32 │ │
│ 64.0 B │ (2, 16, 4) │ float32 │ ✔ (512.0 B total) │
├──────────┼──────────────┼─────────┼───────────────────┤
│ 4.0 B │ 1×s │ int32 │ │
├──────────┼──────────────┼─────────┼───────────────────┤
│ 16.1 KiB │ │ │ │
└──────────┴──────────────┴─────────┴───────────────────┘
A modified version
of
jax.debug.visualize_array_sharding()
that also supports arrays with more than 2 dimensions (by ignoring non-sharded dimensions in the visualisation until
reaching 2 dimensions)
array = jax.numpy.zeros(shape=(N, N, N), dtype=jax.numpy.float32)
array = jax.device_put(array, NamedSharding(mesh, P(None, "gpus")))
sharding_vis(array)
─────────── showing dims [0, 1] from original shape (128, 128, 128) ────────────
┌───────┬───────┬───────┬───────┬───────┬───────┬───────┬───────┐
│ │ │ │ │ │ │ │ │
│ │ │ │ │ │ │ │ │
│ │ │ │ │ │ │ │ │
│ │ │ │ │ │ │ │ │
│ CPU 0 │ CPU 1 │ CPU 2 │ CPU 3 │ CPU 4 │ CPU 5 │ CPU 6 │ CPU 7 │
│ │ │ │ │ │ │ │ │
│ │ │ │ │ │ │ │ │
│ │ │ │ │ │ │ │ │
│ │ │ │ │ │ │ │ │
└───────┴───────┴───────┴───────┴───────┴───────┴───────┴───────┘
You can find many examples of how arrays can be sharded in jax and how the output would look like in tests/test_jax.py
. For examples of sharding arrays along multiple jax processes check test_multihost.py
and multihost.py
While the output of print_array_stats()
is very useful to see which arrays are currently allocated, it is missing the connection to where the array is currently used. While in many cases the shape allows to guess which array a row is refering to, it would be nice to also know its name.
Unfortunately we only get a reference to all currently allocated arrays from jax.live_arrays()
which means we don't know where the array is assigned and what the variable is called.
We can do an ugly workaround to still get more useful names in the print_array_stats()
output:
Whenever simple_array_info
is used like simple_array_info(some_array,"some_array")
, we will modify the some_array
jax array by setting some_array._custom_label = "some_array"
. Therefore this label becomes a part of the jax Array
object and we can use this label in our array overview.
This of course has some limitations (it only works if simple_array_info
is used outside of jitted code) and is an ugly hack, so it possibly will break some things. Therefore this feature can be disabled globally using
from jax_array_info import sharding_info, print_array_stats, config as array_info_config
array_info_config.assign_labels_to_arrays = False
import jax
from jax_array_info import sharding_info, print_array_stats
some_array = jax.numpy.zeros(shape=(128, 128, 128))
sharding_info(some_array, "some_array")
print_array_stats()
╭────── some_array ──────╮
│ shape: (128, 128, 128) │
│ dtype: float32 │
│ size: 8.0 MiB │
│ not sharded │
╰────────────────────────╯
allocated jax arrays
┏━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━━━━┓
┃ size ┃ shape ┃ dtype ┃ sharded ┃ label ┃
┡━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━━━━┩
│ 8.0 MiB │ (128, 128, 128) │ float32 │ │ some_array │
├─────────┼─────────────────┼─────────┼─────────┼────────────┤
│ 8.0 MiB │ │ │ │ │
└─────────┴─────────────────┴─────────┴─────────┴────────────┘