From 0e4e2454d06cac5f7f460ce596cb6d36212eb583 Mon Sep 17 00:00:00 2001 From: Olli Lupton Date: Tue, 29 Oct 2024 17:00:17 +0100 Subject: [PATCH] jax-toolbox-triage: minor usability/doc improvements (#1125) - Print the stdout/stderr of the first execution of the test case, which is supposed to fail, at INFO level along with a message encouraging the user to check that it is the correct failure. - Print the path to the DEBUG log file at INFO level and, therefore, to the console. - Expand the documentation. - Add `--passing-container` and `--failing-container` arguments, which allow the container-level search to be skipped and non-dated containers to be triaged. --- .github/triage/jax_toolbox_triage/args.py | 49 +++++++++++++++-- .github/triage/jax_toolbox_triage/logic.py | 30 +++++++++-- .github/triage/jax_toolbox_triage/main.py | 48 ++++++++++------- .github/triage/tests/test_arg_parsing.py | 62 ++++++++++++++++++++++ .github/triage/tests/test_triage_logic.py | 8 +-- docs/triage-tool.md | 62 ++++++++++++++++++++-- 6 files changed, 226 insertions(+), 33 deletions(-) create mode 100644 .github/triage/tests/test_arg_parsing.py diff --git a/.github/triage/jax_toolbox_triage/args.py b/.github/triage/jax_toolbox_triage/args.py index d092e7200..210462a8b 100644 --- a/.github/triage/jax_toolbox_triage/args.py +++ b/.github/triage/jax_toolbox_triage/args.py @@ -6,7 +6,7 @@ import tempfile -def parse_args(): +def parse_args(args=None): parser = argparse.ArgumentParser( description=""" Triage failures in JAX/XLA-related tests. The expectation is that the given @@ -37,7 +37,6 @@ def parse_args(): help=""" Container to use. Example: jax, pax, triton. Used to construct the URLs of nightly containers, like ghcr.io/nvidia/jax:CONTAINER-YYYY-MM-DD.""", - required=True, ) parser.add_argument( "--output-prefix", @@ -67,6 +66,15 @@ def parse_args(): Command to execute inside the container. This should be as targeted as possible.""", ) + container_search_args.add_argument( + "--failing-container", + help=""" + Skip the container-level search and pass this container to the commit-level + search. If this is passed, --passing-container must be too, but --container + is not required. This can be used to apply the commit-level bisection + search to containers not from the ghcr.io/nvidia/jax:CONTAINER-YYYY-MM-DD + series, although they must have a similar structure.""", + ) container_search_args.add_argument( "--end-date", help=""" @@ -76,6 +84,15 @@ def parse_args(): test case fails on this date.""", type=lambda s: datetime.date.fromisoformat(s), ) + container_search_args.add_argument( + "--passing-container", + help=""" + Skip the container-level search and pass this container to the commit-level + search. If this is passed, --failing-container must be too, but --container is + not required. This can be used to apply the commit-level bisection search + to containers not from the ghcr.io/nvidia/jax:CONTAINER-YYYY-MM-DD series, + although they must have a similar structure.""", + ) container_search_args.add_argument( "--start-date", help=""" @@ -109,4 +126,30 @@ def parse_args(): significantly speed up the commit-level search. By default, uses a temporary directory including the name of the current user.""", ) - return parser.parse_args() + args = parser.parse_args(args=args) + num_explicit_containers = (args.passing_container is not None) + ( + args.failing_container is not None + ) + if num_explicit_containers == 1: + raise Exception( + "--passing-container and --failing-container must both be passed if either is" + ) + if num_explicit_containers == 2: + # Explicit mode, --container, --start-date and --end-date are all ignored + if args.container: + raise Exception( + "--container must not be passed if --passing-container and --failing-container are" + ) + if args.start_date: + raise Exception( + "--start-date must not be passed if --passing-container and --failing-container are" + ) + if args.end_date: + raise Exception( + "--end-date must not be passed if --passing-container and --failing-container are" + ) + elif num_explicit_containers == 0 and args.container is None: + raise Exception( + "--container must be passed if --passing-container and --failing-container are not" + ) + return args diff --git a/.github/triage/jax_toolbox_triage/logic.py b/.github/triage/jax_toolbox_triage/logic.py index 1e88f57c5..f9c8b46ae 100644 --- a/.github/triage/jax_toolbox_triage/logic.py +++ b/.github/triage/jax_toolbox_triage/logic.py @@ -1,9 +1,22 @@ +from dataclasses import dataclass import datetime import functools import logging import typing +@dataclass +class TestResult: + """ + Hold the result/stdout/stderr of a test execution + """ + + __test__ = False # stop pytest gathering this + result: bool + stdout: typing.Optional[str] = None + stderr: typing.Optional[str] = None + + def as_datetime(date: datetime.date) -> datetime.datetime: return datetime.datetime.combine(date, datetime.time()) @@ -59,7 +72,7 @@ def adjust_date( def container_search( *, container_exists: typing.Callable[[datetime.date], bool], - container_passes: typing.Callable[[datetime.date], bool], + container_passes: typing.Callable[[datetime.date], TestResult], start_date: typing.Optional[datetime.date], end_date: typing.Optional[datetime.date], logger: logging.Logger, @@ -88,8 +101,17 @@ def container_search( logger.info(f"Skipping check for end-of-range failure in {end_date}") else: logger.info(f"Checking end-of-range failure in {end_date}") - if container_passes(end_date): + test_end_date = container_passes(end_date) + logger.info(f"stdout: {test_end_date.stdout}") + logger.info(f"stderr: {test_end_date.stderr}") + if test_end_date.result: raise Exception(f"Could not reproduce failure in {end_date}") + logger.info( + "IMPORTANT: you should check that the test output above shows the " + f"*expected* failure of your test case in the {end_date} container. It is " + "very easy to accidentally provide a test case that fails for the wrong " + "reason, which will not triage the correct issue!" + ) # Start the coarse, container-level, search for a starting point to the bisection range earliest_failure = end_date @@ -127,7 +149,7 @@ def container_search( logger.info(f"Skipping check that the test passes on start_date={start_date}") else: # While condition prints an info message - while not container_passes(search_date): + while not container_passes(search_date).result: # Test failed on `search_date`, go further into the past earliest_failure = search_date new_search_date = adjust( @@ -155,7 +177,7 @@ def container_search( if range_mid is None: # It wasn't possible to refine further. break - result = container_passes(range_mid) + result = container_passes(range_mid).result if result: range_start = range_mid else: diff --git a/.github/triage/jax_toolbox_triage/main.py b/.github/triage/jax_toolbox_triage/main.py index 810109f32..925fffacd 100755 --- a/.github/triage/jax_toolbox_triage/main.py +++ b/.github/triage/jax_toolbox_triage/main.py @@ -8,7 +8,7 @@ from .args import parse_args from .docker import DockerContainer -from .logic import commit_search, container_search +from .logic import commit_search, container_search, TestResult from .utils import ( container_exists as container_exists_base, container_url as container_url_base, @@ -21,6 +21,10 @@ def main(): args = parse_args() bazel_cache_mounts = prepare_bazel_cache_mounts(args.bazel_cache) logger = get_logger(args.output_prefix) + logger.info( + "Verbose output, including stdout/err of triage commands, will be written to " + f'{(args.output_prefix / "debug.log").resolve()}' + ) container_url = functools.partial(container_url_base, container=args.container) container_exists = functools.partial( container_exists_base, container=args.container, logger=logger @@ -75,7 +79,7 @@ def get_commit(container: DockerContainer, repo: str) -> typing.Tuple[str, str]: f"Could not extract commit of {repo} from {args.container} container {container}" ) - def check_container(date: datetime.date) -> bool: + def check_container(date: datetime.date) -> TestResult: """ See if the test passes in the given container. """ @@ -100,37 +104,45 @@ def check_container(date: datetime.date) -> bool: "xla": xla_commit, }, ) - return test_pass - - # Search through the published containers, narrowing down to a pair of dates with - # the property that the test passed on `range_start` and fails on `range_end`. - range_start, range_end = container_search( - container_exists=container_exists, - container_passes=check_container, - start_date=args.start_date, - end_date=args.end_date, - logger=logger, - skip_precondition_checks=args.skip_precondition_checks, - threshold_days=args.threshold_days, - ) + return TestResult(result=test_pass, stdout=result.stdout, stderr=result.stderr) + + if args.passing_container is not None: + assert args.failing_container is not None + # Skip the container-level search because explicit end points were given + passing_url = args.passing_container + failing_url = args.failing_container + else: + # Search through the published containers, narrowing down to a pair of dates with + # the property that the test passed on `range_start` and fails on `range_end`. + range_start, range_end = container_search( + container_exists=container_exists, + container_passes=check_container, + start_date=args.start_date, + end_date=args.end_date, + logger=logger, + skip_precondition_checks=args.skip_precondition_checks, + threshold_days=args.threshold_days, + ) + passing_url = container_url(range_start) + failing_url = container_url(range_end) # Container-level search is now complete. Triage proceeds inside the `range_end`` # container. First, we check that rewinding JAX and XLA inside the `range_end`` # container to the commits used in the `range_start` container passes, whereas # using the `range_end` commits reproduces the failure. - with Container(container_url(range_start)) as worker: + with Container(passing_url) as worker: start_jax_commit, _ = get_commit(worker, "jax") start_xla_commit, _ = get_commit(worker, "xla") # Fire up the container that will be used for the fine search. - with Container(container_url(range_end)) as worker: + with Container(failing_url) as worker: end_jax_commit, jax_dir = get_commit(worker, "jax") end_xla_commit, xla_dir = get_commit(worker, "xla") logger.info( ( f"Bisecting JAX [{start_jax_commit}, {end_jax_commit}] and " - f"XLA [{start_xla_commit}, {end_xla_commit}] using {container_url(range_end)}" + f"XLA [{start_xla_commit}, {end_xla_commit}] using {failing_url}" ) ) diff --git a/.github/triage/tests/test_arg_parsing.py b/.github/triage/tests/test_arg_parsing.py new file mode 100644 index 000000000..aebef4afc --- /dev/null +++ b/.github/triage/tests/test_arg_parsing.py @@ -0,0 +1,62 @@ +import pytest +from jax_toolbox_triage.args import parse_args + +test_command = ["my-test-command"] +valid_start_end_container = [ + "--passing-container", + "passing-url", + "--failing-container", + "failing-url", +] +valid_start_end_date_args = [ + ["--container", "jax"], + ["--container", "jax", "--start-date", "2024-10-02"], + ["--container", "jax", "--end-date", "2024-10-02"], + ["--container", "jax", "--start-date", "2024-10-01", "--end-date", "2024-10-02"], +] + + +@pytest.mark.parametrize( + "good_args", [valid_start_end_container] + valid_start_end_date_args +) +def test_good_container_args(good_args): + args = parse_args(good_args + test_command) + assert args.test_command == test_command + + +@pytest.mark.parametrize("date_args", valid_start_end_date_args) +def test_bad_container_arg_combinations_across_groups(date_args): + # Can't combine --{start,end}-container with --container/--{start,end}-date + with pytest.raises(Exception): + parse_args(valid_start_end_container + date_args + test_command) + + +@pytest.mark.parametrize( + "container_args", + [ + # Need --container + [], + ["--start-date", "2024-10-01"], + ["--end-date", "2024-10-02"], + ["--start-date", "2024-10-01", "--end-date", "2024-10-02"], + # Need both if either is passed + ["--passing-container", "passing-url"], + ["--failing-container", "failing-url"], + ], +) +def test_bad_container_arg_combinations_within_groups(container_args): + with pytest.raises(Exception): + parse_args(container_args + test_command) + + +@pytest.mark.parametrize( + "container_args", + [ + # Need valid ISO dates + ["--container", "jax", "--start-date", "a-blue-moon-ago"], + ["--container", "jax", "--end-date", "a-year-ago-last-thursday"], + ], +) +def test_unparsable_container_args(container_args): + with pytest.raises(SystemExit): + parse_args(container_args + test_command) diff --git a/.github/triage/tests/test_triage_logic.py b/.github/triage/tests/test_triage_logic.py index 411740ded..81d90701d 100644 --- a/.github/triage/tests/test_triage_logic.py +++ b/.github/triage/tests/test_triage_logic.py @@ -3,7 +3,7 @@ import logging import pytest import random -from jax_toolbox_triage.logic import commit_search, container_search +from jax_toolbox_triage.logic import commit_search, container_search, TestResult def wrap(b): @@ -306,7 +306,7 @@ def test_container_search_limits( with pytest.raises(Exception, match=match_string): container_search( container_exists=lambda dt: dt in dates_that_exist, - container_passes=lambda dt: False, + container_passes=lambda dt: TestResult(result=False), start_date=start_date, end_date=end_date, logger=logger, @@ -353,7 +353,7 @@ def test_container_search_checks( with pytest.raises(Exception, match=match_string): container_search( container_exists=lambda dt: True, - container_passes=lambda dt: dt in dates_that_pass, + container_passes=lambda dt: TestResult(result=dt in dates_that_pass), start_date=start_date, end_date=end_date, logger=logger, @@ -374,7 +374,7 @@ def test_container_search(logger, start_date, days_of_failure, threshold_days): assert start_date is None or threshold_date >= start_date good_date, bad_date = container_search( container_exists=lambda dt: True, - container_passes=lambda dt: dt < threshold_date, + container_passes=lambda dt: TestResult(result=dt < threshold_date), start_date=start_date, end_date=end_date, logger=logger, diff --git a/docs/triage-tool.md b/docs/triage-tool.md index 71cd35a33..467044c7e 100644 --- a/docs/triage-tool.md +++ b/docs/triage-tool.md @@ -20,6 +20,14 @@ The tool follows a three-step process: failing, and a reference commit of XLA (JAX) that can be used to reproduce the regression. +The third step can also be used on its own, via the `--passing-container` and +`--failing-container` options, which allows it to be used between private container +tags, without the dependency on the `ghcr.io/nvidia/jax` registry. This assumes that +the given containers are closely related to those from JAX-Toolbox +(`ghcr.io/nvidia/jax:XXX`): +* JAX and XLA sources at `/opt/{jax,xla}[-source]` +* `build-jax.sh` script from JAX-Toolbox available in the container + ## Installation The triage tool can be installed using `pip`: @@ -27,26 +35,50 @@ The triage tool can be installed using `pip`: pip install git+https://github.com/NVIDIA/JAX-Toolbox.git#subdirectory=.github/triage ``` or directly from a checkout of the JAX-Toolbox repository. + +You should make sure `pip` is up to date, for example with `pip install -U pip`. The +versions of `pip` installed on cluster head/compute nodes can be quite old. The +recommended installation method, using `virtualenv`, should take care of this for you. + Because the tool needs to orchestrate running commands in multiple containers, it is most convenient to install it in a virtual environment on the host system, rather than attempting to install it inside a container. +The recommended installation method is to install `virtualenv` natively on the host +system, and then use that to create an isolated environment on the host system for the +triage tool, *i.e.*: +```bash +virtualenv triage-venv +./triage-venv/bin/pip install git+https://github.com/NVIDIA/JAX-Toolbox.git#subdirectory=.github/triage +./triage-venv/bin/jax-toolbox-triage ... +``` + The tool should be invoked on a machine with `docker` available and whatever GPUs are needed to execute the test case. ## Usage -To use the tool, there are two compulsory arguments: - * `--container`: which of the `ghcr.io/nvidia/jax:CONTAINER-YYYY-MM-DD` container - families to execute the test command in. Example: `jax` for a JAX unit test - failure, `maxtext` for a MaxText model execution failure +To use the tool, there are two compulsory inputs: * A test command to triage. + * A specification of which containers to triage in. There are two choices here: + * `--container`: which of the `ghcr.io/nvidia/jax:CONTAINER-YYYY-MM-DD` container + families to execute the test command in. Example: `jax` for a JAX unit test + failure, `maxtext` for a MaxText model execution failure. The `--start-date` and + `--end-date` options can be combined with `--container` to tune the search; see + below for more details. + * `--passing-container` and `--failing-container`: a pair of URLs to containers to + use in the commit-level search; if these are passed then no container-level + search is performed. The test command will be executed directly in the container, not inside a shell, so be sure not to add excessive quotation marks (*i.e.* run `jax-toolbox-triage --container=jax test-jax.sh foo` not `jax-toolbox-triage --container=jax "test-jax.sh foo"`), and you should aim to make it as fast and targeted as possible. + +If you want to run multiple commands, you might want to use something like +`jax-toolbox-triage --container=jax sh -c "command1 && command2"`. + The expectation is that the test case will be executed successfully several times as part of the triage, so you may want to tune some parameters to reduce the execution time in the successful case. @@ -55,6 +87,28 @@ probably reduce `--steps` to optimise execution time in the successful case. A JSON status file and both info-level and debug-level logfiles are written to the directory given by `--output-prefix`. +Info-level output is also written to the console, and includes the path to the debug +log file. + +You should pay attention to the first execution of your test case, to make sure it is +failing for the correct reason. For example: +```console +$ jax-toolbox-triage --container jax command-you-forgot-to-install +``` +will not immediately abort, because the tool is **expecting** the command to fail in +the early stages of the triage: +``` +[INFO] 2024-10-29 01:49:01 Verbose output, including stdout/err of triage commands, will be written to /home/olupton/JAX-Toolbox/triage-2024-10-29-01-49-01/debug.log +[INFO] 2024-10-29 01:49:05 Checking end-of-range failure in 2024-10-27 +[INFO] 2024-10-29 01:49:05 Ran test case in 2024-10-27 in 0.4s, pass=False +[INFO] 2024-10-29 01:49:05 stdout: OCI runtime exec failed: exec failed: unable to start container process: exec: "command-you-forgot-to-install": executable file not found in $PATH: unknown + +[INFO] 2024-10-29 01:49:05 stderr: +[INFO] 2024-10-29 01:49:05 IMPORTANT: you should check that the test output above shows the *expected* failure of your test case in the 2024-10-27 container. It is very easy to accidentally provide a test case that fails for the wrong reason, which will not triage the correct issue! +[INFO] 2024-10-29 01:49:06 Starting coarse search with 2024-10-26 based on end_date=2024-10-27 +[INFO] 2024-10-29 01:49:06 Ran test case in 2024-10-26 in 0.4s, pass=False +``` +where, notably, the triage search is continuing. ### Optimising container-level search performance