Skip to content

Commit

Permalink
nsys-jax: re-work to be more pip-install-able (#1165)
Browse files Browse the repository at this point in the history
The overarching goal of this PR is to get closer to a world where the
`nsys-jax` tooling is straightforwardly `pip install`-able. While the
diff looks scary, it's mostly re-organisation.

Substantive changes:
- `nsys-jax` no longer bundles Python code in the output archives, the
`install.sh` script provided for users to run on local machines becomes,
loosely, `install 'pip nsys-jax[jupyter] @
git+https://github.com/NVIDIA/JAX-Toolbox.git@COMMIT#subdirectory=.github/container/nsys_jax'`,
where `COMMIT` corresponds to the `nsys-jax` command that produced the
archive. For the `ghcr.io/nvidia/jax` containers, this is the commit of
JAX-Toolbox that triggered the container build.

Changes included:
- Introduce `/opt/pip-tools-post-install.d`, which `pip-finalize.sh`
will execute the contents of *after* installing the `pip`-managed world
- Migrate `install-protoc` to use this, so `pip-finalize.sh` can forget
about that detail.
- Install
https://github.com/brendangregg/FlameGraph/blob/master/flamegraph.pl via
this.
- Patch the `nvtx_gpu_proj_trace` Python code in Nsight Systems 2024.5
and 2024.6 via this.
- Move `nsys-jax` installation (specifically for the containers) into
`install-nsys-jax.sh` and thereby clean up `install-nsight.sh`. The new
script has to be told the git commit hash of JAX-Toolbox that is being
built, because `nsys-jax` bakes this into an installation script in its
output `.zip` archives to ensure the local environment matches the
profile-collection environment.
- The CLI tools like `nsys-jax`, `nsys-jax-combine` and `install-protoc`
are now handled via `[project.scripts]` in `pyproject.toml` instead of
being standalone Python scripts. This is "more standard", and also makes
it easier to share code between `nsys-jax` and `nsys-jax-combine`.
- The Python library is renamed from `jax_nsys` to `nsys_jax` for
consistency.
- It's now possible to set the default data loading path via the
`NSYS_JAX_DEFAULT_PREFIX` environment variable; previously the default
was the current working directory, but that can be inconvenient to steer
in Jupyter environments.
  • Loading branch information
olupton authored Dec 3, 2024
1 parent de72dd8 commit 5c4b687
Show file tree
Hide file tree
Showing 36 changed files with 1,497 additions and 1,283 deletions.
29 changes: 9 additions & 20 deletions .github/container/Dockerfile.base
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ ARG BASE_IMAGE=nvidia/cuda:12.6.2-devel-ubuntu22.04
ARG GIT_USER_NAME="JAX Toolbox"
ARG [email protected]
ARG CLANG_VERSION=18
ARG JAX_TOOLBOX_REF

###############################################################################
## Obtain GCP's NCCL TCPx plugin
Expand Down Expand Up @@ -30,6 +31,7 @@ ARG BASE_IMAGE
ARG GIT_USER_EMAIL
ARG GIT_USER_NAME
ARG CLANG_VERSION
ARG JAX_TOOLBOX_REF
ENV CUDA_BASE_IMAGE=${BASE_IMAGE}

###############################################################################
Expand Down Expand Up @@ -110,7 +112,7 @@ RUN <<"EOF" bash -ex
git config --global user.name "${GIT_USER_NAME}"
git config --global user.email "${GIT_USER_EMAIL}"
EOF
RUN mkdir -p /opt/pip-tools.d
RUN mkdir -p /opt/pip-tools.d /opt/pip-tools-post-install.d
ADD --chmod=777 \
git-clone.sh \
pip-finalize.sh \
Expand Down Expand Up @@ -141,7 +143,6 @@ COPY --from=tcpx-installer /var/lib/tcpx/lib64 ${TCPX_LIBRARY_PATH}
###############################################################################

ADD install-nsight.sh /usr/local/bin
ADD nsys-2024.5-tid-export.patch /opt/nvidia
RUN install-nsight.sh

###############################################################################
Expand Down Expand Up @@ -183,7 +184,7 @@ ENV PATH=/opt/amazon/efa/bin:${PATH}
ADD install-nccl-sanity-check.sh /usr/local/bin
ADD nccl-sanity-check.cu /opt
RUN install-nccl-sanity-check.sh
ADD jax-nccl-test parallel-launch /usr/local/bin
ADD jax-nccl-test parallel-launch /usr/local/bin/

###############################################################################
## Add the systemcheck to the entrypoint.
Expand All @@ -199,23 +200,11 @@ COPY check-shm.sh /opt/nvidia/entrypoint.d/
# COPY gcp-autoconfig.sh /opt/nvidia/entrypoint.d/

###############################################################################
## Add helper scripts for profiling with Nsight Systems
##
## The scripts saved to /opt/jax_nsys are embedded in the output archives
## written by nsys-jax, while the nsys-jax and nsys-jax-combine scripts are
## only used inside the containers.
###############################################################################
ADD nsys-jax nsys-jax-combine /usr/local/bin/
ADD jax_nsys/ /opt/jax_nsys
# The jax_nsys package should be installed inside the containers, so nsys-jax
# can eagerly execute analysis recipes (--nsys-jax-analysis) in the container
# environment, without an extra layer of virtual environment indirection.
RUN echo "-e /opt/jax_nsys/python/jax_nsys" > /opt/pip-tools.d/requirements-nsys-jax.in
# This should be embedded in output archives and be runnable inside containers
RUN ln -s /opt/jax_nsys/install-protoc /usr/local/bin/
# Should be available for execution inside the containers, should not be
# embedded in the output archives.
ADD jax_nsys_tests/ /opt/jax_nsys_tests
## Install the nsys-jax JAX/XLA-aware profiling scripts, patch Nsight Systems
###############################################################################

ADD install-nsys-jax.sh /usr/local/bin
RUN install-nsys-jax.sh ${JAX_TOOLBOX_REF}

###############################################################################
## Copy manifest file to the container
Expand Down
11 changes: 0 additions & 11 deletions .github/container/install-nsight.sh
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,3 @@ apt-get install -y nsight-compute nsight-systems-cli-2024.6.1
apt-get clean

rm -rf /var/lib/apt/lists/*

for NSYS in /opt/nvidia/nsight-systems-cli/2024.5.1 /opt/nvidia/nsight-systems-cli/2024.6.1; do
if [[ -d "${NSYS}" ]]; then
# * can match at least sbsa-armv8 and x86
(cd ${NSYS}/target-linux-*/python/packages && git apply < /opt/nvidia/nsys-2024.5-tid-export.patch)
fi
done

# Install extra dependencies needed for `nsys recipe ...` commands. These are
# used by the nsys-jax wrapper script.
ln -s $(dirname $(realpath $(command -v nsys)))/python/packages/nsys_recipe/requirements/common.txt /opt/pip-tools.d/requirements-nsys-recipe.in
32 changes: 32 additions & 0 deletions .github/container/install-nsys-jax.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
#!/bin/bash
set -exo pipefail

REF="$1"
if [[ -z "${REF}" ]]; then
echo "$0: <git ref of JAX-Toolbox>"
exit 1
fi

# Install extra dependencies needed for `nsys recipe ...` commands. These are
# used by the nsys-jax wrapper script.
NSYS_DIR=$(dirname $(realpath $(command -v nsys)))
ln -s ${NSYS_DIR}/python/packages/nsys_recipe/requirements/common.txt /opt/pip-tools.d/requirements-nsys-recipe.in

# Install the nsys-jax package, which includes nsys-jax, nsys-jax-combine,
# install-protoc (called from pip-finalize.sh), and nsys-jax-patch-nsys as well as the
# nsys_jax Python library.
URL="git+https://github.com/NVIDIA/JAX-Toolbox.git@${REF}#subdirectory=.github/container/nsys_jax&egg=nsys-jax"
echo "-e '${URL}'" > /opt/pip-tools.d/requirements-nsys-jax.in

# protobuf will be installed at least as a dependency of nsys_jax in the base
# image, but the installed version is likely to be influenced by other packages.
echo "install-protoc /usr/local" > /opt/pip-tools-post-install.d/protoc
chmod 755 /opt/pip-tools-post-install.d/protoc

# Make sure flamegraph.pl is available
echo "install-flamegraph /usr/local" > /opt/pip-tools-post-install.d/flamegraph
chmod 755 /opt/pip-tools-post-install.d/flamegraph

# Make sure Nsight Systems Python patches are installed if needed
echo "nsys-jax-patch-nsys" > /opt/pip-tools-post-install.d/patch-nsys
chmod 755 /opt/pip-tools-post-install.d/patch-nsys
65 changes: 0 additions & 65 deletions .github/container/jax_nsys/install-protoc

This file was deleted.

38 changes: 0 additions & 38 deletions .github/container/jax_nsys/install.sh

This file was deleted.

17 changes: 0 additions & 17 deletions .github/container/jax_nsys/python/jax_nsys/pyproject.toml

This file was deleted.

Loading

0 comments on commit 5c4b687

Please sign in to comment.