Skip to content

Debugging tool to print information (especially sharding) about jax arrays

License

Notifications You must be signed in to change notification settings

Findus23/jax-array-info

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

27 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

jax-array-info

This package contains two functions for debugging jax Arrays:

pip install git+https://github.com/Findus23/jax-array-info.git
from jax_array_info import sharding_info, sharding_vis, simple_array_info, print_array_stats

sharding_info(arr)

sharding_info(arr) prints general information about a jax or numpy array with special focus on sharding ( supporting SingleDeviceSharding, GSPMDSharding, PositionalSharding, NamedSharding and PmapSharding)

some_array = jax.numpy.zeros(shape=(N, N, N), dtype=jax.numpy.float32)
some_array = jax.device_put(some_array, NamedSharding(mesh, P(None, "gpus")))
sharding_info(some_array, "some_array")
╭───────────────── some_array ─────────────────╮
│ shape: (128, 128, 128)                       │
│ dtype: float32                               │
│ size: 8.0 MiB                                │
│ NamedSharding: P(None, 'gpus')               │
│ axis 1 is sharded: CPU 0 contains 0:16 (1/8) │
│                    Total size: 128           │
╰──────────────────────────────────────────────╯

simple_array_info()

sharding_info() uses a jax callback to make sure it can get sharding information in as many situations as possible. But this means it is broken when used in some type of functions (e.g. when using shard_map). simple_array_info() gives the same output, with the advantage of working everywhere (it is equivalent to a print) and the tradeoff that it is not guaranteed to always report correct sharding information (e.g. inside jitted functions).

print_array_stats()

Shows a nice overview over the all currently allocated arrays ordered by their size. To save space, scalar values are grouped by dtype.

Disclaimer: This uses jax.live_arrays() to get its information. There might be allocated arrays that are missing in this view.

arr = jax.numpy.zeros(shape=(16, 16, 16))
arr2 = jax.device_put(jax.numpy.zeros(shape=(2, 16, 4)), NamedSharding(mesh, P(None, "gpus")))
scalar = jax.numpy.array(42)

print_array_stats()
             allocated jax arrays              
┏━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓
┃ size     ┃ shape        ┃ dtype   ┃      sharded      ┃
┡━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩
│ 16.0 KiB │ (16, 16, 16) │ float32 │                   │
│ 64.0 B   │ (2, 16, 4)   │ float32 │ ✔ (512.0 B total) │
├──────────┼──────────────┼─────────┼───────────────────┤
│ 4.0 B    │ 1×s          │ int32   │                   │
├──────────┼──────────────┼─────────┼───────────────────┤
│ 16.1 KiB │              │         │                   │
└──────────┴──────────────┴─────────┴───────────────────┘

sharding_vis(arr)

A modified version of jax.debug.visualize_array_sharding() that also supports arrays with more than 2 dimensions (by ignoring non-sharded dimensions in the visualisation until reaching 2 dimensions)

array = jax.numpy.zeros(shape=(N, N, N), dtype=jax.numpy.float32)
array = jax.device_put(array, NamedSharding(mesh, P(None, "gpus")))
sharding_vis(array)
─────────── showing dims [0, 1] from original shape (128, 128, 128) ────────────
┌───────┬───────┬───────┬───────┬───────┬───────┬───────┬───────┐
│       │       │       │       │       │       │       │       │
│       │       │       │       │       │       │       │       │
│       │       │       │       │       │       │       │       │
│       │       │       │       │       │       │       │       │
│ CPU 0 │ CPU 1 │ CPU 2 │ CPU 3 │ CPU 4 │ CPU 5 │ CPU 6 │ CPU 7 │
│       │       │       │       │       │       │       │       │
│       │       │       │       │       │       │       │       │
│       │       │       │       │       │       │       │       │
│       │       │       │       │       │       │       │       │
└───────┴───────┴───────┴───────┴───────┴───────┴───────┴───────┘

Examples

You can find many examples of how arrays can be sharded in jax and how the output would look like in tests/test_jax.py. For examples of sharding arrays along multiple jax processes check test_multihost.py and multihost.py

Assigning Labels to Jax Arrays

While the output of print_array_stats() is very useful to see which arrays are currently allocated, it is missing the connection to where the array is currently used. While in many cases the shape allows to guess which array a row is refering to, it would be nice to also know its name. Unfortunately we only get a reference to all currently allocated arrays from jax.live_arrays() which means we don't know where the array is assigned and what the variable is called. We can do an ugly workaround to still get more useful names in the print_array_stats() output: Whenever simple_array_info is used like simple_array_info(some_array,"some_array"), we will modify the some_array jax array by setting some_array._custom_label = "some_array". Therefore this label becomes a part of the jax Array object and we can use this label in our array overview.

This of course has some limitations (it only works if simple_array_info is used outside of jitted code) and is an ugly hack, so it possibly will break some things. Therefore this feature can be disabled globally using

from jax_array_info import sharding_info, print_array_stats, config as array_info_config

array_info_config.assign_labels_to_arrays = False

Example

import jax
from jax_array_info import sharding_info, print_array_stats

some_array = jax.numpy.zeros(shape=(128, 128, 128))
sharding_info(some_array, "some_array")

print_array_stats()
╭────── some_array ──────╮
│ shape: (128, 128, 128) │
│ dtype: float32         │
│ size: 8.0 MiB          │
│ not sharded            │
╰────────────────────────╯
                     allocated jax arrays                     
┏━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━━━━┓
┃ size    ┃ shape           ┃ dtype   ┃ sharded ┃ label      ┃
┡━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━━━━┩
│ 8.0 MiB │ (128, 128, 128) │ float32 │         │ some_array │
├─────────┼─────────────────┼─────────┼─────────┼────────────┤
│ 8.0 MiB │                 │         │         │            │
└─────────┴─────────────────┴─────────┴─────────┴────────────┘

About

Debugging tool to print information (especially sharding) about jax arrays

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages