Skip to content

Commit

Permalink
jax-toolbox-triage: minor usability/doc improvements (#1125)
Browse files Browse the repository at this point in the history
- Print the stdout/stderr of the first execution of the test case, which
is supposed to fail, at INFO level along with a message encouraging the
user to check that it is the correct failure.
- Print the path to the DEBUG log file at INFO level and, therefore, to
the console.
- Expand the documentation.
- Add `--passing-container` and `--failing-container` arguments, which
allow the container-level search to be skipped and non-dated containers
to be triaged.
  • Loading branch information
olupton authored Oct 29, 2024
1 parent bde47a4 commit 0e4e245
Show file tree
Hide file tree
Showing 6 changed files with 226 additions and 33 deletions.
49 changes: 46 additions & 3 deletions .github/triage/jax_toolbox_triage/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import tempfile


def parse_args():
def parse_args(args=None):
parser = argparse.ArgumentParser(
description="""
Triage failures in JAX/XLA-related tests. The expectation is that the given
Expand Down Expand Up @@ -37,7 +37,6 @@ def parse_args():
help="""
Container to use. Example: jax, pax, triton. Used to construct the URLs of
nightly containers, like ghcr.io/nvidia/jax:CONTAINER-YYYY-MM-DD.""",
required=True,
)
parser.add_argument(
"--output-prefix",
Expand Down Expand Up @@ -67,6 +66,15 @@ def parse_args():
Command to execute inside the container. This should be as targeted as
possible.""",
)
container_search_args.add_argument(
"--failing-container",
help="""
Skip the container-level search and pass this container to the commit-level
search. If this is passed, --passing-container must be too, but --container
is not required. This can be used to apply the commit-level bisection
search to containers not from the ghcr.io/nvidia/jax:CONTAINER-YYYY-MM-DD
series, although they must have a similar structure.""",
)
container_search_args.add_argument(
"--end-date",
help="""
Expand All @@ -76,6 +84,15 @@ def parse_args():
test case fails on this date.""",
type=lambda s: datetime.date.fromisoformat(s),
)
container_search_args.add_argument(
"--passing-container",
help="""
Skip the container-level search and pass this container to the commit-level
search. If this is passed, --failing-container must be too, but --container is
not required. This can be used to apply the commit-level bisection search
to containers not from the ghcr.io/nvidia/jax:CONTAINER-YYYY-MM-DD series,
although they must have a similar structure.""",
)
container_search_args.add_argument(
"--start-date",
help="""
Expand Down Expand Up @@ -109,4 +126,30 @@ def parse_args():
significantly speed up the commit-level search. By default, uses a temporary
directory including the name of the current user.""",
)
return parser.parse_args()
args = parser.parse_args(args=args)
num_explicit_containers = (args.passing_container is not None) + (
args.failing_container is not None
)
if num_explicit_containers == 1:
raise Exception(
"--passing-container and --failing-container must both be passed if either is"
)
if num_explicit_containers == 2:
# Explicit mode, --container, --start-date and --end-date are all ignored
if args.container:
raise Exception(
"--container must not be passed if --passing-container and --failing-container are"
)
if args.start_date:
raise Exception(
"--start-date must not be passed if --passing-container and --failing-container are"
)
if args.end_date:
raise Exception(
"--end-date must not be passed if --passing-container and --failing-container are"
)
elif num_explicit_containers == 0 and args.container is None:
raise Exception(
"--container must be passed if --passing-container and --failing-container are not"
)
return args
30 changes: 26 additions & 4 deletions .github/triage/jax_toolbox_triage/logic.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,22 @@
from dataclasses import dataclass
import datetime
import functools
import logging
import typing


@dataclass
class TestResult:
"""
Hold the result/stdout/stderr of a test execution
"""

__test__ = False # stop pytest gathering this
result: bool
stdout: typing.Optional[str] = None
stderr: typing.Optional[str] = None


def as_datetime(date: datetime.date) -> datetime.datetime:
return datetime.datetime.combine(date, datetime.time())

Expand Down Expand Up @@ -59,7 +72,7 @@ def adjust_date(
def container_search(
*,
container_exists: typing.Callable[[datetime.date], bool],
container_passes: typing.Callable[[datetime.date], bool],
container_passes: typing.Callable[[datetime.date], TestResult],
start_date: typing.Optional[datetime.date],
end_date: typing.Optional[datetime.date],
logger: logging.Logger,
Expand Down Expand Up @@ -88,8 +101,17 @@ def container_search(
logger.info(f"Skipping check for end-of-range failure in {end_date}")
else:
logger.info(f"Checking end-of-range failure in {end_date}")
if container_passes(end_date):
test_end_date = container_passes(end_date)
logger.info(f"stdout: {test_end_date.stdout}")
logger.info(f"stderr: {test_end_date.stderr}")
if test_end_date.result:
raise Exception(f"Could not reproduce failure in {end_date}")
logger.info(
"IMPORTANT: you should check that the test output above shows the "
f"*expected* failure of your test case in the {end_date} container. It is "
"very easy to accidentally provide a test case that fails for the wrong "
"reason, which will not triage the correct issue!"
)

# Start the coarse, container-level, search for a starting point to the bisection range
earliest_failure = end_date
Expand Down Expand Up @@ -127,7 +149,7 @@ def container_search(
logger.info(f"Skipping check that the test passes on start_date={start_date}")
else:
# While condition prints an info message
while not container_passes(search_date):
while not container_passes(search_date).result:
# Test failed on `search_date`, go further into the past
earliest_failure = search_date
new_search_date = adjust(
Expand Down Expand Up @@ -155,7 +177,7 @@ def container_search(
if range_mid is None:
# It wasn't possible to refine further.
break
result = container_passes(range_mid)
result = container_passes(range_mid).result
if result:
range_start = range_mid
else:
Expand Down
48 changes: 30 additions & 18 deletions .github/triage/jax_toolbox_triage/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from .args import parse_args
from .docker import DockerContainer
from .logic import commit_search, container_search
from .logic import commit_search, container_search, TestResult
from .utils import (
container_exists as container_exists_base,
container_url as container_url_base,
Expand All @@ -21,6 +21,10 @@ def main():
args = parse_args()
bazel_cache_mounts = prepare_bazel_cache_mounts(args.bazel_cache)
logger = get_logger(args.output_prefix)
logger.info(
"Verbose output, including stdout/err of triage commands, will be written to "
f'{(args.output_prefix / "debug.log").resolve()}'
)
container_url = functools.partial(container_url_base, container=args.container)
container_exists = functools.partial(
container_exists_base, container=args.container, logger=logger
Expand Down Expand Up @@ -75,7 +79,7 @@ def get_commit(container: DockerContainer, repo: str) -> typing.Tuple[str, str]:
f"Could not extract commit of {repo} from {args.container} container {container}"
)

def check_container(date: datetime.date) -> bool:
def check_container(date: datetime.date) -> TestResult:
"""
See if the test passes in the given container.
"""
Expand All @@ -100,37 +104,45 @@ def check_container(date: datetime.date) -> bool:
"xla": xla_commit,
},
)
return test_pass

# Search through the published containers, narrowing down to a pair of dates with
# the property that the test passed on `range_start` and fails on `range_end`.
range_start, range_end = container_search(
container_exists=container_exists,
container_passes=check_container,
start_date=args.start_date,
end_date=args.end_date,
logger=logger,
skip_precondition_checks=args.skip_precondition_checks,
threshold_days=args.threshold_days,
)
return TestResult(result=test_pass, stdout=result.stdout, stderr=result.stderr)

if args.passing_container is not None:
assert args.failing_container is not None
# Skip the container-level search because explicit end points were given
passing_url = args.passing_container
failing_url = args.failing_container
else:
# Search through the published containers, narrowing down to a pair of dates with
# the property that the test passed on `range_start` and fails on `range_end`.
range_start, range_end = container_search(
container_exists=container_exists,
container_passes=check_container,
start_date=args.start_date,
end_date=args.end_date,
logger=logger,
skip_precondition_checks=args.skip_precondition_checks,
threshold_days=args.threshold_days,
)
passing_url = container_url(range_start)
failing_url = container_url(range_end)

# Container-level search is now complete. Triage proceeds inside the `range_end``
# container. First, we check that rewinding JAX and XLA inside the `range_end``
# container to the commits used in the `range_start` container passes, whereas
# using the `range_end` commits reproduces the failure.

with Container(container_url(range_start)) as worker:
with Container(passing_url) as worker:
start_jax_commit, _ = get_commit(worker, "jax")
start_xla_commit, _ = get_commit(worker, "xla")

# Fire up the container that will be used for the fine search.
with Container(container_url(range_end)) as worker:
with Container(failing_url) as worker:
end_jax_commit, jax_dir = get_commit(worker, "jax")
end_xla_commit, xla_dir = get_commit(worker, "xla")
logger.info(
(
f"Bisecting JAX [{start_jax_commit}, {end_jax_commit}] and "
f"XLA [{start_xla_commit}, {end_xla_commit}] using {container_url(range_end)}"
f"XLA [{start_xla_commit}, {end_xla_commit}] using {failing_url}"
)
)

Expand Down
62 changes: 62 additions & 0 deletions .github/triage/tests/test_arg_parsing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import pytest
from jax_toolbox_triage.args import parse_args

test_command = ["my-test-command"]
valid_start_end_container = [
"--passing-container",
"passing-url",
"--failing-container",
"failing-url",
]
valid_start_end_date_args = [
["--container", "jax"],
["--container", "jax", "--start-date", "2024-10-02"],
["--container", "jax", "--end-date", "2024-10-02"],
["--container", "jax", "--start-date", "2024-10-01", "--end-date", "2024-10-02"],
]


@pytest.mark.parametrize(
"good_args", [valid_start_end_container] + valid_start_end_date_args
)
def test_good_container_args(good_args):
args = parse_args(good_args + test_command)
assert args.test_command == test_command


@pytest.mark.parametrize("date_args", valid_start_end_date_args)
def test_bad_container_arg_combinations_across_groups(date_args):
# Can't combine --{start,end}-container with --container/--{start,end}-date
with pytest.raises(Exception):
parse_args(valid_start_end_container + date_args + test_command)


@pytest.mark.parametrize(
"container_args",
[
# Need --container
[],
["--start-date", "2024-10-01"],
["--end-date", "2024-10-02"],
["--start-date", "2024-10-01", "--end-date", "2024-10-02"],
# Need both if either is passed
["--passing-container", "passing-url"],
["--failing-container", "failing-url"],
],
)
def test_bad_container_arg_combinations_within_groups(container_args):
with pytest.raises(Exception):
parse_args(container_args + test_command)


@pytest.mark.parametrize(
"container_args",
[
# Need valid ISO dates
["--container", "jax", "--start-date", "a-blue-moon-ago"],
["--container", "jax", "--end-date", "a-year-ago-last-thursday"],
],
)
def test_unparsable_container_args(container_args):
with pytest.raises(SystemExit):
parse_args(container_args + test_command)
8 changes: 4 additions & 4 deletions .github/triage/tests/test_triage_logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import logging
import pytest
import random
from jax_toolbox_triage.logic import commit_search, container_search
from jax_toolbox_triage.logic import commit_search, container_search, TestResult


def wrap(b):
Expand Down Expand Up @@ -306,7 +306,7 @@ def test_container_search_limits(
with pytest.raises(Exception, match=match_string):
container_search(
container_exists=lambda dt: dt in dates_that_exist,
container_passes=lambda dt: False,
container_passes=lambda dt: TestResult(result=False),
start_date=start_date,
end_date=end_date,
logger=logger,
Expand Down Expand Up @@ -353,7 +353,7 @@ def test_container_search_checks(
with pytest.raises(Exception, match=match_string):
container_search(
container_exists=lambda dt: True,
container_passes=lambda dt: dt in dates_that_pass,
container_passes=lambda dt: TestResult(result=dt in dates_that_pass),
start_date=start_date,
end_date=end_date,
logger=logger,
Expand All @@ -374,7 +374,7 @@ def test_container_search(logger, start_date, days_of_failure, threshold_days):
assert start_date is None or threshold_date >= start_date
good_date, bad_date = container_search(
container_exists=lambda dt: True,
container_passes=lambda dt: dt < threshold_date,
container_passes=lambda dt: TestResult(result=dt < threshold_date),
start_date=start_date,
end_date=end_date,
logger=logger,
Expand Down
Loading

0 comments on commit 0e4e245

Please sign in to comment.