Skip to content

Commit

Permalink
Switch to a new thread-safe utility for catching warnings.
Browse files Browse the repository at this point in the history
The Python warnings.catch_warnings() functionality is not thread-safe (https://py-free-threading.github.io/porting/#the-warnings-module-is-not-thread-safe), so we cannot use it during tests that use free-threading. This change introduces a private warnings test helper (test_warning_util.py), which hooks the CPython warning infrastructure and uses it to implement thread-safe warnings infrastructure.

This requires a handful of small modifications to tests to remove direct uses of the warnings module. We also sadly have to delete one TPU test that checks for a warning raised on another thread; there's no easy way for us to catch that in a thread-safe way, but that test seems like overkill anyway.
  • Loading branch information
hawkinsp committed Jan 9, 2025
1 parent 640cb00 commit b06779b
Show file tree
Hide file tree
Showing 10 changed files with 303 additions and 72 deletions.
1 change: 1 addition & 0 deletions jax/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ py_library(
testonly = 1,
srcs = [
"_src/test_util.py",
"_src/test_warning_util.py",
],
visibility = [
":internal",
Expand Down
57 changes: 44 additions & 13 deletions jax/_src/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
import time
from typing import Any, TextIO
import unittest
import warnings
import zlib

from absl.testing import absltest
Expand All @@ -49,6 +48,7 @@
from jax._src import dtypes as _dtypes
from jax._src import lib as _jaxlib
from jax._src import monitoring
from jax._src import test_warning_util
from jax._src import xla_bridge
from jax._src import util
from jax._src import mesh as mesh_lib
Expand Down Expand Up @@ -118,7 +118,7 @@
)

TEST_NUM_THREADS = config.int_flag(
'jax_test_num_threads', 0,
'jax_test_num_threads', int(os.getenv('JAX_TEST_NUM_THREADS', '0')),
help='Number of threads to use for running tests. 0 means run everything '
'in the main thread. Using > 1 thread is experimental.'
)
Expand Down Expand Up @@ -1076,7 +1076,7 @@ def stopTest(self, test: unittest.TestCase):
with self.lock:
# We assume test_result is an ABSL _TextAndXMLTestResult, so we can
# override how it gets the time.
time_getter = self.test_result.time_getter
time_getter = getattr(self.test_result, "time_getter", None)
try:
self.test_result.time_getter = lambda: self.start_time
self.test_result.startTest(test)
Expand All @@ -1085,7 +1085,8 @@ def stopTest(self, test: unittest.TestCase):
self.test_result.time_getter = lambda: stop_time
self.test_result.stopTest(test)
finally:
self.test_result.time_getter = time_getter
if time_getter is not None:
self.test_result.time_getter = time_getter

def addSuccess(self, test: unittest.TestCase):
self.actions.append(lambda: self.test_result.addSuccess(test))
Expand Down Expand Up @@ -1120,6 +1121,8 @@ def run(self, result: unittest.TestResult, debug: bool = False) -> unittest.Test
if TEST_NUM_THREADS.value <= 0:
return super().run(result)

test_warning_util.install_threadsafe_warning_handlers()

executor = ThreadPoolExecutor(TEST_NUM_THREADS.value)
lock = threading.Lock()
futures = []
Expand Down Expand Up @@ -1368,11 +1371,44 @@ def assertMultiLineStrippedEqual(self, expected, what):
self.assertMultiLineEqual(expected_clean, what_clean,
msg=f"Found\n{what}\nExpecting\n{expected}")


@contextmanager
def assertNoWarnings(self):
with warnings.catch_warnings():
warnings.simplefilter("error")
with test_warning_util.raise_on_warnings():
yield

# We replace assertWarns and assertWarnsRegex with functions that use the
# thread-safe warning utilities. Unlike the unittest versions these only
# function as context managers.
@contextmanager
def assertWarns(self, warning, *, msg=None):
with test_warning_util.record_warnings() as ws:
yield
for w in ws:
if not isinstance(w.message, warning):
continue
if msg is not None and msg not in str(w.message):
continue
return
self.fail(f"Expected warning not found {warning}:'{msg}', got "
f"{ws}")

@contextmanager
def assertWarnsRegex(self, warning, regex):
if regex is not None:
regex = re.compile(regex)

with test_warning_util.record_warnings() as ws:
yield
for w in ws:
if not isinstance(w.message, warning):
continue
if regex is not None and not regex.search(str(w.message)):
continue
return
self.fail(f"Expected warning not found {warning}:'{regex}', got "
f"{ws}")


def _CompileAndCheck(self, fun, args_maker, *, check_dtypes=True, tol=None,
rtol=None, atol=None, check_cache_misses=True):
Expand Down Expand Up @@ -1449,11 +1485,7 @@ def assertNotDeleted(self, x):
self.assertFalse(x.is_deleted())


@contextmanager
def ignore_warning(*, message='', category=Warning, **kw):
with warnings.catch_warnings():
warnings.filterwarnings("ignore", message=message, category=category, **kw)
yield
ignore_warning = test_warning_util.ignore_warning

# -------------------- Mesh parametrization helpers --------------------

Expand Down Expand Up @@ -1768,9 +1800,8 @@ def make_axis_points(size):
logtiny = finfo.minexp / prec_dps_ratio
axis_points = np.zeros(3 + 2 * size, dtype=finfo.dtype)

with warnings.catch_warnings():
with ignore_warning(category=RuntimeWarning):
# Silence RuntimeWarning: overflow encountered in cast
warnings.simplefilter("ignore")
half_neg_line = -np.logspace(logmin, logtiny, size, dtype=finfo.dtype)
half_line = -half_neg_line[::-1]
axis_points[-size - 1:-1] = half_line
Expand Down
132 changes: 132 additions & 0 deletions jax/_src/test_warning_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Thread-safe utilities for catching and testing for warnings.
#
# The Python warnings module, at least as of Python 3.13, is not thread-safe.
# The catch_warnings() feature is inherently racy, see
# https://py-free-threading.github.io/porting/#the-warnings-module-is-not-thread-safe
#
# This module offers a thread-safe way to catch and record warnings. We install
# a custom showwarning hook with the Python warning module, and then rely on
# the CPython warnings module to call our show warning function. We then use it
# to create our own thread-safe warning filtering utilities.

import contextlib
import re
import threading
import warnings


class _WarningContext(threading.local):
"Thread-local state that contains a list of warning handlers."

def __init__(self):
self.handlers = []


_context = _WarningContext()


# Callback that applies the handlers in reverse order. If no handler matches,
# we raise an error.
def _showwarning(message, category, filename, lineno, file=None, line=None):
for handler in reversed(_context.handlers):
if handler(message, category, filename, lineno, file, line):
return
raise category(message)


@contextlib.contextmanager
def raise_on_warnings():
"Context manager that raises an exception if a warning is raised."
if warnings.showwarning is not _showwarning:
with warnings.catch_warnings():
warnings.simplefilter("error")
yield
return

def handler(message, category, filename, lineno, file=None, line=None):
raise category(message)

_context.handlers.append(handler)
try:
yield
finally:
_context.handlers.pop()


@contextlib.contextmanager
def record_warnings():
"Context manager that yields a list of warnings that are raised."
if warnings.showwarning is not _showwarning:
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
yield w
return

log = []

def handler(message, category, filename, lineno, file=None, line=None):
log.append(warnings.WarningMessage(message, category, filename, lineno, file, line))
return True

_context.handlers.append(handler)
try:
yield log
finally:
_context.handlers.pop()


@contextlib.contextmanager
def ignore_warning(*, message: str | None = None, category: type = Warning):
"Context manager that ignores any matching warnings."
if warnings.showwarning is not _showwarning:
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore", message="" if message is None else message, category=category)
yield
return

if message:
message_re = re.compile(message)
else:
message_re = None

category_cls = category

def handler(message, category, filename, lineno, file=None, line=None):
text = str(message) if isinstance(message, Warning) else message
if (message_re is None or message_re.match(text)) and issubclass(
category, category_cls
):
return True
return False

_context.handlers.append(handler)
try:
yield
finally:
_context.handlers.pop()


def install_threadsafe_warning_handlers():
# Hook the showwarning method. The warnings module explicitly notes that
# this is a function that users may replace.
warnings.showwarning = _showwarning

# Set the warnings module to always display warnings. We hook into it by
# overriding the "showwarning" method, so it's important that all warnings
# are "shown" by the usual mechanism.
warnings.simplefilter("always")
8 changes: 8 additions & 0 deletions tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1153,6 +1153,14 @@ jax_py_test(
],
)

jax_py_test(
name = "warnings_util_test",
srcs = ["warnings_util_test.py"],
deps = [
"//jax:test_util",
] + py_deps("absl/testing"),
)

jax_py_test(
name = "xla_bridge_test",
srcs = ["xla_bridge_test.py"],
Expand Down
48 changes: 23 additions & 25 deletions tests/compilation_cache_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import unittest
from unittest import mock
from unittest import SkipTest
import warnings

from absl.testing import absltest
from absl.testing import parameterized
Expand All @@ -39,6 +38,7 @@
from jax._src import monitoring
from jax._src import path as pathlib
from jax._src import test_util as jtu
from jax._src import test_warning_util
from jax._src import xla_bridge
from jax._src.compilation_cache_interface import CacheInterface
from jax._src.lib import xla_client as xc
Expand Down Expand Up @@ -232,21 +232,20 @@ def test_cache_write_warning(self):
with (
config.raise_persistent_cache_errors(False),
mock.patch.object(cc._get_cache(backend).__class__, "put") as mock_put,
warnings.catch_warnings(record=True) as w,
test_warning_util.record_warnings() as w,
):
warnings.simplefilter("always")
mock_put.side_effect = RuntimeError("test error")
self.assertEqual(f(2).item(), 4)
if len(w) != 1:
print("Warnings:", [str(w_) for w_ in w], flush=True)
self.assertLen(w, 1)
self.assertIn(
(
"Error writing persistent compilation cache entry "
"for 'jit__lambda_': RuntimeError: test error"
),
str(w[0].message),
)
if len(w) != 1:
print("Warnings:", [str(w_) for w_ in w], flush=True)
self.assertLen(w, 1)
self.assertIn(
(
"Error writing persistent compilation cache entry "
"for 'jit__lambda_': RuntimeError: test error"
),
str(w[0].message),
)

def test_cache_read_warning(self):
f = jit(lambda x: x * x)
Expand All @@ -255,23 +254,22 @@ def test_cache_read_warning(self):
with (
config.raise_persistent_cache_errors(False),
mock.patch.object(cc._get_cache(backend).__class__, "get") as mock_get,
warnings.catch_warnings(record=True) as w,
test_warning_util.record_warnings() as w,
):
warnings.simplefilter("always")
mock_get.side_effect = RuntimeError("test error")
# Calling assertEqual with the jitted f will generate two PJIT
# executables: Equal and the lambda function itself.
self.assertEqual(f(2).item(), 4)
if len(w) != 1:
print("Warnings:", [str(w_) for w_ in w], flush=True)
self.assertLen(w, 1)
self.assertIn(
(
"Error reading persistent compilation cache entry "
"for 'jit__lambda_': RuntimeError: test error"
),
str(w[0].message),
)
if len(w) != 1:
print("Warnings:", [str(w_) for w_ in w], flush=True)
self.assertLen(w, 1)
self.assertIn(
(
"Error reading persistent compilation cache entry "
"for 'jit__lambda_': RuntimeError: test error"
),
str(w[0].message),
)

def test_min_entry_size(self):
with (
Expand Down
6 changes: 2 additions & 4 deletions tests/deprecation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import warnings

from absl.testing import absltest
from jax._src import deprecations
from jax._src import test_util as jtu
from jax._src import test_warning_util
from jax._src.internal_test_util import deprecation_module as m

class DeprecationTest(absltest.TestCase):

def testModuleDeprecation(self):
with warnings.catch_warnings():
warnings.simplefilter("error")
with test_warning_util.raise_on_warnings():
self.assertEqual(m.x, 42)

with self.assertWarnsRegex(DeprecationWarning, "Please use x"):
Expand Down
2 changes: 1 addition & 1 deletion tests/lax_numpy_reducers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def testReducer(self, name, rng_factory, shape, dtype, out_dtype,
rng = rng_factory(self.rng())
@jtu.ignore_warning(category=NumpyComplexWarning)
@jtu.ignore_warning(category=RuntimeWarning,
message="mean of empty slice.*")
message="Mean of empty slice.*")
@jtu.ignore_warning(category=RuntimeWarning,
message="overflow encountered.*")
def np_fun(x):
Expand Down
Loading

0 comments on commit b06779b

Please sign in to comment.