diff --git a/.github/container/Dockerfile.base b/.github/container/Dockerfile.base index 50fda91a2..e1d2abc1f 100644 --- a/.github/container/Dockerfile.base +++ b/.github/container/Dockerfile.base @@ -3,6 +3,7 @@ ARG BASE_IMAGE=nvidia/cuda:12.6.2-devel-ubuntu22.04 ARG GIT_USER_NAME="JAX Toolbox" ARG GIT_USER_EMAIL=jax@nvidia.com ARG CLANG_VERSION=18 +ARG JAX_TOOLBOX_REF ############################################################################### ## Obtain GCP's NCCL TCPx plugin @@ -30,6 +31,7 @@ ARG BASE_IMAGE ARG GIT_USER_EMAIL ARG GIT_USER_NAME ARG CLANG_VERSION +ARG JAX_TOOLBOX_REF ENV CUDA_BASE_IMAGE=${BASE_IMAGE} ############################################################################### @@ -110,7 +112,7 @@ RUN <<"EOF" bash -ex git config --global user.name "${GIT_USER_NAME}" git config --global user.email "${GIT_USER_EMAIL}" EOF -RUN mkdir -p /opt/pip-tools.d +RUN mkdir -p /opt/pip-tools.d /opt/pip-tools-post-install.d ADD --chmod=777 \ git-clone.sh \ pip-finalize.sh \ @@ -141,7 +143,6 @@ COPY --from=tcpx-installer /var/lib/tcpx/lib64 ${TCPX_LIBRARY_PATH} ############################################################################### ADD install-nsight.sh /usr/local/bin -ADD nsys-2024.5-tid-export.patch /opt/nvidia RUN install-nsight.sh ############################################################################### @@ -183,7 +184,7 @@ ENV PATH=/opt/amazon/efa/bin:${PATH} ADD install-nccl-sanity-check.sh /usr/local/bin ADD nccl-sanity-check.cu /opt RUN install-nccl-sanity-check.sh -ADD jax-nccl-test parallel-launch /usr/local/bin +ADD jax-nccl-test parallel-launch /usr/local/bin/ ############################################################################### ## Add the systemcheck to the entrypoint. @@ -199,23 +200,11 @@ COPY check-shm.sh /opt/nvidia/entrypoint.d/ # COPY gcp-autoconfig.sh /opt/nvidia/entrypoint.d/ ############################################################################### -## Add helper scripts for profiling with Nsight Systems -## -## The scripts saved to /opt/jax_nsys are embedded in the output archives -## written by nsys-jax, while the nsys-jax and nsys-jax-combine scripts are -## only used inside the containers. -############################################################################### -ADD nsys-jax nsys-jax-combine /usr/local/bin/ -ADD jax_nsys/ /opt/jax_nsys -# The jax_nsys package should be installed inside the containers, so nsys-jax -# can eagerly execute analysis recipes (--nsys-jax-analysis) in the container -# environment, without an extra layer of virtual environment indirection. -RUN echo "-e /opt/jax_nsys/python/jax_nsys" > /opt/pip-tools.d/requirements-nsys-jax.in -# This should be embedded in output archives and be runnable inside containers -RUN ln -s /opt/jax_nsys/install-protoc /usr/local/bin/ -# Should be available for execution inside the containers, should not be -# embedded in the output archives. -ADD jax_nsys_tests/ /opt/jax_nsys_tests +## Install the nsys-jax JAX/XLA-aware profiling scripts, patch Nsight Systems +############################################################################### + +ADD install-nsys-jax.sh /usr/local/bin +RUN install-nsys-jax.sh ${JAX_TOOLBOX_REF} ############################################################################### ## Copy manifest file to the container diff --git a/.github/container/install-nsight.sh b/.github/container/install-nsight.sh index 4aa001cf1..dc0ef92cb 100755 --- a/.github/container/install-nsight.sh +++ b/.github/container/install-nsight.sh @@ -16,14 +16,3 @@ apt-get install -y nsight-compute nsight-systems-cli-2024.6.1 apt-get clean rm -rf /var/lib/apt/lists/* - -for NSYS in /opt/nvidia/nsight-systems-cli/2024.5.1 /opt/nvidia/nsight-systems-cli/2024.6.1; do - if [[ -d "${NSYS}" ]]; then - # * can match at least sbsa-armv8 and x86 - (cd ${NSYS}/target-linux-*/python/packages && git apply < /opt/nvidia/nsys-2024.5-tid-export.patch) - fi -done - -# Install extra dependencies needed for `nsys recipe ...` commands. These are -# used by the nsys-jax wrapper script. -ln -s $(dirname $(realpath $(command -v nsys)))/python/packages/nsys_recipe/requirements/common.txt /opt/pip-tools.d/requirements-nsys-recipe.in diff --git a/.github/container/install-nsys-jax.sh b/.github/container/install-nsys-jax.sh new file mode 100755 index 000000000..37bef8728 --- /dev/null +++ b/.github/container/install-nsys-jax.sh @@ -0,0 +1,32 @@ +#!/bin/bash +set -exo pipefail + +REF="$1" +if [[ -z "${REF}" ]]; then + echo "$0: " + exit 1 +fi + +# Install extra dependencies needed for `nsys recipe ...` commands. These are +# used by the nsys-jax wrapper script. +NSYS_DIR=$(dirname $(realpath $(command -v nsys))) +ln -s ${NSYS_DIR}/python/packages/nsys_recipe/requirements/common.txt /opt/pip-tools.d/requirements-nsys-recipe.in + +# Install the nsys-jax package, which includes nsys-jax, nsys-jax-combine, +# install-protoc (called from pip-finalize.sh), and nsys-jax-patch-nsys as well as the +# nsys_jax Python library. +URL="git+https://github.com/NVIDIA/JAX-Toolbox.git@${REF}#subdirectory=.github/container/nsys_jax&egg=nsys-jax" +echo "-e '${URL}'" > /opt/pip-tools.d/requirements-nsys-jax.in + +# protobuf will be installed at least as a dependency of nsys_jax in the base +# image, but the installed version is likely to be influenced by other packages. +echo "install-protoc /usr/local" > /opt/pip-tools-post-install.d/protoc +chmod 755 /opt/pip-tools-post-install.d/protoc + +# Make sure flamegraph.pl is available +echo "install-flamegraph /usr/local" > /opt/pip-tools-post-install.d/flamegraph +chmod 755 /opt/pip-tools-post-install.d/flamegraph + +# Make sure Nsight Systems Python patches are installed if needed +echo "nsys-jax-patch-nsys" > /opt/pip-tools-post-install.d/patch-nsys +chmod 755 /opt/pip-tools-post-install.d/patch-nsys diff --git a/.github/container/jax_nsys/install-protoc b/.github/container/jax_nsys/install-protoc deleted file mode 100755 index 71c8882be..000000000 --- a/.github/container/jax_nsys/install-protoc +++ /dev/null @@ -1,65 +0,0 @@ -#!/usr/bin/env python -import argparse -import google.protobuf -import io -import os -import platform -import requests -import zipfile - -parser = argparse.ArgumentParser( - "Install a version of the protoc compiler that is compatible with the google.protobuf runtime" -) -parser.add_argument( - "prefix", help="Output prefix under which to install protoc", type=str -) -args = parser.parse_args() - -s = requests.Session() -s.mount("https://", requests.adapters.HTTPAdapter(max_retries=5)) - -# protobuf versioning is complicated, see protocolbuffers/protobuf#11123 for more -# discussion. For older versions, when the versioning scheme was aligned, try and -# install a protoc with the same version as google.protobuf. For newer versions, given -# google.protobuf version X.Y.Z install protoc version Y.Z as described in -# https://protobuf.dev/support/version-support -runtime_version = tuple(map(int, google.protobuf.__version__.split("."))) -if runtime_version < (3, 21): - # old versioning scheme, try and install a matching protoc version - protoc_version = runtime_version -else: - # new versioning scheme, runtime minor.patch should be the protoc version - protoc_version = runtime_version[1:] - -# Install the given protobuf version -ver = ".".join(map(str, protoc_version)) -system = platform.system().lower() -machine = platform.machine() -system = {"darwin": "osx"}.get(system, system) -machine = { - "aarch64": "aarch_64", - "arm64": "aarch_64", -}.get(machine, machine) -# Apple Silicon can handle universal and x86_64 if it needs to. -machines = { - ("osx", "aarch_64"): ["aarch_64", "universal_binary", "x86_64"], -}.get((system, machine), [machine]) -for machine in machines: - r = s.get( - f"https://github.com/protocolbuffers/protobuf/releases/download/v{ver}/protoc-{ver}-{system}-{machine}.zip" - ) - if r.status_code == 404: - # assume this means the architecture is not available - continue -else: - r.raise_for_status() - -with zipfile.ZipFile(io.BytesIO(r.content)) as z: - for name in z.namelist(): - if ".." in name: - continue - if name.startswith("bin/") or name.startswith("include/"): - z.extract(name, path=args.prefix) - -# Make sure the protoc binary is executable -os.chmod(os.path.join(args.prefix, "bin", "protoc"), 0o755) diff --git a/.github/container/jax_nsys/install.sh b/.github/container/jax_nsys/install.sh deleted file mode 100755 index dcb7c0aa0..000000000 --- a/.github/container/jax_nsys/install.sh +++ /dev/null @@ -1,38 +0,0 @@ -#!/bin/bash -# -# Usage: ./install.sh [optional arguments to virtualenv] -# -# If it doesn't already exist, this creates a virtual environment named -# `nsys_jax_env` in the current directory and installs Jupyter Lab and the -# dependencies of the Analysis.ipynb notebook that is shipped alongside this -# script inside the output archives of the `nsys-jax` wrapper. -# -# The expectation is that those archives will be copied and extracted on a -# laptop or workstation, and this installation script will be run there, while -# the `nsys-jax` wrapper is executed on a remote GPU cluster. -set -ex -SCRIPT_DIR=$(cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) -VIRTUALENV="${SCRIPT_DIR}/nsys_jax_venv" -if [[ ! -d "${VIRTUALENV}" ]]; then - # Let `virtualenv` find/choose a Python. Currently >=3.10 is supported. - virtualenv -p 3.12 -p 3.11 -p 3.10 "$@" "${VIRTUALENV}" - . "${VIRTUALENV}/bin/activate" - python -m pip install -U pip - if ! python -c "import google.protobuf" > /dev/null 2>&1 || ! command -v protoc > /dev/null; then - python -m pip install protobuf requests - "${SCRIPT_DIR}/install-protoc" "${VIRTUALENV}" - fi - # matplotlib is a dependency of Analysis.ipynb but not jax_nsys - python -m pip install jupyterlab matplotlib - python -m pip install -e "${SCRIPT_DIR}/python/jax_nsys" - curl -o "${VIRTUALENV}/bin/flamegraph.pl" https://raw.githubusercontent.com/brendangregg/FlameGraph/master/flamegraph.pl - chmod 755 "${VIRTUALENV}/bin/flamegraph.pl" -else - echo "Virtual environment already exists, not installing anything..." -fi -if [ -z ${NSYS_JAX_INSTALL_SKIP_LAUNCH+x} ]; then - echo "Launching: cd ${SCRIPT_DIR} && ${VIRTUALENV}/bin/python -m jupyterlab Analysis.ipynb" - cd "${SCRIPT_DIR}" && "${VIRTUALENV}/bin/python" -m jupyterlab Analysis.ipynb -else - echo "Skipping launch of jupyterlab due to NSYS_JAX_INSTALL_SKIP_LAUNCH" -fi diff --git a/.github/container/jax_nsys/python/jax_nsys/pyproject.toml b/.github/container/jax_nsys/python/jax_nsys/pyproject.toml deleted file mode 100644 index e55a7867d..000000000 --- a/.github/container/jax_nsys/python/jax_nsys/pyproject.toml +++ /dev/null @@ -1,17 +0,0 @@ -[project] -name = "jax-nsys" -dynamic = ["version"] -dependencies = [ - "ipython", - "numpy", - "pandas", - "protobuf", # a compatible version of protoc needs to be installed out-of-band - "pyarrow", - "requests", # for install-protoc - "uncertainties", # communication analysis recipe -] -requires-python = ">= 3.10" -[project.optional-dependencies] -test = [ - "pytest" -] diff --git a/.github/container/nsys-jax b/.github/container/nsys-jax deleted file mode 100755 index d05c308ad..000000000 --- a/.github/container/nsys-jax +++ /dev/null @@ -1,815 +0,0 @@ -#!/usr/bin/env python -import argparse -from concurrent.futures import FIRST_EXCEPTION, ThreadPoolExecutor, wait -from contextlib import contextmanager -from glob import glob, iglob -import lzma -import os -import os.path as osp -import pandas as pd # type: ignore -import queue -import re -import shlex -import shutil -import sqlite3 -import subprocess -import sys -import tempfile -import time -import traceback -import zipfile - - -# Expand %q{ENV_VAR} if the variable is defined. -def expand(string: str, skip_missing=True) -> str: - missing = set() - - def rep(x): - if len(x.group(1)) % 2 == 0: - return x.group(0) - if x.group(2) not in os.environ: - missing.add(x.group(2)) - return x.group(0) - return x.group(1)[:-1] + os.environ[x.group(2)] - - expanded = re.sub(r"([%]+)q\{(.*?)\}", rep, string).replace("%%", "%") - if not skip_missing and missing: - raise Exception(f"{missing} not defined when expanding '{string}'") - return expanded - - -# Wrapper-specific arguments. This also handles -h and --help. -parser = argparse.ArgumentParser( - allow_abbrev=False, - usage=( - "nsys-jax [-h] [--nsys-jax-condition EXPRESSION] [--nsys-jax-analysis A1 " - "[--nsys-jax-analysis-arg=A1_ARG1 [--nsys-jax-analysis-arg=A1_ARG2 ...]] " - "[--nsys-jax-analysis A2 [--nsys-jax-analysis-arg=A2_ARG1 ...]] [-o OUTPUT | " - "--output OUTPUT] [-f | --force-overwrite] [nsys profile arguments ...] [--] " - "executable [executable arguments ...]" - ), - description=( - "`nsys-jax` is a wrapper for `nsys profile` that collects additional metadata " - "that are specific to JAX and XLA, post-processes the profile data, and " - "produces a compressed .zip archive containing the relevant files." - ), - epilog=( - "NOTE: if the executable arguments include a literal `--` then the optional " - "`--` shown in the usage message MUST be passed to disambiguate. This is also " - "required when extra nsys profile arguments are passed." - ), -) -parser.add_argument( - "--nsys-jax-analysis", - action="append", - dest="analysis", - help=( - "Post-processing analysis script to execute after report collection. This can " - "be the name of a bundled recipe, or the path to a Python script. The script " - "will be passed any arguments specified via --nsys-jax-analysis-arg, followed " - "by a single positional argument, which is the path to a directory of the " - "same structure as the extracted output archive." - ), - type=lambda x: ("script", x), -) -parser.add_argument( - "--nsys-jax-analysis-arg", - action="append", - dest="analysis", - help="Extra arguments to pass to analysis scripts specified via --nsys-jax-analysis", - type=lambda x: ("arg", x), -) -parser.add_argument( - "--nsys-jax-condition", - help=( - "Bash expression that will be expanded to determine if this instance " - "of nsys-jax should actually launch nsys. Example: " - "--nsys-jax-condition='$SLURM_LOCALID == 0' to only profile the first " - "process on each node. The expression is evaluated inside [[ ... ]]." - ), -) -parser.add_argument( - "-f", - "--force-overwrite", - action="store_true", - help="This must be passed for nsys-jax to overwrite an existing output archive.", -) -parser.add_argument( - "-o", - "--output", - help=( - "Output filename, if this has an .nsys-rep or .zip suffix it will be removed " - "to yield ROOT, and the output archive will be ROOT.zip, which will contain a " - "ROOT.nsys-rep." - ), -) - - -def shuffle_analysis_arg(analysis): - if analysis is None: - return [] - # [Script(A), Arg(A1), Arg(A2), Script(B), Arg(B1)] becomes [[A, A1, A2], [B, B1]] - out, current = [], [] - for t, x in analysis: - if t == "script": - if len(current): - out.append(current) - current = [x] - else: - assert t == "arg" and len(current) - current.append(x) - if len(current): - out.append(current) - return out - - -nsys_jax_flags, unknown_args = parser.parse_known_args(sys.argv) -nsys_jax_flags.analysis = shuffle_analysis_arg(nsys_jax_flags.analysis) -# Remove the name of the nsys-jax wrapper -nsys_flags_and_cmd = unknown_args[1:] -# This can have two forms: -# exe [exe args ...] -# [nsys args ...] -- exe [exe args ...] -# where the second one must be used if `exe args` contains `--`, even if no nsys args -# are passed. -try: - limit = nsys_flags_and_cmd.index("--") - nsys_flags = nsys_flags_and_cmd[:limit] - application = nsys_flags_and_cmd[limit + 1 :] -except ValueError: - # No --, everything is the application - nsys_flags = [] - application = nsys_flags_and_cmd - -if len(application) == 0: - parser.print_help() - raise Exception("No application to profile") - -if shutil.which(application[0]) is None: - parser.print_help() - raise Exception(f"{application[0]} not found by shutil.which") - -enable_profiling = True -if nsys_jax_flags.nsys_jax_condition is not None: - enable_profiling = ( - subprocess.run( - ["/bin/bash", "-c", f"[[ {nsys_jax_flags.nsys_jax_condition} ]]"], - shell=False, - ).returncode - == 0 - ) - -if nsys_jax_flags.output is None: - # There was not an explicit output location; generate one. There may be - # multiple processes racing to do this. - archive_handle, archive_name = tempfile.mkstemp( - dir=os.getcwd(), prefix="nsys-jax-report-", suffix=".zip" - ) - # Re-open it based on name later, mkstemp is just a way of avoiding races - os.close(archive_handle) - # No -f / --force-overwrite needed in this case - archive_name_can_be_overwritten = True -else: - # Explicit output location was given in `nsys_jax_flags.output`, transform that - # into the .zip-suffixed verison of it. - archive_name = ( - expand(nsys_jax_flags.output.removesuffix(".nsys-rep").removesuffix(".zip")) - + ".zip" - ) - if not osp.isabs(archive_name): - nsys_output = osp.join(os.getcwd(), archive_name) - archive_name_can_be_overwritten = nsys_jax_flags.force_overwrite - -# We will write /final/output/path/name.zip, and it will contain name.nsys-rep, -# but we do not instruct nsys to write that to /final/output/path/name.nsys-rep -# so that more of the processing can happen on a faster, more local filesystem. -report_name = osp.basename(archive_name).removesuffix(".zip") + ".nsys-rep" -tmp_dir = tempfile.mkdtemp() -tmp_rep = osp.join(tmp_dir, report_name) -nsys_flags += ["--output", tmp_rep] - -# If --nsys-jax-analysis is used, we also construct a local directory mirroring -# the extracted archive structure. TODO: clean this up -mirror_dir = None if len(nsys_jax_flags.analysis) == 0 else tempfile.mkdtemp() - - -def override_nsys_default(arg, value): - if any(x.startswith(f"--{arg}=") for x in nsys_flags): - return - nsys_flags.append(f"--{arg}={value}") - - -# Override some Nsight Systems defaults, but don't block setting them explicitly. -override_nsys_default("cuda-graph-trace", "node") -override_nsys_default("cpuctxsw", "none") -override_nsys_default("python-sampling", "true") -# TODO: consider dropping osrt from here -override_nsys_default("trace", "cublas,cuda,cudnn,cusolver,nvtx,osrt") - -# Modified environment in which to run the application -env = os.environ.copy() - -# Stop stack traces from being truncated in the metadata passed to XLA unless -# the option was explicitly set. -if "JAX_TRACEBACK_IN_LOCATIONS_LIMIT" not in env: - env["JAX_TRACEBACK_IN_LOCATIONS_LIMIT"] = "-1" - -# Disable the compilation cache so that we get the full set of .pb files -if "JAX_ENABLE_COMPILATION_CACHE" not in env: - env["JAX_ENABLE_COMPILATION_CACHE"] = "false" - -# Get the existing XLA_FLAGS and parse them into a dictionary. -xla_flag_list = shlex.split(env.get("XLA_FLAGS", "")) -xla_flags = {} -for flag in xla_flag_list: - assert flag.startswith("--") - bits = flag[2:].split("=", maxsplit=1) - name, value = bits[0], bits[1] if len(bits) > 1 else None - assert name not in xla_flags - xla_flags[name] = value - - -def as_list(flags): - return [f"--{n}" if v is None else f"--{n}={v}" for n, v in flags.items()] - - -assert xla_flag_list == as_list(xla_flags) - - -def as_bool(s): - """String -> bool conversion following XLA's semantics.""" - if s.lower() == "true" or s == "1": - return True - if s.lower() == "false" or s == "0": - return False - raise Exception("Could not convert '{}' to bool".format(s)) - - -# Enable dumping protobufs unless it was explicitly disabled -if "xla_dump_hlo_as_proto" not in xla_flags: - xla_flags["xla_dump_hlo_as_proto"] = "true" - -proto_dump_enabled = as_bool(xla_flags["xla_dump_hlo_as_proto"]) - -# For simplicity, impose our directory structure on the dump from XLA -if proto_dump_enabled: - if "xla_dump_to" in xla_flags: - print(f"WARNING: --xla_dump_to={xla_flags['xla_dump_to']} being overriden") - xla_flags["xla_dump_to"] = osp.join(tmp_dir, "dump") -else: - print("WARNING: protobuf dump explicitly disabled, things will break") - -# Serialise the modified XLA flags. shlex.join is tempting, but doesn't seem to -# get the right result for --xla_dump_hlo_pass_re=.*, as it adds extra quotes. -env["XLA_FLAGS"] = " ".join(as_list(xla_flags)) - -# Run the application in nsys -# TODO: consider being more fault-tolerant? -# The Nsight Systems command prefix -nsys = [ - "nsys", - "profile", -] + nsys_flags -subprocess.run((nsys if enable_profiling else []) + application, check=True, env=env) - -# If we skipped profiling the application, there is nothing more to be done. -if not enable_profiling: - sys.exit(0) - -# Check the output report was written and is new -if not osp.exists(tmp_rep): - raise Exception(f"Could not find output file: {tmp_rep}") - - -# Use deflate compression -compress_deflate = {"compress_type": zipfile.ZIP_DEFLATED} -# Do not compress (if the file is already compressed) -compress_none: dict[str, int] = {} - - -def copy_proto_files_to_tmp(tmp_dir, xla_dir="/opt/xla"): - """ - Copy .proto files from XLA into a temporary directory under `tmp_dir`. - - TODO: install .proto files as part of `jaxlib`, so this can work without - the XLA sources being available under `xla_dir` e.g. as part of a - generic `pip` installation of JAX. - - Returns: (name of temporary directory, list of relative .proto paths) - """ - start = time.time() - proto_dir = osp.join(tmp_dir, "protos") - tsl_dir = osp.join(xla_dir, "third_party", "tsl") - proto_files = [] - for p, root in [("tsl/**/*.proto", tsl_dir), ("xla/**/*.proto", xla_dir)]: - for proto in iglob(p, recursive=True, root_dir=root): - proto_files.append(proto) - dst_dir = osp.join(proto_dir, osp.dirname(proto)) - if not osp.isdir(dst_dir): - os.makedirs(dst_dir) - shutil.copy(osp.join(root, proto), osp.join(proto_dir, proto)) - print(f"{archive_name}: gathered .proto files in {time.time()-start:.2f}s") - return proto_dir, proto_files - - -def run_nsys_recipe(recipe, report_file, tmp_dir, output_queue): - """ - Post-process a .nsys-rep file into a .parquet file for offline analysis. - This is currently implemented using the given nsys recipe. - """ - start = time.time() - recipe_output = osp.join(tmp_dir, recipe) - subprocess.run( - [ - "nsys", - "recipe", - recipe, - "--input", - report_file, - "--output", - recipe_output, - ], - check=True, - ) - for ofile in iglob(recipe + "/**", recursive=True, root_dir=tmp_dir): - full_path = osp.join(tmp_dir, ofile) - # glob("/does-not-exist/**", recursive=True) == ['/does-not-exist/'] - if osp.isdir(full_path) or not osp.exists(full_path): - continue - output_queue.put((ofile, full_path, compress_none)) - print(f"{archive_name}: post-processing finished in {time.time()-start:.2f}s") - - -def compress_and_archive(prefix, file, output_queue): - """ - Read prefix+file, compress it, queue the compressed bytes for archival - without further compression. - """ - with open(osp.join(prefix, file), "rb") as ifile: - output_queue.put((file + ".xz", lzma.compress(ifile.read()), compress_none)) - - -def run_nsys_stats_report(report, report_file, tmp_dir, output_queue): - """ - Run a stats recipe on an .nsys-rep file (that has probably already been - exported to .sqlite). - """ - start = time.time() - subprocess.run( - [ - "nsys", - "stats", - "--report", - report, - "--input", - report_file, - # avoid race conditions with other reports/etc. - "--sqlite", - osp.splitext(report_file)[0] + "-" + report + ".sqlite", - "--output", - osp.join(tmp_dir, "report"), - ], - check=True, - ) - for ofile in iglob("report_" + report + ".csv", root_dir=tmp_dir): - compress_and_archive(tmp_dir, ofile, output_queue) - print(f"{archive_name}: post-processing finished in {time.time()-start:.2f}s") - - -def save_device_stream_thread_names(tmp_dir, report, output_queue): - """ - Extract extra information from the SQLite dump that is needed to map projected NVTX - ranges to global device IDs. - """ - start = time.time() - assert report.endswith(".nsys-rep"), f"{report} had an unexpected suffix" - db_file = report.removesuffix(".nsys-rep") + "-metadata.sqlite" - subprocess.run( - [ - "nsys", - "export", - "--type", - "sqlite", - "--tables", - "StringIds,TARGET_INFO_GPU,TARGET_INFO_NVTX_CUDA_DEVICE,TARGET_INFO_SYSTEM_ENV,ThreadNames", - "--output", - db_file, - report, - ], - check=True, - ) - assert os.path.exists(db_file) - con = sqlite3.connect(db_file) - cur = con.cursor() - - def table_to_parquet(query, index, filename, columns=None, index_name=None): - res = cur.execute(query) - if columns is None: - columns = [x[0] for x in res.description] - df = pd.DataFrame(res, columns=columns).set_index(index, verify_integrity=True) - if index_name is not None: - df.index.name = index_name - df.to_parquet(osp.join(tmp_dir, filename)) - output_queue.put((filename, osp.join(tmp_dir, filename), compress_none)) - - # Extract {(pid, tid): (name, priority)} map; PID/TID arithmetic comes from - # https://docs.nvidia.com/nsight-systems/UserGuide/index.html#common-sqlite-examples - table_to_parquet( - r""" - SELECT - StringIds.value AS Name, - ThreadNames.priority AS Priority, - ThreadNames.globalTid / 0x1000000 % 0x1000000 AS PID, - ThreadNames.globalTid % 0x1000000 AS TID - FROM ThreadNames - INNER JOIN StringIds ON ThreadNames.nameId=StringIds.id""", - ["PID", "TID"], - "thread-metadata.parquet", - ) - # Extract high level metadata about the profiling session, including the hostname - table_to_parquet( - "SELECT name, nameEnum, value FROM TARGET_INFO_SYSTEM_ENV", - "nameEnum", - "system-metadata.parquet", - ) - - def table_exists(table_name): - return ( - cur.execute( - f"SELECT 1 FROM sqlite_master WHERE type='table' AND name='{table_name}'" - ).fetchall() - != [] - ) - - # Cannot write device-metadata.parquet if no device activity was profiled. - if table_exists("TARGET_INFO_GPU") and table_exists("TARGET_INFO_NVTX_CUDA_DEVICE"): - # Extract {device_id: metadata_and_name} map, making sure to pick up the name that - # XLA assigns via NVTX - def table_columns(table_name): - return [ - (table_name, x[0]) - for x in cur.execute(f"SELECT * FROM {table_name} LIMIT 1").description - ] - - table_to_parquet( - """ - SELECT * FROM TARGET_INFO_GPU - INNER JOIN TARGET_INFO_NVTX_CUDA_DEVICE ON TARGET_INFO_GPU.cuDevice = TARGET_INFO_NVTX_CUDA_DEVICE.deviceId""", - ("TARGET_INFO_GPU", "cuDevice"), - "device-metadata.parquet", - columns=pd.MultiIndex.from_tuples( - table_columns("TARGET_INFO_GPU") - + table_columns("TARGET_INFO_NVTX_CUDA_DEVICE") - ), - index_name="cuDevice", - ) - else: - print("WARNING: NOT writing device metadata, no device activity profiled?") - print(f"{archive_name}: extracted device/thread names in {time.time()-start:.2f}s") - - -def copy_jax_nsys_files(input_dir, output_queue): - """ - Gather files from `input_dir` and queue them for archival. - """ - # Gather the files from /opt/jax_nsys that should be bundled into the archive. - for file in iglob("**", recursive=True, root_dir=input_dir): - full_path = osp.join(input_dir, file) - if osp.isdir(full_path): - continue - if file.startswith("python/jax_nsys/jax_nsys/__pycache__") or file.startswith( - "python/jax_nsys/jax_nsys.egg-info" - ): - continue - output_queue.put((file, full_path, compress_deflate)) - - -def find_pb_files_in_tmp(tmp_dir): - """ - Return a prefix + list of relative paths to Protobuf files dumped by XLA. - """ - return tmp_dir, glob("dump/*.pb", root_dir=tmp_dir) + glob( - "dump/*.pbtxt", root_dir=tmp_dir - ) - - -def gather_source_files( - proto_dir, proto_files, pb_file_prefix, pb_file_list, output_queue -): - """ - Given a directory containing the required .proto files (`proto_dir`) and a - prefix (`pb_file_prefix`) and list of relative paths to .pb files - (`pb_file_list`), extract a list of source code files referred to by the - XLA metadata and embed those source code files in the output archive. - """ - start = time.time() - # .hlo.pb are used to gather source code to be embedded - hlo_pb_files = [ - osp.join(pb_file_prefix, x) for x in pb_file_list if x.endswith(".hlo.pb") - ] - with tempfile.TemporaryDirectory() as tmp_dir: - # Compile the .proto files - subprocess.run( - ["protoc", f"-I={proto_dir}", f"--python_out={tmp_dir}"] + proto_files, - check=True, - cwd=proto_dir, - ) - # Collect the set of referenced source files - sys.path.insert(0, tmp_dir) - from xla.service import hlo_pb2 - - hlo = hlo_pb2.HloProto() - src_files = set() - for hlo_pb_file in hlo_pb_files: - with open(hlo_pb_file, "rb") as f: - hlo.ParseFromString(f.read()) - src_files |= set(hlo.hlo_module.stack_frame_index.file_names) - sys.path.remove(tmp_dir) - if len(src_files) == 0: - print("WARNING: no source files were gathered") - # Copy these files into the output archive. - for src_file in src_files: - if src_file == "": - # This can appear due to python -c "...", for example. - continue - assert osp.isabs(src_file), f"{src_file} is not absolute" - output_queue.put(("sources" + src_file, src_file, compress_deflate)) - print(f"{archive_name}: gathered source code in {time.time()-start:.2f}s") - - -def execute_analysis_scripts(mirror_dir, analysis_scripts): - """ - Execute any post-processing scripts passed via --nsys-jax-analysis, - returning a list of output files that should be added to the output - archive. - """ - if len(analysis_scripts) == 0: - return [], 0 - - assert mirror_dir is not None - output = [] - exit_code = 0 - used_slugs = set() - for analysis in analysis_scripts: - script, args = analysis[0], analysis[1:] - # If --nsys-jax-analysis is the name of a bundled analysis script, use that. Otherwise it should be a file that exists. - search = [ - osp.join( - mirror_dir, - "python", - "jax_nsys_analysis", - script + ".py", - ), - script, - ] - candidates = list(filter(osp.exists, search)) - assert len(candidates), f"Could not find analysis script, tried {search}" - args.append(mirror_dir) - analysis_command = [sys.executable, candidates[0]] + args - # Derive a unique name slug from the analysis script name - slug = osp.basename(candidates[0]).removesuffix(".py") - n, suffix = 1, "" - while slug + suffix in used_slugs: - suffix = f"-{n}" - n += 1 - slug += suffix - used_slugs.add(slug) - working_dir = osp.join(mirror_dir, "analysis", slug) - os.makedirs(working_dir, exist_ok=True) - print( - f"Running analysis script: {shlex.join(analysis_command)} in {working_dir}" - ) - result = subprocess.run( - analysis_command, - cwd=working_dir, - ) - if result.returncode != 0: - exit_code = result.returncode - # Gather output files of the scrpt - for path in iglob("**", recursive=True, root_dir=working_dir): - output.append( - (osp.join("analysis", slug, path), osp.join(working_dir, path)) - ) - return output, exit_code - - -def write_output_file(to_process, mirror_dir, analysis_scripts): - """ - Write the output archive (`archive_name`) by consuming entries from the - queue until a `None` sentinel value is seen. If `mirror_dir` is not None - then populate it with symlinks/files as necessary to create a structure - equivalent to the output archive. - """ - start = time.time() - with zipfile.ZipFile( - archive_name, "w" if archive_name_can_be_overwritten else "x" - ) as archive: - while True: - timeout = 30 - try: - item = to_process.get(timeout=timeout) - to_process.task_done() - if item is None: - # This is the sentinel value instructing us to exit. - assert to_process.empty() - break - path_in_archive, content, kwargs = item - mirror_path = None - if mirror_dir is not None: - mirror_path = osp.join(mirror_dir, path_in_archive) - os.makedirs(osp.dirname(mirror_path), exist_ok=True) - if isinstance(content, bytes): - archive.writestr(path_in_archive, content, **kwargs) - if mirror_path is not None: - with open(mirror_path, "wb") as mfile: - mfile.write(content) - else: - archive.write(content, arcname=path_in_archive, **kwargs) - if mirror_path is not None: - os.symlink(content, mirror_path) - except queue.Empty: - print(f"{archive_name}: output stalled ({timeout}s heartbeat)") - # Execute analysis scripts so their outputs can be bundled in the archive - # before it is closed - analysis_outputs, exit_code = execute_analysis_scripts( - mirror_dir, analysis_scripts - ) - for path_in_archive, local_path in analysis_outputs: - archive.write(filename=local_path, arcname=path_in_archive) - os.chmod(archive_name, 0o644) - print(f"{archive_name}: wrote in {time.time()-start:.2f}s") - if exit_code != 0: - print("Exiting due to analysis script errors") - sys.exit(exit_code) - - -def process_pb_files(pb_future): - """ - Queue .pb and .pbtxt files for inclusion in the output archive. - """ - pb_file_prefix, pb_file_list = pb_future.result() - for pb_file in pb_file_list: - futures.append( - executor.submit( - compress_and_archive, pb_file_prefix, pb_file, files_to_archive - ) - ) - - -def process_pb_and_proto_files(pb_future, proto_future, output_queue, futures): - """ - Queue .proto files for inclusion in the output archive and trigger - gathering source code files once .pb/.pbtxt/.proto files are available. - """ - # Block for completion of copy_proto_files_to_tmp - proto_dir, proto_files = proto_future.result() - # Queue them for inclusion in the output archive - for proto_file in proto_files: - output_queue.put( - ( - osp.join("protos", proto_file), - osp.join(proto_dir, proto_file), - compress_deflate, - ) - ) - # Wait to have pb files too - pb_file_prefix, pb_file_list = pb_future.result() - # Submit work that depends on the proto directory - futures.append( - executor.submit( - gather_source_files, - proto_dir, - proto_files, - pb_file_prefix, - pb_file_list, - files_to_archive, - ) - ) - - -# Orchestrate post-processing steps: -# - collect Python source files: -# - collect list of .proto files -# - copy them to a temp dir -# - extract list of Python source files from .pb/.pbtxt files using that dir -# - save those source files to the archive -# - save the .proto files in the temp dir to the archive -# - save .pb/.pbtxt files: -# - gather a list of these -# - compress them individually -# - add the compressed versions to the output archive w/out extra compression -# - save the .nsys-rep file to the output archive with compression -# - post-process the .nsys-rep -# - convert .nsys-rep -> .parquet in the temp dir with nsys recipe -# - save the .parquet file to the output archive w/out extra compression -# - copy the contents of /opt/jax_nsys into the output archive - -# Element format: (path_in_archive, Path or bytes, ZipFile.write* kwargs) -files_to_archive: queue.Queue = queue.Queue() - - -@contextmanager -def output_thread(executor: ThreadPoolExecutor): - """ - Launch the output worker on context manager entry, signal that it should - exit on context manager exit. - """ - try: - # Spawn a worker to actually write the output file, consuming entries - # in output_queue. - future = executor.submit( - write_output_file, - files_to_archive, - mirror_dir, - nsys_jax_flags.analysis, - ) - yield future - finally: - # Signal via the output queue that the worker should exit. - files_to_archive.put(None) - # Make sure any errors from the output thread are surfaced - future.result() - - -exit_code = 0 -with ThreadPoolExecutor() as executor, output_thread(executor): - # Track futures so we can wait on them and report errors. - futures = [] - # Queue the .nsys-rep for compression - files_to_archive.put( - ( - report_name, - tmp_rep, - compress_deflate, - ) - ) - # Convert .nsys-rep -> .parquet and queue the latter for archival - futures.append( - executor.submit( - run_nsys_recipe, - "nvtx_gpu_proj_trace", - tmp_rep, - tmp_dir, - files_to_archive, - ) - ) - # Copy /opt/jax_nsys into the archive - futures.append( - executor.submit(copy_jax_nsys_files, "/opt/jax_nsys", files_to_archive) - ) - # Gather the list of .proto files - proto_future = executor.submit(copy_proto_files_to_tmp, tmp_dir) - # Gather the list of .pb[txt] files - pb_future = executor.submit(find_pb_files_in_tmp, tmp_dir) - futures.append(pb_future) - futures.append(executor.submit(process_pb_files, pb_future)) - # Wait on pb_future and proto_future and submit dependent work - futures.append( - executor.submit( - process_pb_and_proto_files, - pb_future, - proto_future, - files_to_archive, - futures, - ) - ) - futures.append( - executor.submit( - run_nsys_stats_report, - "nvtx_pushpop_trace", - tmp_rep, - tmp_dir, - files_to_archive, - ) - ) - # Do some custom post-processing of the .sqlite export generated by gpu_proj_future - futures.append( - executor.submit( - save_device_stream_thread_names, - tmp_dir, - tmp_rep, - files_to_archive, - ) - ) - # Wait for errors/completion of `futures`; note that this does not include - # the output thread, which is signaled to upon exiting from this block. - # Also note that the list of futures can still grow at this point. - retired = 0 - while True: - results = wait(futures, return_when=FIRST_EXCEPTION, timeout=30) - # Check if we exited early because of an exception and, if so, print it - # immediately. Do not abort, so even in case of errors a valid archive - # containing as much useful information as possible will be written. - retired += len(results.done) - for future in results.done: - futures.remove(future) - if future.exception() is not None: - exit_code = 1 - traceback.print_exception(future.exception()) - pending = len(futures) - if pending == 0: - break - print(f"{archive_name}: {pending}/{len(futures) + retired} pending") -if exit_code: - print(f"{archive_name}: exiting with code {exit_code} due to errors") -sys.exit(exit_code) diff --git a/.github/container/nsys-jax-combine b/.github/container/nsys-jax-combine deleted file mode 100755 index 36ef782af..000000000 --- a/.github/container/nsys-jax-combine +++ /dev/null @@ -1,200 +0,0 @@ -#!/usr/bin/env python -import argparse -from collections import defaultdict -import copy -import os -import pathlib -import shlex -import shutil -import subprocess -import sys -import tempfile -import zipfile - -parser = argparse.ArgumentParser( - description=( - "`nsys-jax-combine` facilitates distributed profiling of JAX applications " - "using the `nsys-jax` wrapper. It aggregates multiple .zip outputs from " - "different `nsys-jax` processes that profiled the same distributed execution " - "of an application, checking consistency and removing duplicated data." - ), -) -parser.add_argument( - "--analysis", - action="append", - help=( - "Post-processing analysis script to execute after merging. This can be the " - "name of a recipe bundled in the inpit files, or the path to a Python script. " - "The script will be passed any arguments specified via --analysis-arg, " - "followed by a single positional argument, which is the path to a directory " - "of the same structure as the extracted output archive." - ), - type=lambda x: ("script", x), -) -parser.add_argument( - "--analysis-arg", - action="append", - dest="analysis", - help="Extra arguments to pass to analysis scripts specified via --analysis", - type=lambda x: ("arg", x), -) - - -def shuffle_analysis_arg(analysis): - if analysis is None: - return [] - # [Script(A), Arg(A1), Arg(A2), Script(B), Arg(B1)] becomes [[A, A1, A2], [B, B1]] - out, current = [], [] - for t, x in analysis: - if t == "script": - if len(current): - out.append(current) - current = [x] - else: - assert t == "arg" and len(current) - current.append(x) - if len(current): - out.append(current) - return out - - -parser.add_argument( - "-f", - "--force-overwrite", - action="store_true", - help="Overwrite the output file if it exists.", -) -parser.add_argument( - "input", - type=pathlib.Path, - nargs="+", - help="Input .zip archives produced by `nsys-jax`", -) - - -def check_keep_nsys_rep(raw): - assert raw in {"all", "first", "none"} - return raw - - -parser.add_argument( - "--keep-nsys-rep", - default="first", - type=check_keep_nsys_rep, - help=( - "How many .nsys-rep files from the input to copy to the output. Supported " - "values are 'all', 'first' and 'none'." - ), -) -parser.add_argument( - "-o", - "--output", - help="Output file name", - required=True, - type=pathlib.Path, -) -# TODO: derive a default output path from the input paths -args = parser.parse_args() -args.analysis = shuffle_analysis_arg(args.analysis) -if args.output.suffix != ".zip": - args.output = args.output.with_suffix(".zip") -if os.path.exists(args.output) and not args.force_overwrite: - raise Exception( - f"Output path {args.output} already exists and -f/--force-overwrite was not passed" - ) - -hashes = defaultdict(set) -for input in args.input: - with zipfile.ZipFile(input) as ifile: - for member in ifile.infolist(): - hashes[member.filename].add(member.CRC) - -mirror_dir = pathlib.Path(tempfile.mkdtemp()) if len(args.analysis) else None -with zipfile.ZipFile(args.output, "w") as ofile: - for n_input, input in enumerate(args.input): - first_input = n_input == 0 - keep_this_nsys_rep = args.keep_nsys_rep == "all" or ( - args.keep_nsys_rep == "first" and first_input - ) - with zipfile.ZipFile(input) as ifile: - for member in ifile.infolist(): - if member.is_dir(): - continue - filename = member.filename - assert filename in hashes - seen_hashes = hashes[filename] - - def write(dst_info): - assert dst_info.filename not in set(ofile.namelist()) - with ifile.open(member) as src: - with ofile.open(dst_info, "w") as dst: - shutil.copyfileobj(src, dst) - if mirror_dir is not None: - dst_path = mirror_dir / dst_info.filename - os.makedirs(dst_path.parent, exist_ok=True) - src.seek(0) - with open(dst_path, "wb") as dst: - shutil.copyfileobj(src, dst) - - if filename.endswith(".nsys-rep"): - assert len(seen_hashes) == 1 - if filename == input.stem + ".nsys-rep" and keep_this_nsys_rep: - # `filename`` is the .nsys-rep from `input`` - write(member) - else: - if len(seen_hashes) == 1: - # This file was the same in all inputs: copy it once. - if first_input: - write(member) - else: - # This file was not the same in all inputs: copy it to a - # modified destination. An input file A/B in reportN.zip will - # be saved as A/B/reportN in the output, i.e. A/B will be a - # directory instead of a file. TODO: in future instead of using - # input.stem use a standardised format showing the device - # numbers that were profiled in reportN.zip. - dst_info = copy.copy(member) - dst_info.filename = filename + "/" + input.stem - write(dst_info) - if len(args.analysis): - assert mirror_dir is not None - used_slugs = set() - for analysis in args.analysis: - # Execute post-processing recipes and add any outputs to `ofile` - script, script_args = analysis[0], analysis[1:] - # If --analysis is the name of a bundled analysis script, use that. Otherwise it should be a file that exists. - search = [ - mirror_dir / "python" / "jax_nsys_analysis" / (script + ".py"), - pathlib.Path(script), - ] - candidates = list(filter(lambda p: p.exists(), search)) - assert len(candidates), f"Could not find analysis script, tried {search}" - analysis_command = ( - [sys.executable, candidates[0]] + script_args + [mirror_dir] - ) - # Derive a unique name slug from the analysis script name - slug = os.path.basename(candidates[0]).removesuffix(".py") - n, suffix = 1, "" - while slug + suffix in used_slugs: - suffix = f"-{n}" - n += 1 - slug += suffix - used_slugs.add(slug) - working_dir = mirror_dir / "analysis" / slug - os.makedirs(working_dir, exist_ok=True) - print( - f"Running analysis script: {shlex.join(map(str, analysis_command))} in {working_dir}" - ) - subprocess.run( - analysis_command, - check=True, - cwd=working_dir, - ) - # Gather output files of the scrpt - for path in working_dir.rglob("*"): - with ( - open(working_dir / path, "rb") as src, - ofile.open(str(path.relative_to(mirror_dir)), "w") as dst, - ): - # https://github.com/python/mypy/issues/15031 ? - shutil.copyfileobj(src, dst) # type: ignore diff --git a/.github/container/jax_nsys/python/jax_nsys/jax_nsys/__init__.py b/.github/container/nsys_jax/nsys_jax/__init__.py similarity index 100% rename from .github/container/jax_nsys/python/jax_nsys/jax_nsys/__init__.py rename to .github/container/nsys_jax/nsys_jax/__init__.py diff --git a/.github/container/jax_nsys/Analysis.ipynb b/.github/container/nsys_jax/nsys_jax/analyses/Analysis.ipynb similarity index 99% rename from .github/container/jax_nsys/Analysis.ipynb rename to .github/container/nsys_jax/nsys_jax/analyses/Analysis.ipynb index 3224c940b..bb7dced40 100644 --- a/.github/container/jax_nsys/Analysis.ipynb +++ b/.github/container/nsys_jax/nsys_jax/analyses/Analysis.ipynb @@ -9,7 +9,7 @@ "source": [ "from collections import defaultdict\n", "import functools\n", - "from jax_nsys import (\n", + "from nsys_jax import (\n", " align_profiler_data_timestamps,\n", " apply_warmup_heuristics,\n", " display_flamegraph,\n", diff --git a/.github/container/jax_nsys/python/jax_nsys_analysis/communication.py b/.github/container/nsys_jax/nsys_jax/analyses/communication.py similarity index 99% rename from .github/container/jax_nsys/python/jax_nsys_analysis/communication.py rename to .github/container/nsys_jax/nsys_jax/analyses/communication.py index ef02f5c1b..5388a1f84 100644 --- a/.github/container/jax_nsys/python/jax_nsys_analysis/communication.py +++ b/.github/container/nsys_jax/nsys_jax/analyses/communication.py @@ -1,7 +1,7 @@ #!/usr/bin/env python import argparse from collections import defaultdict -from jax_nsys import ( +from nsys_jax import ( align_profiler_data_timestamps, apply_warmup_heuristics, ensure_compiled_protos_are_importable, diff --git a/.github/container/jax_nsys/python/jax_nsys_analysis/summary.py b/.github/container/nsys_jax/nsys_jax/analyses/summary.py similarity index 99% rename from .github/container/jax_nsys/python/jax_nsys_analysis/summary.py rename to .github/container/nsys_jax/nsys_jax/analyses/summary.py index 978c041fa..33c825837 100644 --- a/.github/container/jax_nsys/python/jax_nsys_analysis/summary.py +++ b/.github/container/nsys_jax/nsys_jax/analyses/summary.py @@ -1,6 +1,6 @@ #!/usr/bin/env python import argparse -from jax_nsys import ( +from nsys_jax import ( apply_warmup_heuristics, ensure_compiled_protos_are_importable, generate_compilation_statistics, diff --git a/.github/container/jax_nsys/python/jax_nsys/jax_nsys/analysis.py b/.github/container/nsys_jax/nsys_jax/analysis.py similarity index 100% rename from .github/container/jax_nsys/python/jax_nsys/jax_nsys/analysis.py rename to .github/container/nsys_jax/nsys_jax/analysis.py diff --git a/.github/container/jax_nsys/python/jax_nsys/jax_nsys/data_loaders.py b/.github/container/nsys_jax/nsys_jax/data_loaders.py similarity index 99% rename from .github/container/jax_nsys/python/jax_nsys/jax_nsys/data_loaders.py rename to .github/container/nsys_jax/nsys_jax/data_loaders.py index d6e4464bd..a3e848dc8 100644 --- a/.github/container/jax_nsys/python/jax_nsys/jax_nsys/data_loaders.py +++ b/.github/container/nsys_jax/nsys_jax/data_loaders.py @@ -11,7 +11,7 @@ from .analysis import calculate_collective_metrics from .protobuf import xla_module_metadata -from .utils import make_child_mask, ProfilerData +from .utils import default_data_prefix, make_child_mask, ProfilerData pd.options.mode.copy_on_write = True @@ -629,7 +629,7 @@ def _load_nvtx_pushpop_trace(prefix: pathlib.Path, frames: set[str]) -> pd.DataF def load_profiler_data( - prefix: pathlib.Path = pathlib.Path("."), + prefix: pathlib.Path = default_data_prefix(), frames: set[str] = {"communication", "compile", "module", "thunk"}, ) -> ProfilerData: """ diff --git a/.github/container/jax_nsys/python/jax_nsys/jax_nsys/protobuf.py b/.github/container/nsys_jax/nsys_jax/protobuf.py similarity index 98% rename from .github/container/jax_nsys/python/jax_nsys/jax_nsys/protobuf.py rename to .github/container/nsys_jax/nsys_jax/protobuf.py index 4feae6038..a43160f19 100644 --- a/.github/container/jax_nsys/python/jax_nsys/jax_nsys/protobuf.py +++ b/.github/container/nsys_jax/nsys_jax/protobuf.py @@ -3,6 +3,8 @@ import pathlib import typing +from .utils import default_data_prefix + def _host_memory_space(inst): return inst.shape.layout.memory_space == 5 @@ -239,7 +241,7 @@ def unique_result(self, callable): def xla_module_metadata( program_id: int, policy: typing.Literal["consistent"], - prefix: pathlib.Path = pathlib.Path("."), + prefix: pathlib.Path = default_data_prefix(), ) -> HloProto: ... @@ -247,7 +249,7 @@ def xla_module_metadata( def xla_module_metadata( program_id: int, policy: typing.Literal["all"], - prefix: pathlib.Path = pathlib.Path("."), + prefix: pathlib.Path = default_data_prefix(), ) -> HloProtoSet: ... @@ -255,7 +257,7 @@ def xla_module_metadata( def xla_module_metadata( program_id: int, policy: str = "consistent", - prefix: pathlib.Path = pathlib.Path("."), + prefix: pathlib.Path = default_data_prefix(), ) -> typing.Union[HloProto, HloProtoSet]: """ Load the protobuf metadata for module `program_id`. If given, `prefix` is the diff --git a/.github/container/jax_nsys/python/jax_nsys/jax_nsys/protobuf_utils.py b/.github/container/nsys_jax/nsys_jax/protobuf_utils.py similarity index 94% rename from .github/container/jax_nsys/python/jax_nsys/jax_nsys/protobuf_utils.py rename to .github/container/nsys_jax/nsys_jax/protobuf_utils.py index 03b1b4816..6873e06de 100644 --- a/.github/container/jax_nsys/python/jax_nsys/jax_nsys/protobuf_utils.py +++ b/.github/container/nsys_jax/nsys_jax/protobuf_utils.py @@ -9,6 +9,8 @@ import tempfile from typing import Optional +from .utils import default_data_prefix + def which(executable: str) -> pathlib.Path: """ @@ -55,7 +57,9 @@ def compile_protos( subprocess.run(args, check=True) -def ensure_compiled_protos_are_importable(*, prefix: pathlib.Path = pathlib.Path(".")): +def ensure_compiled_protos_are_importable( + *, prefix: pathlib.Path = default_data_prefix() +): """ See if the Python bindings generated from .proto are importable, and if not then generate them in a temporary directory and prepend it to sys.path. diff --git a/.github/container/nsys_jax/nsys_jax/scripts/install_flamegraph.py b/.github/container/nsys_jax/nsys_jax/scripts/install_flamegraph.py new file mode 100644 index 000000000..af81fe712 --- /dev/null +++ b/.github/container/nsys_jax/nsys_jax/scripts/install_flamegraph.py @@ -0,0 +1,31 @@ +import argparse +import os +import requests + + +def main(): + """ + Ideally flamegraph.pl could just be declared as a dependency in pyproject.toml, but + it's not packaged for that. It could probably be worked around, but for now we just + distribute this script to install it. + """ + # TODO: add a default to (with confirmation) install in the same prefix as this script is installed to + parser = argparse.ArgumentParser("Fetch the flamegraph.pl script") + parser.add_argument( + "prefix", help="Output prefix under which to install flamegraph.pl", type=str + ) + args = parser.parse_args() + install_dir = os.path.join(args.prefix, "bin") + install_path = os.path.join(install_dir, "flamegraph.pl") + assert not os.path.exists(install_path), f"{install_path} already exists" + os.makedirs(install_dir, exist_ok=True) + + s = requests.Session() + s.mount("https://", requests.adapters.HTTPAdapter(max_retries=5)) + r = s.get( + "https://raw.githubusercontent.com/brendangregg/FlameGraph/master/flamegraph.pl" + ) + r.raise_for_status() + with open(install_path, "w") as ofile: + ofile.write(r.text) + os.chmod(install_path, 0o755) diff --git a/.github/container/nsys_jax/nsys_jax/scripts/install_protoc.py b/.github/container/nsys_jax/nsys_jax/scripts/install_protoc.py new file mode 100644 index 000000000..e524cb6ab --- /dev/null +++ b/.github/container/nsys_jax/nsys_jax/scripts/install_protoc.py @@ -0,0 +1,67 @@ +import argparse +import google.protobuf +import io +import os +import platform +import requests +import zipfile + + +def main(): + # TODO: add a default to (with confirmation) install in the same prefix as this script is installed to + parser = argparse.ArgumentParser( + "Install a version of the protoc compiler that is compatible with the google.protobuf runtime" + ) + parser.add_argument( + "prefix", help="Output prefix under which to install protoc", type=str + ) + args = parser.parse_args() + + s = requests.Session() + s.mount("https://", requests.adapters.HTTPAdapter(max_retries=5)) + + # protobuf versioning is complicated, see protocolbuffers/protobuf#11123 for more + # discussion. For older versions, when the versioning scheme was aligned, try and + # install a protoc with the same version as google.protobuf. For newer versions, given + # google.protobuf version X.Y.Z install protoc version Y.Z as described in + # https://protobuf.dev/support/version-support + runtime_version = tuple(map(int, google.protobuf.__version__.split("."))) + if runtime_version < (3, 21): + # old versioning scheme, try and install a matching protoc version + protoc_version = runtime_version + else: + # new versioning scheme, runtime minor.patch should be the protoc version + protoc_version = runtime_version[1:] + + # Install the given protobuf version + ver = ".".join(map(str, protoc_version)) + system = platform.system().lower() + machine = platform.machine() + system = {"darwin": "osx"}.get(system, system) + machine = { + "aarch64": "aarch_64", + "arm64": "aarch_64", + }.get(machine, machine) + # Apple Silicon can handle universal and x86_64 if it needs to. + machines = { + ("osx", "aarch_64"): ["aarch_64", "universal_binary", "x86_64"], + }.get((system, machine), [machine]) + for machine in machines: + r = s.get( + f"https://github.com/protocolbuffers/protobuf/releases/download/v{ver}/protoc-{ver}-{system}-{machine}.zip" + ) + if r.status_code == 404: + # assume this means the architecture is not available + continue + else: + r.raise_for_status() + + with zipfile.ZipFile(io.BytesIO(r.content)) as z: + for name in z.namelist(): + if ".." in name: + continue + if name.startswith("bin/") or name.startswith("include/"): + z.extract(name, path=args.prefix) + + # Make sure the protoc binary is executable + os.chmod(os.path.join(args.prefix, "bin", "protoc"), 0o755) diff --git a/.github/container/nsys_jax/nsys_jax/scripts/nsys_jax.py b/.github/container/nsys_jax/nsys_jax/scripts/nsys_jax.py new file mode 100644 index 000000000..acc8ad19d --- /dev/null +++ b/.github/container/nsys_jax/nsys_jax/scripts/nsys_jax.py @@ -0,0 +1,801 @@ +import argparse +from concurrent.futures import FIRST_EXCEPTION, ThreadPoolExecutor, wait +from contextlib import contextmanager +from glob import glob, iglob +import lzma +import os +import os.path as osp +import pandas as pd # type: ignore +import pathlib +import queue +import re +import shlex +import shutil +import sqlite3 +import subprocess +import sys +import tempfile +import time +import traceback +import zipfile + +from .utils import execute_analysis_script, shuffle_analysis_arg +from ..version import __sha__ as jax_toolbox_sha_with_prefix + + +# Expand %q{ENV_VAR} if the variable is defined. +def expand(string: str, skip_missing=True) -> str: + missing = set() + + def rep(x): + if len(x.group(1)) % 2 == 0: + return x.group(0) + if x.group(2) not in os.environ: + missing.add(x.group(2)) + return x.group(0) + return x.group(1)[:-1] + os.environ[x.group(2)] + + expanded = re.sub(r"([%]+)q\{(.*?)\}", rep, string).replace("%%", "%") + if not skip_missing and missing: + raise Exception(f"{missing} not defined when expanding '{string}'") + return expanded + + +# Use deflate compression +COMPRESS_DEFLATE = {"compress_type": zipfile.ZIP_DEFLATED} +# Do not compress (if the file is already compressed) +COMPRESS_NONE: dict[str, int] = {} + +install_script_template = r"""#!/bin/bash +# +# Usage: ./install.sh [optional arguments to virtualenv] +# +# If it doesn't already exist, this creates a virtual environment named +# `nsys_jax_env` in the current directory and installs Jupyter Lab and the +# dependencies of the Analysis.ipynb notebook that is shipped alongside this +# script inside the output archives of the `nsys-jax` wrapper. +# +# The expectation is that those archives will be copied and extracted on a +# laptop or workstation, and this installation script will be run there, while +# the `nsys-jax` wrapper is executed on a remote GPU cluster. +set -ex +SCRIPT_DIR=$(cd -- "$( dirname -- "${{BASH_SOURCE[0]}}" )" &> /dev/null && pwd) +VIRTUALENV="${{SCRIPT_DIR}}/nsys_jax_venv" +BIN="${{VIRTUALENV}}/bin" +if [[ ! -d "${{VIRTUALENV}}" ]]; then + # Let `virtualenv` find/choose a Python. Currently >=3.10 is supported. + virtualenv -p 3.13 -p 3.12 -p 3.11 -p 3.10 "$@" "${{VIRTUALENV}}" + "${{BIN}}/pip" install -U pip + "${{BIN}}/pip" install 'nsys-jax[jupyter] @ git+https://github.com/NVIDIA/JAX-Toolbox.git@{jax_toolbox_commit}#subdirectory=.github/container/nsys_jax' + "${{BIN}}/install-flamegraph" "${{VIRTUALENV}}" + "${{BIN}}/install-protoc" "${{VIRTUALENV}}" +else + echo "Virtual environment already exists, not installing anything..." +fi +# Pick up the current profile data by default +export NSYS_JAX_DEFAULT_PREFIX="${{PWD}}" +# https://setuptools.pypa.io/en/latest/userguide/datafiles.html#accessing-data-files-at-runtime +NOTEBOOK=$("${{BIN}}/python" -c 'from importlib.resources import files; print(files("nsys_jax") / "analyses" / "Analysis.ipynb")') +if [ -z ${{NSYS_JAX_IPYTHON_NOT_JUPYTER_LAB+x}} ]; then + CMD="${{BIN}}/jupyter-lab" +else + CMD="${{BIN}}/ipython" +fi +echo "Launching: cd ${{SCRIPT_DIR}} && ${{CMD}} ${{NOTEBOOK}}" +cd "${{SCRIPT_DIR}}" && "${{CMD}}" "${{NOTEBOOK}}" +""" + + +def create_install_script(output_queue): + """ + Write an install.sh to the output archive that installs nsys-jax at the same + version/commit that the current execution is using. + """ + # setuptools_scm produces a shortened sha with a `g` prefix (for git) + jax_toolbox_sha = jax_toolbox_sha_with_prefix[1:] + install_script = install_script_template.format(jax_toolbox_commit=jax_toolbox_sha) + output_queue.put(("install.sh", install_script.encode(), COMPRESS_DEFLATE)) + + +def main() -> None: + """ + Entrypoint for nsys-jax + """ + # Wrapper-specific arguments. This also handles -h and --help. + parser = argparse.ArgumentParser( + allow_abbrev=False, + usage=( + "nsys-jax [-h] [--nsys-jax-condition EXPRESSION] [--nsys-jax-analysis A1 " + "[--nsys-jax-analysis-arg=A1_ARG1 [--nsys-jax-analysis-arg=A1_ARG2 ...]] " + "[--nsys-jax-analysis A2 [--nsys-jax-analysis-arg=A2_ARG1 ...]] [-o OUTPUT | " + "--output OUTPUT] [-f | --force-overwrite] [nsys profile arguments ...] [--] " + "executable [executable arguments ...]" + ), + description=( + "`nsys-jax` is a wrapper for `nsys profile` that collects additional metadata " + "that are specific to JAX and XLA, post-processes the profile data, and " + "produces a compressed .zip archive containing the relevant files." + ), + epilog=( + "NOTE: if the executable arguments include a literal `--` then the optional " + "`--` shown in the usage message MUST be passed to disambiguate. This is also " + "required when extra nsys profile arguments are passed." + ), + ) + parser.add_argument( + "--nsys-jax-analysis", + action="append", + dest="analysis", + help=( + "Post-processing analysis script to execute after report collection. This can " + "be the name of a bundled recipe, or the path to a Python script. The script " + "will be passed any arguments specified via --nsys-jax-analysis-arg, followed " + "by a single positional argument, which is the path to a directory of the " + "same structure as the extracted output archive." + ), + type=lambda x: ("script", x), + ) + parser.add_argument( + "--nsys-jax-analysis-arg", + action="append", + dest="analysis", + help="Extra arguments to pass to analysis scripts specified via --nsys-jax-analysis", + type=lambda x: ("arg", x), + ) + parser.add_argument( + "--nsys-jax-condition", + help=( + "Bash expression that will be expanded to determine if this instance " + "of nsys-jax should actually launch nsys. Example: " + "--nsys-jax-condition='$SLURM_LOCALID == 0' to only profile the first " + "process on each node. The expression is evaluated inside [[ ... ]]." + ), + ) + parser.add_argument( + "-f", + "--force-overwrite", + action="store_true", + help="This must be passed for nsys-jax to overwrite an existing output archive.", + ) + parser.add_argument( + "-o", + "--output", + help=( + "Output filename, if this has an .nsys-rep or .zip suffix it will be removed " + "to yield ROOT, and the output archive will be ROOT.zip, which will contain a " + "ROOT.nsys-rep." + ), + ) + + nsys_jax_flags, unknown_args = parser.parse_known_args(sys.argv) + nsys_jax_flags.analysis = shuffle_analysis_arg(nsys_jax_flags.analysis) + # Remove the name of the nsys-jax wrapper + nsys_flags_and_cmd = unknown_args[1:] + # This can have two forms: + # exe [exe args ...] + # [nsys args ...] -- exe [exe args ...] + # where the second one must be used if `exe args` contains `--`, even if no nsys args + # are passed. + try: + limit = nsys_flags_and_cmd.index("--") + nsys_flags = nsys_flags_and_cmd[:limit] + application = nsys_flags_and_cmd[limit + 1 :] + except ValueError: + # No --, everything is the application + nsys_flags = [] + application = nsys_flags_and_cmd + + if len(application) == 0: + parser.print_help() + raise Exception("No application to profile") + + if shutil.which(application[0]) is None: + parser.print_help() + raise Exception(f"{application[0]} not found by shutil.which") + + enable_profiling = True + if nsys_jax_flags.nsys_jax_condition is not None: + enable_profiling = ( + subprocess.run( + ["/bin/bash", "-c", f"[[ {nsys_jax_flags.nsys_jax_condition} ]]"], + shell=False, + ).returncode + == 0 + ) + + if nsys_jax_flags.output is None: + # There was not an explicit output location; generate one. There may be + # multiple processes racing to do this. + archive_handle, archive_name = tempfile.mkstemp( + dir=os.getcwd(), prefix="nsys-jax-report-", suffix=".zip" + ) + # Re-open it based on name later, mkstemp is just a way of avoiding races + os.close(archive_handle) + # No -f / --force-overwrite needed in this case + archive_name_can_be_overwritten = True + else: + # Explicit output location was given in `nsys_jax_flags.output`, transform that + # into the .zip-suffixed verison of it. + archive_name = ( + expand(nsys_jax_flags.output.removesuffix(".nsys-rep").removesuffix(".zip")) + + ".zip" + ) + archive_name_can_be_overwritten = nsys_jax_flags.force_overwrite + + # We will write /final/output/path/name.zip, and it will contain name.nsys-rep, + # but we do not instruct nsys to write that to /final/output/path/name.nsys-rep + # so that more of the processing can happen on a faster, more local filesystem. + report_name = osp.basename(archive_name).removesuffix(".zip") + ".nsys-rep" + tmp_dir = tempfile.mkdtemp() + tmp_rep = osp.join(tmp_dir, report_name) + nsys_flags += ["--output", tmp_rep] + + # If --nsys-jax-analysis is used, we also construct a local directory mirroring + # the extracted archive structure. TODO: clean this up + mirror_dir = None if len(nsys_jax_flags.analysis) == 0 else tempfile.mkdtemp() + + def override_nsys_default(arg, value): + if any(x.startswith(f"--{arg}=") for x in nsys_flags): + return + nsys_flags.append(f"--{arg}={value}") + + # Override some Nsight Systems defaults, but don't block setting them explicitly. + override_nsys_default("cuda-graph-trace", "node") + override_nsys_default("cpuctxsw", "none") + override_nsys_default("python-sampling", "true") + # TODO: consider dropping osrt from here + override_nsys_default("trace", "cublas,cuda,cudnn,cusolver,nvtx,osrt") + + # Modified environment in which to run the application + env = os.environ.copy() + + # Stop stack traces from being truncated in the metadata passed to XLA unless + # the option was explicitly set. + if "JAX_TRACEBACK_IN_LOCATIONS_LIMIT" not in env: + env["JAX_TRACEBACK_IN_LOCATIONS_LIMIT"] = "-1" + + # Disable the compilation cache so that we get the full set of .pb files + if "JAX_ENABLE_COMPILATION_CACHE" not in env: + env["JAX_ENABLE_COMPILATION_CACHE"] = "false" + + # Get the existing XLA_FLAGS and parse them into a dictionary. + xla_flag_list = shlex.split(env.get("XLA_FLAGS", "")) + xla_flags = {} + for flag in xla_flag_list: + assert flag.startswith("--") + bits = flag[2:].split("=", maxsplit=1) + name, value = bits[0], bits[1] if len(bits) > 1 else None + assert name not in xla_flags + xla_flags[name] = value + + def as_list(flags): + return [f"--{n}" if v is None else f"--{n}={v}" for n, v in flags.items()] + + assert xla_flag_list == as_list(xla_flags) + + def as_bool(s): + """String -> bool conversion following XLA's semantics.""" + if s.lower() == "true" or s == "1": + return True + if s.lower() == "false" or s == "0": + return False + raise Exception("Could not convert '{}' to bool".format(s)) + + # Enable dumping protobufs unless it was explicitly disabled + if "xla_dump_hlo_as_proto" not in xla_flags: + xla_flags["xla_dump_hlo_as_proto"] = "true" + + proto_dump_enabled = as_bool(xla_flags["xla_dump_hlo_as_proto"]) + + # For simplicity, impose our directory structure on the dump from XLA + if proto_dump_enabled: + if "xla_dump_to" in xla_flags: + print(f"WARNING: --xla_dump_to={xla_flags['xla_dump_to']} being overriden") + xla_flags["xla_dump_to"] = osp.join(tmp_dir, "dump") + else: + print("WARNING: protobuf dump explicitly disabled, things will break") + + # Serialise the modified XLA flags. shlex.join is tempting, but doesn't seem to + # get the right result for --xla_dump_hlo_pass_re=.*, as it adds extra quotes. + env["XLA_FLAGS"] = " ".join(as_list(xla_flags)) + + # Run the application in nsys + # TODO: consider being more fault-tolerant? + # The Nsight Systems command prefix + nsys = [ + "nsys", + "profile", + ] + nsys_flags + subprocess.run( + (nsys if enable_profiling else []) + application, check=True, env=env + ) + + # If we skipped profiling the application, there is nothing more to be done. + if not enable_profiling: + sys.exit(0) + + # Check the output report was written and is new + if not osp.exists(tmp_rep): + raise Exception(f"Could not find output file: {tmp_rep}") + + def copy_proto_files_to_tmp( + tmp_dir, xla_dir=os.environ.get("SRC_PATH_XLA", "/opt/xla") + ): + """ + Copy .proto files from XLA into a temporary directory under `tmp_dir`. + + TODO: install .proto files as part of `jaxlib`, so this can work without + the XLA sources being available under `xla_dir` e.g. as part of a + generic `pip` installation of JAX. + + Returns: (name of temporary directory, list of relative .proto paths) + """ + start = time.time() + proto_dir = osp.join(tmp_dir, "protos") + tsl_dir = osp.join(xla_dir, "third_party", "tsl") + proto_files = [] + for p, root in [("tsl/**/*.proto", tsl_dir), ("xla/**/*.proto", xla_dir)]: + for proto in iglob(p, recursive=True, root_dir=root): + proto_files.append(proto) + dst_dir = osp.join(proto_dir, osp.dirname(proto)) + if not osp.isdir(dst_dir): + os.makedirs(dst_dir) + shutil.copy(osp.join(root, proto), osp.join(proto_dir, proto)) + print(f"{archive_name}: gathered .proto files in {time.time()-start:.2f}s") + return proto_dir, proto_files + + def run_nsys_recipe(recipe, report_file, tmp_dir, output_queue): + """ + Post-process a .nsys-rep file into a .parquet file for offline analysis. + This is currently implemented using the given nsys recipe. + """ + start = time.time() + recipe_output = osp.join(tmp_dir, recipe) + subprocess.run( + [ + "nsys", + "recipe", + recipe, + "--input", + report_file, + "--output", + recipe_output, + ], + check=True, + ) + for ofile in iglob(recipe + "/**", recursive=True, root_dir=tmp_dir): + full_path = osp.join(tmp_dir, ofile) + # glob("/does-not-exist/**", recursive=True) == ['/does-not-exist/'] + if osp.isdir(full_path) or not osp.exists(full_path): + continue + output_queue.put((ofile, full_path, COMPRESS_NONE)) + print(f"{archive_name}: post-processing finished in {time.time()-start:.2f}s") + + def compress_and_archive(prefix, file, output_queue): + """ + Read prefix+file, compress it, queue the compressed bytes for archival + without further compression. + """ + with open(osp.join(prefix, file), "rb") as ifile: + output_queue.put((file + ".xz", lzma.compress(ifile.read()), COMPRESS_NONE)) + + def run_nsys_stats_report(report, report_file, tmp_dir, output_queue): + """ + Run a stats recipe on an .nsys-rep file (that has probably already been + exported to .sqlite). + """ + start = time.time() + subprocess.run( + [ + "nsys", + "stats", + "--report", + report, + "--input", + report_file, + # avoid race conditions with other reports/etc. + "--sqlite", + osp.splitext(report_file)[0] + "-" + report + ".sqlite", + "--output", + osp.join(tmp_dir, "report"), + ], + check=True, + ) + for ofile in iglob("report_" + report + ".csv", root_dir=tmp_dir): + compress_and_archive(tmp_dir, ofile, output_queue) + print(f"{archive_name}: post-processing finished in {time.time()-start:.2f}s") + + def save_device_stream_thread_names(tmp_dir, report, output_queue): + """ + Extract extra information from the SQLite dump that is needed to map projected NVTX + ranges to global device IDs. + """ + start = time.time() + assert report.endswith(".nsys-rep"), f"{report} had an unexpected suffix" + db_file = report.removesuffix(".nsys-rep") + "-metadata.sqlite" + subprocess.run( + [ + "nsys", + "export", + "--type", + "sqlite", + "--tables", + "StringIds,TARGET_INFO_GPU,TARGET_INFO_NVTX_CUDA_DEVICE,TARGET_INFO_SYSTEM_ENV,ThreadNames", + "--output", + db_file, + report, + ], + check=True, + ) + assert os.path.exists(db_file) + con = sqlite3.connect(db_file) + cur = con.cursor() + + def table_to_parquet(query, index, filename, columns=None, index_name=None): + res = cur.execute(query) + if columns is None: + columns = [x[0] for x in res.description] + df = pd.DataFrame(res, columns=columns).set_index( + index, verify_integrity=True + ) + if index_name is not None: + df.index.name = index_name + df.to_parquet(osp.join(tmp_dir, filename)) + output_queue.put((filename, osp.join(tmp_dir, filename), COMPRESS_NONE)) + + # Extract {(pid, tid): (name, priority)} map; PID/TID arithmetic comes from + # https://docs.nvidia.com/nsight-systems/UserGuide/index.html#common-sqlite-examples + table_to_parquet( + r""" + SELECT + StringIds.value AS Name, + ThreadNames.priority AS Priority, + ThreadNames.globalTid / 0x1000000 % 0x1000000 AS PID, + ThreadNames.globalTid % 0x1000000 AS TID + FROM ThreadNames + INNER JOIN StringIds ON ThreadNames.nameId=StringIds.id""", + ["PID", "TID"], + "thread-metadata.parquet", + ) + # Extract high level metadata about the profiling session, including the hostname + table_to_parquet( + "SELECT name, nameEnum, value FROM TARGET_INFO_SYSTEM_ENV", + "nameEnum", + "system-metadata.parquet", + ) + + def table_exists(table_name): + return ( + cur.execute( + f"SELECT 1 FROM sqlite_master WHERE type='table' AND name='{table_name}'" + ).fetchall() + != [] + ) + + # Cannot write device-metadata.parquet if no device activity was profiled. + if table_exists("TARGET_INFO_GPU") and table_exists( + "TARGET_INFO_NVTX_CUDA_DEVICE" + ): + # Extract {device_id: metadata_and_name} map, making sure to pick up the name that + # XLA assigns via NVTX + def table_columns(table_name): + return [ + (table_name, x[0]) + for x in cur.execute( + f"SELECT * FROM {table_name} LIMIT 1" + ).description + ] + + table_to_parquet( + """ + SELECT * FROM TARGET_INFO_GPU + INNER JOIN TARGET_INFO_NVTX_CUDA_DEVICE ON TARGET_INFO_GPU.cuDevice = TARGET_INFO_NVTX_CUDA_DEVICE.deviceId""", + ("TARGET_INFO_GPU", "cuDevice"), + "device-metadata.parquet", + columns=pd.MultiIndex.from_tuples( + table_columns("TARGET_INFO_GPU") + + table_columns("TARGET_INFO_NVTX_CUDA_DEVICE") + ), + index_name="cuDevice", + ) + else: + print("WARNING: NOT writing device metadata, no device activity profiled?") + print( + f"{archive_name}: extracted device/thread names in {time.time()-start:.2f}s" + ) + + def find_pb_files_in_tmp(tmp_dir): + """ + Return a prefix + list of relative paths to Protobuf files dumped by XLA. + """ + return tmp_dir, glob("dump/*.pb", root_dir=tmp_dir) + glob( + "dump/*.pbtxt", root_dir=tmp_dir + ) + + def gather_source_files( + proto_dir, proto_files, pb_file_prefix, pb_file_list, output_queue + ): + """ + Given a directory containing the required .proto files (`proto_dir`) and a + prefix (`pb_file_prefix`) and list of relative paths to .pb files + (`pb_file_list`), extract a list of source code files referred to by the + XLA metadata and embed those source code files in the output archive. + """ + start = time.time() + # .hlo.pb are used to gather source code to be embedded + hlo_pb_files = [ + osp.join(pb_file_prefix, x) for x in pb_file_list if x.endswith(".hlo.pb") + ] + with tempfile.TemporaryDirectory() as tmp_dir: + # Compile the .proto files + subprocess.run( + ["protoc", f"-I={proto_dir}", f"--python_out={tmp_dir}"] + proto_files, + check=True, + cwd=proto_dir, + ) + # Collect the set of referenced source files + sys.path.insert(0, tmp_dir) + from xla.service import hlo_pb2 + + hlo = hlo_pb2.HloProto() + src_files = set() + for hlo_pb_file in hlo_pb_files: + with open(hlo_pb_file, "rb") as f: + hlo.ParseFromString(f.read()) + src_files |= set(hlo.hlo_module.stack_frame_index.file_names) + sys.path.remove(tmp_dir) + if len(src_files) == 0: + print("WARNING: no source files were gathered") + # Copy these files into the output archive. + for src_file in src_files: + if src_file == "": + # This can appear due to python -c "...", for example. + continue + assert osp.isabs(src_file), f"{src_file} is not absolute" + output_queue.put(("sources" + src_file, src_file, COMPRESS_DEFLATE)) + print(f"{archive_name}: gathered source code in {time.time()-start:.2f}s") + + def execute_analysis_scripts(mirror_dir, analysis_scripts): + """ + Execute any post-processing scripts passed via --nsys-jax-analysis, + returning a list of output files that should be added to the output + archive. + """ + if len(analysis_scripts) == 0: + return [], 0 + + assert mirror_dir is not None + output = [] + exit_code = 0 + mirror_dir = pathlib.Path(mirror_dir) + for analysis in analysis_scripts: + result, output_prefix = execute_analysis_script( + data=mirror_dir, script=analysis[0], args=analysis[1:] + ) + if result.returncode != 0: + exit_code = result.returncode + # Gather output files of the scrpt + for path in iglob( + "**", recursive=True, root_dir=osp.join(mirror_dir, output_prefix) + ): + output.append( + ( + osp.join(output_prefix, path), + osp.join(mirror_dir, output_prefix, path), + ) + ) + + return output, exit_code + + def write_output_file(to_process, mirror_dir, analysis_scripts): + """ + Write the output archive (`archive_name`) by consuming entries from the + queue until a `None` sentinel value is seen. If `mirror_dir` is not None + then populate it with symlinks/files as necessary to create a structure + equivalent to the output archive. + """ + start = time.time() + with zipfile.ZipFile( + archive_name, "w" if archive_name_can_be_overwritten else "x" + ) as archive: + while True: + timeout = 30 + try: + item = to_process.get(timeout=timeout) + to_process.task_done() + if item is None: + # This is the sentinel value instructing us to exit. + assert to_process.empty() + break + path_in_archive, content, kwargs = item + mirror_path = None + if mirror_dir is not None: + mirror_path = osp.join(mirror_dir, path_in_archive) + os.makedirs(osp.dirname(mirror_path), exist_ok=True) + if isinstance(content, bytes): + archive.writestr(path_in_archive, content, **kwargs) + if mirror_path is not None: + with open(mirror_path, "wb") as mfile: + mfile.write(content) + else: + archive.write(content, arcname=path_in_archive, **kwargs) + if mirror_path is not None: + os.symlink(content, mirror_path) + except queue.Empty: + print(f"{archive_name}: output stalled ({timeout}s heartbeat)") + # Execute analysis scripts so their outputs can be bundled in the archive + # before it is closed + analysis_outputs, exit_code = execute_analysis_scripts( + mirror_dir, analysis_scripts + ) + for path_in_archive, local_path in analysis_outputs: + archive.write(filename=local_path, arcname=path_in_archive) + os.chmod(archive_name, 0o644) + print(f"{archive_name}: wrote in {time.time()-start:.2f}s") + if exit_code != 0: + print("Exiting due to analysis script errors") + sys.exit(exit_code) + + def process_pb_files(pb_future): + """ + Queue .pb and .pbtxt files for inclusion in the output archive. + """ + pb_file_prefix, pb_file_list = pb_future.result() + for pb_file in pb_file_list: + futures.append( + executor.submit( + compress_and_archive, pb_file_prefix, pb_file, files_to_archive + ) + ) + + def process_pb_and_proto_files(pb_future, proto_future, output_queue, futures): + """ + Queue .proto files for inclusion in the output archive and trigger + gathering source code files once .pb/.pbtxt/.proto files are available. + """ + # Block for completion of copy_proto_files_to_tmp + proto_dir, proto_files = proto_future.result() + # Queue them for inclusion in the output archive + for proto_file in proto_files: + output_queue.put( + ( + osp.join("protos", proto_file), + osp.join(proto_dir, proto_file), + COMPRESS_DEFLATE, + ) + ) + # Wait to have pb files too + pb_file_prefix, pb_file_list = pb_future.result() + # Submit work that depends on the proto directory + futures.append( + executor.submit( + gather_source_files, + proto_dir, + proto_files, + pb_file_prefix, + pb_file_list, + files_to_archive, + ) + ) + + # Orchestrate post-processing steps: + # - collect Python source files: + # - collect list of .proto files + # - copy them to a temp dir + # - extract list of Python source files from .pb/.pbtxt files using that dir + # - save those source files to the archive + # - save the .proto files in the temp dir to the archive + # - save .pb/.pbtxt files: + # - gather a list of these + # - compress them individually + # - add the compressed versions to the output archive w/out extra compression + # - save the .nsys-rep file to the output archive with compression + # - post-process the .nsys-rep + # - convert .nsys-rep -> .parquet in the temp dir with nsys recipe + # - save the .parquet file to the output archive w/out extra compression + + # Element format: (path_in_archive, Path or bytes, ZipFile.write* kwargs) + files_to_archive: queue.Queue = queue.Queue() + + @contextmanager + def output_thread(executor: ThreadPoolExecutor): + """ + Launch the output worker on context manager entry, signal that it should + exit on context manager exit. + """ + try: + # Spawn a worker to actually write the output file, consuming entries + # in output_queue. + future = executor.submit( + write_output_file, + files_to_archive, + mirror_dir, + nsys_jax_flags.analysis, + ) + yield future + finally: + # Signal via the output queue that the worker should exit. + files_to_archive.put(None) + # Make sure any errors from the output thread are surfaced + future.result() + + exit_code = 0 + with ThreadPoolExecutor() as executor, output_thread(executor): + # Track futures so we can wait on them and report errors. + futures = [] + # Queue the .nsys-rep for compression + files_to_archive.put( + ( + report_name, + tmp_rep, + COMPRESS_DEFLATE, + ) + ) + # Convert .nsys-rep -> .parquet and queue the latter for archival + futures.append( + executor.submit( + run_nsys_recipe, + "nvtx_gpu_proj_trace", + tmp_rep, + tmp_dir, + files_to_archive, + ) + ) + # Write an installation script into the archive + futures.append(executor.submit(create_install_script, files_to_archive)) + # Gather the list of .proto files + proto_future = executor.submit(copy_proto_files_to_tmp, tmp_dir) + # Gather the list of .pb[txt] files + pb_future = executor.submit(find_pb_files_in_tmp, tmp_dir) + futures.append(pb_future) + futures.append(executor.submit(process_pb_files, pb_future)) + # Wait on pb_future and proto_future and submit dependent work + futures.append( + executor.submit( + process_pb_and_proto_files, + pb_future, + proto_future, + files_to_archive, + futures, + ) + ) + futures.append( + executor.submit( + run_nsys_stats_report, + "nvtx_pushpop_trace", + tmp_rep, + tmp_dir, + files_to_archive, + ) + ) + # Do some custom post-processing of the .sqlite export generated by gpu_proj_future + futures.append( + executor.submit( + save_device_stream_thread_names, + tmp_dir, + tmp_rep, + files_to_archive, + ) + ) + # Wait for errors/completion of `futures`; note that this does not include + # the output thread, which is signaled to upon exiting from this block. + # Also note that the list of futures can still grow at this point. + retired = 0 + while True: + results = wait(futures, return_when=FIRST_EXCEPTION, timeout=30) + # Check if we exited early because of an exception and, if so, print it + # immediately. Do not abort, so even in case of errors a valid archive + # containing as much useful information as possible will be written. + retired += len(results.done) + for future in results.done: + futures.remove(future) + if future.exception() is not None: + exit_code = 1 + traceback.print_exception(future.exception()) + pending = len(futures) + if pending == 0: + break + print(f"{archive_name}: {pending}/{len(futures) + retired} pending") + if exit_code: + print(f"{archive_name}: exiting with code {exit_code} due to errors") + sys.exit(exit_code) diff --git a/.github/container/nsys_jax/nsys_jax/scripts/nsys_jax_combine.py b/.github/container/nsys_jax/nsys_jax/scripts/nsys_jax_combine.py new file mode 100644 index 000000000..5b804de00 --- /dev/null +++ b/.github/container/nsys_jax/nsys_jax/scripts/nsys_jax_combine.py @@ -0,0 +1,155 @@ +import argparse +from collections import defaultdict +import copy +import os +import pathlib +import shutil +import tempfile +import zipfile + +from .utils import execute_analysis_script, shuffle_analysis_arg + + +def main(): + """ + Entrypoint for nsys-jax-combine + """ + parser = argparse.ArgumentParser( + description=( + "`nsys-jax-combine` facilitates distributed profiling of JAX applications " + "using the `nsys-jax` wrapper. It aggregates multiple .zip outputs from " + "different `nsys-jax` processes that profiled the same distributed execution " + "of an application, checking consistency and removing duplicated data." + ), + ) + parser.add_argument( + "--analysis", + action="append", + help=( + "Post-processing analysis script to execute after merging. This can be the " + "name of a recipe bundled in the inpit files, or the path to a Python script. " + "The script will be passed any arguments specified via --analysis-arg, " + "followed by a single positional argument, which is the path to a directory " + "of the same structure as the extracted output archive." + ), + type=lambda x: ("script", x), + ) + parser.add_argument( + "--analysis-arg", + action="append", + dest="analysis", + help="Extra arguments to pass to analysis scripts specified via --analysis", + type=lambda x: ("arg", x), + ) + parser.add_argument( + "-f", + "--force-overwrite", + action="store_true", + help="Overwrite the output file if it exists.", + ) + parser.add_argument( + "input", + type=pathlib.Path, + nargs="+", + help="Input .zip archives produced by `nsys-jax`", + ) + + def check_keep_nsys_rep(raw): + assert raw in {"all", "first", "none"} + return raw + + parser.add_argument( + "--keep-nsys-rep", + default="first", + type=check_keep_nsys_rep, + help=( + "How many .nsys-rep files from the input to copy to the output. Supported " + "values are 'all', 'first' and 'none'." + ), + ) + parser.add_argument( + "-o", + "--output", + help="Output file name", + required=True, + type=pathlib.Path, + ) + # TODO: derive a default output path from the input paths + args = parser.parse_args() + args.analysis = shuffle_analysis_arg(args.analysis) + if args.output.suffix != ".zip": + args.output = args.output.with_suffix(".zip") + if os.path.exists(args.output) and not args.force_overwrite: + raise Exception( + f"Output path {args.output} already exists and -f/--force-overwrite was not passed" + ) + + hashes = defaultdict(set) + for input in args.input: + with zipfile.ZipFile(input) as ifile: + for member in ifile.infolist(): + hashes[member.filename].add(member.CRC) + + mirror_dir = pathlib.Path(tempfile.mkdtemp()) if len(args.analysis) else None + with zipfile.ZipFile(args.output, "w") as ofile: + for n_input, input in enumerate(args.input): + first_input = n_input == 0 + keep_this_nsys_rep = args.keep_nsys_rep == "all" or ( + args.keep_nsys_rep == "first" and first_input + ) + with zipfile.ZipFile(input) as ifile: + for member in ifile.infolist(): + if member.is_dir(): + continue + filename = member.filename + assert filename in hashes + seen_hashes = hashes[filename] + + def write(dst_info): + assert dst_info.filename not in set(ofile.namelist()) + with ifile.open(member) as src: + with ofile.open(dst_info, "w") as dst: + shutil.copyfileobj(src, dst) + if mirror_dir is not None: + dst_path = mirror_dir / dst_info.filename + os.makedirs(dst_path.parent, exist_ok=True) + src.seek(0) + with open(dst_path, "wb") as dst: + shutil.copyfileobj(src, dst) + + if filename.endswith(".nsys-rep"): + assert len(seen_hashes) == 1 + if filename == input.stem + ".nsys-rep" and keep_this_nsys_rep: + # `filename`` is the .nsys-rep from `input`` + write(member) + else: + if len(seen_hashes) == 1: + # This file was the same in all inputs: copy it once. + if first_input: + write(member) + else: + # This file was not the same in all inputs: copy it to a + # modified destination. An input file A/B in reportN.zip will + # be saved as A/B/reportN in the output, i.e. A/B will be a + # directory instead of a file. TODO: in future instead of using + # input.stem use a standardised format showing the device + # numbers that were profiled in reportN.zip. + dst_info = copy.copy(member) + dst_info.filename = filename + "/" + input.stem + write(dst_info) + if len(args.analysis): + assert mirror_dir is not None + # Execute post-processing recipes and add any outputs to `ofile` + for analysis in args.analysis: + result, output_prefix = execute_analysis_script( + data=mirror_dir, script=analysis[0], args=analysis[1:] + ) + result.check_returncode() + # Gather output files of the scrpt + for path in (mirror_dir / output_prefix).rglob("*"): + with ( + open(mirror_dir / output_prefix / path, "rb") as src, + ofile.open(str(path.relative_to(mirror_dir)), "w") as dst, + ): + # https://github.com/python/mypy/issues/15031 ? + shutil.copyfileobj(src, dst) # type: ignore diff --git a/.github/container/nsys-2024.5-tid-export.patch b/.github/container/nsys_jax/nsys_jax/scripts/patch_nsys.py similarity index 50% rename from .github/container/nsys-2024.5-tid-export.patch rename to .github/container/nsys_jax/nsys_jax/scripts/patch_nsys.py index f19d35e27..8b732283c 100644 --- a/.github/container/nsys-2024.5-tid-export.patch +++ b/.github/container/nsys_jax/nsys_jax/scripts/patch_nsys.py @@ -1,4 +1,9 @@ -diff --git a/nsys_recipe/lib/nvtx.py b/nsys_recipe/lib/nvtx.py +import os +import re +import shutil +import subprocess + +patch_content = r"""diff --git a/nsys_recipe/lib/nvtx.py b/nsys_recipe/lib/nvtx.py index 2470043..7abf892 100644 --- a/nsys_recipe/lib/nvtx.py +++ b/nsys_recipe/lib/nvtx.py @@ -22,3 +27,28 @@ "stackLevel": "Stack Level", "childrenCount": "Children Count", "rangeId": "Range ID", +""" + + +def main(): + """ + Entrypoint for nsys-jax-patch-nsys. + """ + nsys = shutil.which("nsys") + assert nsys is not None, "nsys-jax-patch-nsys expects nsys to be installed" + nsys_version = subprocess.check_output([nsys, "--version"], text=True) + m = re.match( + r"^NVIDIA Nsight Systems version (\d+\.\d+\.\d+)\.\d+-\d+v\d+$", nsys_version + ) + assert m is not None, f"Could not parse: {nsys_version}" + if m.group(1) in {"2024.5.1", "2024.6.1"}: + print(f"Patching Nsight Systems version {m.group(1)}") + # e.g. /opt/nvidia/nsight-systems-cli/2024.7.1/target-linux-x64 + tdir = os.path.dirname(os.path.realpath(nsys)) + subprocess.run( + [shutil.which("git"), "apply"], + cwd=os.path.join(tdir, "python", "packages"), + input=patch_content, + check=True, + text=True, + ) diff --git a/.github/container/nsys_jax/nsys_jax/scripts/utils.py b/.github/container/nsys_jax/nsys_jax/scripts/utils.py new file mode 100644 index 000000000..e13845f17 --- /dev/null +++ b/.github/container/nsys_jax/nsys_jax/scripts/utils.py @@ -0,0 +1,77 @@ +import contextlib +import importlib.resources +import os +import pathlib +import shlex +import subprocess +import sys + + +def shuffle_analysis_arg(analysis): + """ + Helper for parsing --nsys-jax-analysis[-arg] (nsys-jax) and --analysis[-arg] + (nsys-jax-combine) command line options. + """ + if analysis is None: + return [] + # [Script(A), Arg(A1), Arg(A2), Script(B), Arg(B1)] becomes [[A, A1, A2], [B, B1]] + out, current = [], [] + for t, x in analysis: + if t == "script": + if len(current): + out.append(current) + current = [x] + else: + assert t == "arg" and len(current) + current.append(x) + if len(current): + out.append(current) + return out + + +def analysis_recipe_path(script): + """ + Return a context manager that yields the path to the analysis script named by + `script`. This can either be the name of a bundled analysis script from the + the installed analyses/ directory, or a filesystem path. + """ + script_file = importlib.resources.files("nsys_jax").joinpath( + "analyses", script + ".py" + ) + if script_file.is_file(): + return script_file + assert os.path.exists( + script + ), f"{script} does not exist and is not the name of a built-in analysis script" + return contextlib.nullcontext(pathlib.Path(script)) + + +def execute_analysis_script( + *, data: pathlib.Path, script: str, args: list[str] +) -> tuple[subprocess.CompletedProcess, pathlib.Path]: + """ + Run the analysis script named by `script` on the profile data in the `data` + directory (structure the same as nsys-jax[-combine] output archives), saving any + output files to subdirectory of `data` named `output_prefix`. + """ + with analysis_recipe_path(script) as script_path: + analysis_command = [sys.executable, str(script_path)] + args + [str(data)] + + # Derive a unique name slug from the analysis script name + def with_suffix(suffix): + return data / "analysis" / (script_path.stem + suffix) + + n, suffix = 1, "" + while with_suffix(suffix).exists(): + suffix = f"-{n}" + n += 1 + working_dir = with_suffix(suffix) + working_dir.mkdir(parents=True) + print( + f"Running analysis script: {shlex.join(analysis_command)} in {working_dir}" + ) + result = subprocess.run( + analysis_command, + cwd=working_dir, + ) + return result, working_dir.relative_to(data) diff --git a/.github/container/jax_nsys/python/jax_nsys/jax_nsys/utils.py b/.github/container/nsys_jax/nsys_jax/utils.py similarity index 90% rename from .github/container/jax_nsys/python/jax_nsys/jax_nsys/utils.py rename to .github/container/nsys_jax/nsys_jax/utils.py index 01b9c9ca0..c9d071a2d 100644 --- a/.github/container/jax_nsys/python/jax_nsys/jax_nsys/utils.py +++ b/.github/container/nsys_jax/nsys_jax/utils.py @@ -1,10 +1,20 @@ from dataclasses import dataclass +import os import pandas as pd # type: ignore +import pathlib from typing import Optional pd.options.mode.copy_on_write = True +def default_data_prefix() -> pathlib.Path: + """ + Default path for profiler data. This is particularly useful for Jupyter notebooks, + which make it awkward to arrange for a sensible default working directory. + """ + return pathlib.Path(os.environ.get("NSYS_JAX_DEFAULT_PREFIX", ".")) + + @dataclass class ProfilerData: """ diff --git a/.github/container/jax_nsys/python/jax_nsys/jax_nsys/visualization.py b/.github/container/nsys_jax/nsys_jax/visualization.py similarity index 100% rename from .github/container/jax_nsys/python/jax_nsys/jax_nsys/visualization.py rename to .github/container/nsys_jax/nsys_jax/visualization.py diff --git a/.github/container/nsys_jax/pyproject.toml b/.github/container/nsys_jax/pyproject.toml new file mode 100644 index 000000000..95bdffd4c --- /dev/null +++ b/.github/container/nsys_jax/pyproject.toml @@ -0,0 +1,44 @@ +[project] +name = "nsys-jax" +dynamic = ["version"] +dependencies = [ + "ipython", + "numpy", + "pandas", + "protobuf", # a compatible version of protoc needs to be installed out-of-band + "pyarrow", + "requests", # for install-protoc + "uncertainties", # communication analysis recipe +] +requires-python = ">= 3.10" + +[build-system] +requires = ["setuptools>=64", "setuptools_scm>=8"] +build-backend = "setuptools.build_meta" + +[project.optional-dependencies] +jupyter = [ + "jupyterlab", + "matplotlib" +] +test = [ + "pytest" +] + +[project.scripts] +install-flamegraph = "nsys_jax.scripts.install_flamegraph:main" +install-protoc = "nsys_jax.scripts.install_protoc:main" +nsys-jax = "nsys_jax.scripts.nsys_jax:main" +nsys-jax-combine = "nsys_jax.scripts.nsys_jax_combine:main" +nsys-jax-patch-nsys = "nsys_jax.scripts.patch_nsys:main" + +[tool.setuptools_scm] +root = "../../.." # .github/container/nsys_jax +# written into the git checkout in case of an editable installation +version_file = "nsys_jax/version.py" +# __sha__ is not written by default +version_file_template = """\ +__version__ = version = {version!r} +__version_tuple__ = version_tuple = {version_tuple!r} +__sha__ = {scm_version.node!r} +""" diff --git a/.github/container/jax_nsys_tests/example_program.py b/.github/container/nsys_jax/tests/example_program.py similarity index 100% rename from .github/container/jax_nsys_tests/example_program.py rename to .github/container/nsys_jax/tests/example_program.py diff --git a/.github/container/jax_nsys_tests/jax_nsys_test_helpers/__init__.py b/.github/container/nsys_jax/tests/nsys_jax_test_helpers/__init__.py similarity index 100% rename from .github/container/jax_nsys_tests/jax_nsys_test_helpers/__init__.py rename to .github/container/nsys_jax/tests/nsys_jax_test_helpers/__init__.py diff --git a/.github/container/jax_nsys_tests/test_basics.py b/.github/container/nsys_jax/tests/test_basics.py similarity index 93% rename from .github/container/jax_nsys_tests/test_basics.py rename to .github/container/nsys_jax/tests/test_basics.py index f9a39d8a0..8704e25d6 100644 --- a/.github/container/jax_nsys_tests/test_basics.py +++ b/.github/container/nsys_jax/tests/test_basics.py @@ -4,10 +4,10 @@ import tempfile import zipfile -helper_dir = os.path.join(os.path.dirname(__file__), "jax_nsys_test_helpers") +helper_dir = os.path.join(os.path.dirname(__file__), "nsys_jax_test_helpers") if helper_dir not in sys.path: sys.path.insert(0, helper_dir) -from jax_nsys_test_helpers import nsys_jax # noqa: E402 +from nsys_jax_test_helpers import nsys_jax # noqa: E402 def test_program_without_gpu_activity(): diff --git a/.github/container/jax_nsys_tests/test_example_program.py b/.github/container/nsys_jax/tests/test_example_program.py similarity index 94% rename from .github/container/jax_nsys_tests/test_example_program.py rename to .github/container/nsys_jax/tests/test_example_program.py index 0461c7c7e..1866e4282 100644 --- a/.github/container/jax_nsys_tests/test_example_program.py +++ b/.github/container/nsys_jax/tests/test_example_program.py @@ -1,4 +1,4 @@ -from jax_nsys import ( +from nsys_jax import ( ensure_compiled_protos_are_importable, load_profiler_data, ) @@ -9,10 +9,10 @@ import tempfile import zipfile -helper_dir = os.path.join(os.path.dirname(__file__), "jax_nsys_test_helpers") +helper_dir = os.path.join(os.path.dirname(__file__), "nsys_jax_test_helpers") if helper_dir not in sys.path: sys.path.insert(0, helper_dir) -from jax_nsys_test_helpers import nsys_jax # noqa: E402 +from nsys_jax_test_helpers import nsys_jax # noqa: E402 @pytest.fixture(scope="module") diff --git a/.github/container/pip-finalize.sh b/.github/container/pip-finalize.sh index 1149d7638..6d8ceac9b 100755 --- a/.github/container/pip-finalize.sh +++ b/.github/container/pip-finalize.sh @@ -51,6 +51,9 @@ pip-sync --pip-args '--no-deps --src /opt' requirements.txt rm -rf ~/.cache/* -# protobuf will be installed at least as a dependency of jax_nsys in the base -# image, but the installed version is likely to be influenced by other packages. -install-protoc /usr/local +# Execute post-install hooks +for post_install in $(ls /opt/pip-tools-post-install.d/*); do + if [[ -x "${post_install}" ]]; then + "${post_install}" + fi +done diff --git a/.github/workflows/_build_base.yaml b/.github/workflows/_build_base.yaml index b575ec14b..4a1ad84d6 100644 --- a/.github/workflows/_build_base.yaml +++ b/.github/workflows/_build_base.yaml @@ -128,10 +128,13 @@ jobs: platforms: linux/${{ inputs.ARCHITECTURE }} tags: ${{ steps.meta.outputs.tags }} labels: ${{ steps.meta.outputs.labels }} + # head_ref is the PR source branch for pull_request pipelines, which avoids + # baking in the SHA of a merge commit than cannot be checked out later build-args: | GIT_USER_NAME=${{ inputs.GIT_USER_NAME }} GIT_USER_EMAIL=${{ inputs.GIT_USER_EMAIL }} BUILD_DATE=${{ inputs.BUILD_DATE }} + JAX_TOOLBOX_REF=${{ github.head_ref || github.sha }} ${{ inputs.BASE_IMAGE != 'latest' && format('BASE_IMAGE={0}', inputs.BASE_IMAGE) || '' }} - name: Generate sitrep diff --git a/.github/workflows/_ci.yaml b/.github/workflows/_ci.yaml index c0d6b89c7..a0c6bcfda 100644 --- a/.github/workflows/_ci.yaml +++ b/.github/workflows/_ci.yaml @@ -321,8 +321,11 @@ jobs: -v $PWD:/opt/output \ ${{ needs.build-jax.outputs.DOCKER_TAG_FINAL }} \ bash <<"EOF" |& tee test-nsys-jax.log - pip install pytest-reportlog -e /opt/jax_nsys/python/jax_nsys[test] - pytest --report-log=/opt/output/pytest-report.jsonl /opt/jax_nsys_tests + # nsys-jax is already installed, this is just adding the test dependencies + pip install pytest-reportlog nsys-jax[test] + # abuse knowledge that nsys-jax is installed editable, so the tests exist + test_path=$(python -c 'import importlib.resources; print(importlib.resources.files("nsys_jax").joinpath("..", "tests").resolve())') + pytest --report-log=/opt/output/pytest-report.jsonl "${test_path}" EOF GPUS_PER_NODE=$(nvidia-smi -L | grep -c '^GPU') for mode in 1-process 2-process process-per-gpu; do @@ -400,6 +403,37 @@ jobs: *-execution-combine.log secrets: inherit + # test-nsys-jax generates several fresh .zip archive outputs by running nsys-jax with real GPU hardware; this test + # runs on a regular GitHub Actions runner and checks that offline post-processing works in an environment that does + # not already have nsys-jax installed + test-nsys-jax-archive: + needs: test-nsys-jax + if: inputs.ARCHITECTURE == 'amd64' # arm64 runners n/a + strategy: + matrix: + os: [ubuntu-22.04, ubuntu-24.04, macOS-latest] + runs-on: ${{ matrix.os }} + steps: + - name: Download nsys-jax output .zip files + uses: actions/download-artifact@v4 + with: + name: nsys-jax-unit-test-A100 + - name: Extract archives and execute install scripts + run: | + pip install virtualenv # for install.sh + for zip in $(ls *.zip); do + ZIP="${PWD}/${zip}" + pushd $(mktemp -d) + unzip "${ZIP}" + ls -l + # TODO: verify this isn't needed, or make sure it isn't needed + chmod 755 install.sh + # Run the notebook with IPython, not Jupyter Lab, so it exits and prints something informative to stdout + # Skip executing Jupyter lab + NSYS_JAX_IPYTHON_NOT_JUPYTER_LAB=1 ./install.sh + popd + done + # test-equinox: # needs: build-equinox # if: inputs.ARCHITECTURE == 'amd64' # arm64 runners n/a diff --git a/.github/workflows/nsys-jax.yaml b/.github/workflows/nsys-jax.yaml index e143334cb..e15cd557f 100644 --- a/.github/workflows/nsys-jax.yaml +++ b/.github/workflows/nsys-jax.yaml @@ -17,12 +17,13 @@ on: branches: - main +defaults: + run: + shell: bash -x -eo pipefail {0} + env: NSYS_JAX_PYTHON_FILES: | - JAX-Toolbox/.github/container/nsys-jax - JAX-Toolbox/.github/container/nsys-jax-combine - JAX-Toolbox/.github/container/jax_nsys - JAX-Toolbox/.github/container/jax_nsys_tests + JAX-Toolbox/.github/container/nsys_jax JAX-Toolbox/.github/container/jax-nccl-test jobs: @@ -40,10 +41,13 @@ jobs: with: python-version: '3.10' # jax is just a CPU-only build of the latest release for type-checking purposes - - name: "Install jax / jax_nsys / mypy" - run: pip install jax -e JAX-Toolbox/.github/container/jax_nsys/python/jax_nsys matplotlib mypy nbconvert types-protobuf + - name: "Install jax / nsys_jax / mypy" + run: pip install jax -e JAX-Toolbox/.github/container/nsys_jax matplotlib mypy nbconvert types-protobuf types-requests - name: "Install protoc" - run: ./JAX-Toolbox/.github/container/jax_nsys/install-protoc local_protoc + # TODO: this could install into the pip prefix as a default + run: | + install-protoc local/ + echo "$PWD/local/bin" >> "${GITHUB_PATH}" - name: "Fetch XLA .proto files" uses: actions/checkout@v4 with: @@ -53,21 +57,18 @@ jobs: *.proto sparse-checkout-cone-mode: false - name: "Compile .proto files" - shell: bash -x -e {0} run: | mkdir compiled_protos compiled_stubs protos mv -v xla/third_party/tsl/tsl protos/ mv -v xla/xla protos/ - PATH=${PWD}/local_protoc/bin:$PATH python -c "from jax_nsys import compile_protos; compile_protos(proto_dir='protos', output_dir='compiled_protos', output_stub_dir='compiled_stubs')" + python -c "from nsys_jax import compile_protos; compile_protos(proto_dir='protos', output_dir='compiled_protos', output_stub_dir='compiled_stubs')" touch compiled_stubs/py.typed - name: "Convert .ipynb to .py" - shell: bash -x -e {0} run: | for notebook in $(find ${NSYS_JAX_PYTHON_FILES} -name '*.ipynb'); do jupyter nbconvert --to script ${notebook} done - name: "Run mypy checks" - shell: bash -x -e {0} run: | export MYPYPATH="${PWD}/compiled_stubs" mypy --scripts-are-modules ${NSYS_JAX_PYTHON_FILES} @@ -88,21 +89,17 @@ jobs: uses: actions/setup-python@v5 with: python-version: '3.12' - # TODO: a modern nsys-jax-combine with old .zip input should probably produce a - # .zip with a modern jax_nsys/ - - name: Add modern jax_nsys/ files to static .zip inputs + - name: Install nsys-jax and dependencies run: | - cd .github/container/jax_nsys - for zip in ../../workflows/nsys-jax/test_data/pax_fsdp4_4proc_proc*.zip; do - zip -ur "${zip}" . - zipinfo "${zip}" - done + # Installs nsys-jax-combine; use an editable install here for better coverage + pip install -e .github/container/nsys_jax[jupyter] + # TODO: this could install into the pip prefix as a default + install-flamegraph local/ + install-protoc local/ + echo "$PWD/local/bin" >> "${GITHUB_PATH}" - name: Use nsys-jax-combine to merge profiles from multiple nsys processes - shell: bash -x -e {0} run: | - pip install -e .github/container/jax_nsys/python/jax_nsys - python .github/container/jax_nsys/install-protoc local_protoc - PATH=${PWD}/local_protoc/bin:$PATH .github/container/nsys-jax-combine \ + nsys-jax-combine \ --analysis summary \ --analysis communication \ -o pax_fsdp4_4proc.zip \ @@ -111,18 +108,13 @@ jobs: run: | mkdir combined/ unzip -d combined/ pax_fsdp4_4proc.zip - - name: Run the install script, but skip launching Jupyter Lab - shell: bash -x -e {0} - run: | - pip install virtualenv - NSYS_JAX_INSTALL_SKIP_LAUNCH=1 ./combined/install.sh - - name: Test the Jupyter Lab installation and execute the notebook - shell: bash -x -e {0} + - name: Execute the notebook run: | - pushd combined/ - ./nsys_jax_venv/bin/python -m jupyterlab --version + NOTEBOOK=$(python -c 'from importlib.resources import files; print(files("nsys_jax") / "analyses" / "Analysis.ipynb")') + # Point to the extracted nsys-jax-combine output + export NSYS_JAX_DEFAULT_PREFIX="${PWD}/combined" # Run with ipython for the sake of getting a clear error message - ./nsys_jax_venv/bin/ipython Analysis.ipynb + ipython "${NOTEBOOK}" # This input file was generated with something like # srun -n 1 --container-name=XXX --container-image=ghcr.io/nvidia/jax:pax-2024-07-06 @@ -140,35 +132,39 @@ jobs: steps: - name: Check out the repository under ${GITHUB_WORKSPACE} uses: actions/checkout@v4 - - name: Mock up the structure of an extracted .zip file - shell: bash -x -e {0} + - name: Extract the post-processed profile data from a real .zip file (no .nsys-rep) run: | # Get the actual test data from a real archive, minus the .nsys-rep file - unzip -d .github/container/jax_nsys/ .github/workflows/nsys-jax/test_data/pax_fsdp4_1proc.zip + mkdir profile_data/ + unzip -d profile_data/ .github/workflows/nsys-jax/test_data/pax_fsdp4_1proc.zip - name: "Setup Python 3.10" uses: actions/setup-python@v5 with: python-version: '3.10' - - name: Run the install script, but skip launching Jupyter Lab - shell: bash -x -e {0} + - name: Install nsys-jax and dependencies run: | - pip install virtualenv - NSYS_JAX_INSTALL_SKIP_LAUNCH=1 ./.github/container/jax_nsys/install.sh - - name: Test the Jupyter Lab installation and execute the notebook - shell: bash -x -e {0} + # Do *not* use an editable install (covered above) for better coverage + pip install .github/container/nsys_jax[jupyter] + # TODO: this could install into the pip prefix as a default + install-flamegraph local/ + install-protoc local/ + echo "$PWD/local/bin" >> "${GITHUB_PATH}" + - name: Execute the notebook + id: exec run: | - pushd .github/container/jax_nsys - ./nsys_jax_venv/bin/python -m jupyterlab --version + NOTEBOOK=$(python -c 'from importlib.resources import files; print(files("nsys_jax") / "analyses" / "Analysis.ipynb")') + # Point to the extracted profile data + export NSYS_JAX_DEFAULT_PREFIX="${PWD}/profile_data" # Run with ipython for the sake of getting a clear error message - ./nsys_jax_venv/bin/ipython Analysis.ipynb + ipython "${NOTEBOOK}" + echo "NOTEBOOK=${NOTEBOOK}" >> $GITHUB_OUTPUT - name: Render the notebook id: render - shell: bash -x -e {0} run: | - pushd .github/container/jax_nsys workdir=$(mktemp -d) - ./nsys_jax_venv/bin/jupyter nbconvert --execute --inplace Analysis.ipynb - cp *.ipynb *.svg "${workdir}" + export NSYS_JAX_DEFAULT_PREFIX="${PWD}/profile_data" + jupyter nbconvert --execute --inplace '${{ steps.exec.outputs.NOTEBOOK }}' + cp '${{ steps.exec.outputs.NOTEBOOK }}' *.svg "${workdir}" echo "WORKDIR=${workdir}" >> $GITHUB_OUTPUT - name: Upload rendered notebook to Gist id: upload @@ -270,7 +266,6 @@ jobs: - name: "Install ruff" run: pip install ruff - name: "Run ruff checks" - shell: bash -x {0} run: | ruff check ${NSYS_JAX_PYTHON_FILES} check_status=$? diff --git a/.github/workflows/nsys-jax/prepare-test-data b/.github/workflows/nsys-jax/prepare-test-data index bb548a180..793cee43c 100755 --- a/.github/workflows/nsys-jax/prepare-test-data +++ b/.github/workflows/nsys-jax/prepare-test-data @@ -13,8 +13,8 @@ parser = argparse.ArgumentParser( parser.add_argument("input", help="Input archive") parser.add_argument("output", help="Output archive") args = parser.parse_args() -jax_nsys = pathlib.Path(__file__).parent.parent.parent / "container" / "jax_nsys" -assert jax_nsys.is_dir(), "Could not find .github/container/jax_nsys" +nsys_jax = pathlib.Path(__file__).parent.parent.parent / "container" / "nsys_jax" +assert nsys_jax.is_dir(), "Could not find .github/container/nsys_jax" with zipfile.ZipFile(args.input) as ifile, zipfile.ZipFile( args.output, "w", @@ -28,7 +28,7 @@ with zipfile.ZipFile(args.input) as ifile, zipfile.ZipFile( continue # Don't include any of the source code that is copied into the archive via # nsys-jax and Dockerfile.base; the CI pipeline uses the HEAD version of it - path = jax_nsys / member.filename + path = nsys_jax / member.filename if path.is_file(): continue # Don't include the output of any --nsys-jax-analysis scripts that were run diff --git a/docs/nsys-jax.md b/docs/nsys-jax.md new file mode 100644 index 000000000..70161d859 --- /dev/null +++ b/docs/nsys-jax.md @@ -0,0 +1,124 @@ +# `nsys-jax` wrapper for Nsight Systems +`nsys-jax` refers to a small ecosystem of Python-based tools for collecting and analysing Nsight Systems profiles of +JAX programs. +There are two command-line tools: +- `nsys-jax`, which wraps `nsys profile` and bundles the resulting profile data with additional JAX/XLA-specific + metadata that allows for richer programmatic analysis of the profile data. +- `nsys-jax-combine`, which combines multiple `nsys-jax` output files (for example, collected from different processes + in the same multi-process/multi-node distributed JAX program) into a single output file, with de-duplication and + consistency checks. + +Behind the scenes, there is a small Python library (`nsys_jax`) for loading and analysing the output of `nsys-jax` and +`nsys-jax-combine`, which allows the use of standard data science packages like `numpy`, `pandas` and `matplotlib` to +explore profile data. + +There are three convenient ways of running profile data analyses: +- `nsys-jax ... --nsys-jax-analysis ANALYSIS ... program.py`: after profile data have been collected, `ANALYSIS` will + immediately be executed; results will be printed to the terminal and output data files will be embedded in the output + archive (*i.e.* execution on the compute node immediately after execution) +- `nsys-jax-combine ... --analysis ANALYSIS ...`: after multiple `nsys-jax` outputs have been combined, `ANALYSIS` will + be executed on the merged output; results will be printed to the terminal and output data files will be embedded in + the output archive (*e.g.* execution somewhere inside the compute cluster that has access to all processes' outputs) +- Manual execution; the output files of both `nsys-jax` and `nsys-jax-combine` include an installation script that sets + up a local Python virtual environment including the `nsys_jax` library, Jupyter Lab, and an example notebook for + Jupyter-based exploration of profile data (*e.g.* run this on your laptop and explore your profile data in a Jupyter + notebook). This installation script depends on the `virtualenv` command. + +## Installation +The containers published from this repository (`ghcr.io/nvidia/jax:XXX`) have `nsys-jax` pre-installed; the recipes for +building these are public and can be used as a point of reference if you want to install `nsys-jax` in your own +containers. + +### Manual installation +The main installation step is simply +```console +$ pip install git+https://github.com/NVIDIA/JAX-Toolbox.git#subdirectory=.github/container/nsys_jax +``` +or, for an editable install from a specific branch/tag/commit: +```console +$ pip install --src /checkout-dir -e 'git+https://github.com/NVIDIA/JAX-Toolbox.git@main#subdirectory=.github/container/nsys_jax&egg=nsys-jax' +``` +You may want to include this in a global `pip-compile`-based dependency resolution (as is done in the containers built +from this repository), rather than running too many ad-hoc `pip install` commands. + +This will install all of the components mentioned so far, but does not currently include the following implicit +dependencies: +- `protoc` must be installed at a version compatible with the `google.protobuf` runtime library; `nsys-jax` includes a + helper script that can be run after `pip install`, e.g. to install `/usr/local/bin/protoc`, run + `install-protoc /usr/local`. +- https://github.com/brendangregg/FlameGraph/blob/master/flamegraph.pl must be installed and executable if you want to + generate flame graph visualisations; `nsys-jax` includes a helper script here too, e.g. `install-flamegraph /usr/local`. +- Nsight Systems's multi-report analysis system is used by `nsys-jax` internally and has some additional dependencies + that are not bundled in the Nsight Systems installation + ([doc](https://docs.nvidia.com/nsight-systems/InstallationGuide/index.html#installing-multi-report-analysis-system)), + these are listed in `/target-linux-x64/python/packages/nsys_recipe/requirements/common.txt` and can + be installed with `pip install -r /path/to/common.txt` or by including it in your global `pip-compile`-based + dependency resolution. +- To interpret metadata dumped from XLA, `nsys-jax` needs `.proto` files from XLA that are not included in the JAX + installation. If the relevant XLA source tree is not checked out at `/opt/xla`, the environment variable + `SRC_PATH_XLA` should be set to point to it. +- A small patch to some Python files included in the installations of Nsight Systems versions 2024.5 and 2024.6 is + needed for compatibility with `nsys-jax`, this can be applied by running the `nsys-jax-patch-nsys` command and will + not be required in other versions of Nsight Systems. + +Only `protoc` is always needed, `flamegraph.pl` is an optional dependency, and the remaining dependencies are only +required when actually collecting profile data with the `nsys-jax` command, but not when merging collected profile data +with `nsys-jax-combine` or running local analyses of profile data. + +## Collecting profile data + +The `nsys-jax` command loosely corresponds to `nsys profile`, as introduced in +[the generic profiling documentation](./profiling.md). +Simply run `nsys-jax python my_program.py`. +If you want to pass additional options to `nsys profile`, the syntax is +`nsys-jax [nsys profile options] -- python my_program.py`; the `--` is compulsory. + +`nsys-jax` collects additional JAX/XLA metadata from the program being profiled and automatically performs some +post-processing of the profile data to faciliate programmatic analysis. + +It is usually a good idea to set the profile names to something meaningful using the `--output` (`-o`) option. +The syntax supported by `nsys-jax` is slightly more restricted than what `nsys` supports; only `%q{ENV_VAR}` expansions +are supported. +An example when using the Slurm job orchestrator is: +`nsys-jax -o /out/job%q{SLURM_JOB_ID}/step%q{SLURM_STEP_ID}/rank%q{SLURM_PROCID} -- python my_program.py` +which will result in an output archive `/out/job42/step7/rank0.zip` that contains `rank0.nsys-rep` and other metadata. + +As well as running `nsys profile`, this automatically sets some configuration variables mentioned above, such as +`JAX_TRACEBACK_IN_LOCATIONS_LIMIT`, and sets XLA flags requesting that metadata be saved in Protobuf format. + +> **Important**: because `nsys-jax` manipulates the `XLA_FLAGS` environment variable, you must make sure that this is +> not overwritten inside the executable that you pass. For example `nsys-jax python my_program.py` is fine, but +> `nsys-jax my_script_to_overwrite_xla_flags_and_run_my_program.sh` may not be. + +The only XLA flag that `nsys-jax` will **overwrite** is `--xla_dump_to`, which sets the output directory for the +Protobuf metadata. `nsys-jax` additionally changes the default value of `--xla_dump_hlo_as_proto` (`true`), but will +not modify this if it has been set explicitly. + +> **Note**: because the Protobuf metadata is written at compilation time, using the JAX persistent compilation cache +> prevents it from being written reliably. Because of this `nsys-jax` sets `JAX_ENABLE_COMPILATION_CACHE` to `false` if +> it is not explicitly set. + +After collecting the Nsight Systems profile, `nsys-jax` triggers two extra processing steps: +- the `.nsys-rep` file is converted into a `.parquet` and a `.csv.xz` file for offline analysis +- the metadata dumped by XLA is scanned for references to Python source code files -- i.e. your JAX program and the + Python libraries on which it depends. Those files are copied to the output archive. + +Finally, a compressed `.zip` archive is generated. The post-processing uses a local, temporary directory. Only the +final archive is written to the given output location, which is likely to be on slower, shared storage. + +## Offline analysis +Copy an `nsys-jax` archive to an interactive system, and extract it. At the top level, there is an `install.sh` script +that will create a Python virtual environment containing Jupyter Lab and the dependencies of the `Analysis.ipynb` +notebook that is also distributed in the archive. Run this and the suggested launch command for Jupyter Lab. + +The included notebook is intended to be a template for programmatic analysis of the profile data in conjunction with +the metadata from XLA. Out of the box it will provide some basic summaries and visualisations: +![Analysis notebook inside Jupyter Lab showing an interactive flame graph of JAX source code](./img/jupyter-flamegraph.png) + +Examples include summaries of compilation time, heap memory usage, and straggler analysis of multi-GPU jobs. + +You can see a rendered example of this notebook, as generated from the `main` branch of this repository, here: +https://gist.github.com/nvjax/e2cd3520201caab6b67385ed36fad3c1#file-analysis-ipynb. + +> **Note**: this code should be considered unstable, the bundled notebook and its input data format may change +> considerably, but it should provide a useful playground in which to experiment with your own profile data. diff --git a/docs/profiling.md b/docs/profiling.md index b375cee78..b6f7c7eb3 100644 --- a/docs/profiling.md +++ b/docs/profiling.md @@ -193,48 +193,7 @@ Loosely this corresponds to `nsys profile` above, i.e. simply run `nsys-jax pyth additional options to `nsys profile`, the syntax is `nsys-jax [nsys profile options] -- python my_program.py`; the `--` is compulsory. -It is usually a good idea to set the profile names to something meaningful using `nsys profile`'s `--output=..` option. -`nsys-jax` will read the value of this option and save extra metadata under the same prefix, with the restriction that -only `%q{ENV_VAR}` expansions are supported. An example when using the Slurm job orchestrator is: -`nsys-jax -o /out/job%q{SLURM_JOB_ID}/step%q{SLURM_STEP_ID}/rank%q{SLURM_PROCID} -- python my_program.py` -which will result in an output archive `/out/job42/step7/rank0.zip` that contains `rank0.nsys-rep` and other metadata. +`nsys-jax` collects additional JAX/XLA metadata from the program being profiled and automatically performs some +post-processing of the profile data to faciliate programmatic analysis. -As well as running `nsys profile`, this automatically sets some configuration variables mentioned above, such as -`JAX_TRACEBACK_IN_LOCATIONS_LIMIT`, and sets XLA flags requesting that metadata be saved in Protobuf format. - -> **Important**: because `nsys-jax` manipulates the `XLA_FLAGS` environment variable, you must make sure that this is -> not overwritten inside the executable that you pass. For example `nsys-jax python my_program.py` is fine, but -> `nsys-jax my_script_to_overwrite_xla_flags_and_run_my_program.sh` may not be. - -The only XLA flag that `nsys-jax` will **overwrite** is `--xla_dump_to`, which sets the output directory for the -Protobuf metadata. `nsys-jax` additionally changes the default value of `--xla_dump_hlo_as_proto` (`true`), but will -not modify this if it has been set explicitly. - -> **Note**: because the Protobuf metadata is written at compilation time, using the JAX persistent compilation cache -> prevents it from being written reliably. Because of this `nsys-jax` sets `JAX_ENABLE_COMPILATION_CACHE` to `false` if -> it is not explicitly set. - -After collecting the Nsight Systems profile, `nsys-jax` triggers two extra processing steps: -- the `.nsys-rep` file is converted into a `.parquet` and a `.csv.xz` file for offline analysis -- the metadata dumped by XLA is scanned for references to Python source code files -- i.e. your JAX program and the - Python libraries on which it depends. Those files are copied to the output archive. - -Finally, a compressed `.zip` archive is generated. The post-processing uses a local, temporary directory. Only the -final archive is written to the given output location, which is likely to be on slower, shared storage. - -### Offline analysis -Copy an `nsys-jax` archive to an interactive system, and extract it. At the top level, there is an `install.sh` script -that will create a Python virtual environment containing Jupyter Lab and the dependencies of the `Analysis.ipynb` -notebook that is also distributed in the archive. Run this and the suggested launch command for Jupyter Lab. - -The included notebook is intended to be a template for programmatic analysis of the profile data in conjunction with -the metadata from XLA. Out of the box it will provide some basic summaries and visualisations: -![Analysis notebook inside Jupyter Lab showing an interactive flame graph of JAX source code](./img/jupyter-flamegraph.png) - -Examples include summaries of compilation time, heap memory usage, and straggler analysis of multi-GPU jobs. - -You can see a rendered example of this notebook, as generated from the `main` branch of this repository, here: -https://gist.github.com/nvjax/e2cd3520201caab6b67385ed36fad3c1#file-analysis-ipynb. - -> **Note**: this code should be considered unstable, the bundled notebook and its input data format may change -> considerably, but it should provide a useful playground in which to experiment with your own profile data. +It is documented in more detail [on this page](./nsys-jax.md).