diff --git a/jax/BUILD b/jax/BUILD index 044c2ab0104f..54997d4276ee 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -120,6 +120,7 @@ py_library( testonly = 1, srcs = [ "_src/test_util.py", + "_src/test_warning_util.py", ], visibility = [ ":internal", diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index 7ad4f413c047..eee288d1b48f 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -35,7 +35,6 @@ import time from typing import Any, TextIO import unittest -import warnings import zlib from absl.testing import absltest @@ -49,6 +48,7 @@ from jax._src import dtypes as _dtypes from jax._src import lib as _jaxlib from jax._src import monitoring +from jax._src import test_warning_util from jax._src import xla_bridge from jax._src import util from jax._src import mesh as mesh_lib @@ -118,7 +118,7 @@ ) TEST_NUM_THREADS = config.int_flag( - 'jax_test_num_threads', 0, + 'jax_test_num_threads', int(os.getenv('JAX_TEST_NUM_THREADS', '0')), help='Number of threads to use for running tests. 0 means run everything ' 'in the main thread. Using > 1 thread is experimental.' ) @@ -1076,7 +1076,7 @@ def stopTest(self, test: unittest.TestCase): with self.lock: # We assume test_result is an ABSL _TextAndXMLTestResult, so we can # override how it gets the time. - time_getter = self.test_result.time_getter + time_getter = getattr(self.test_result, "time_getter", None) try: self.test_result.time_getter = lambda: self.start_time self.test_result.startTest(test) @@ -1085,7 +1085,8 @@ def stopTest(self, test: unittest.TestCase): self.test_result.time_getter = lambda: stop_time self.test_result.stopTest(test) finally: - self.test_result.time_getter = time_getter + if time_getter is not None: + self.test_result.time_getter = time_getter def addSuccess(self, test: unittest.TestCase): self.actions.append(lambda: self.test_result.addSuccess(test)) @@ -1120,6 +1121,8 @@ def run(self, result: unittest.TestResult, debug: bool = False) -> unittest.Test if TEST_NUM_THREADS.value <= 0: return super().run(result) + test_warning_util.install_threadsafe_warning_handlers() + executor = ThreadPoolExecutor(TEST_NUM_THREADS.value) lock = threading.Lock() futures = [] @@ -1368,11 +1371,44 @@ def assertMultiLineStrippedEqual(self, expected, what): self.assertMultiLineEqual(expected_clean, what_clean, msg=f"Found\n{what}\nExpecting\n{expected}") + @contextmanager def assertNoWarnings(self): - with warnings.catch_warnings(): - warnings.simplefilter("error") + with test_warning_util.raise_on_warnings(): + yield + + # We replace assertWarns and assertWarnsRegex with functions that use the + # thread-safe warning utilities. Unlike the unittest versions these only + # function as context managers. + @contextmanager + def assertWarns(self, warning, *, msg=None): + with test_warning_util.record_warnings() as ws: + yield + for w in ws: + if not isinstance(w.message, warning): + continue + if msg is not None and msg not in str(w.message): + continue + return + self.fail(f"Expected warning not found {warning}:'{msg}', got " + f"{ws}") + + @contextmanager + def assertWarnsRegex(self, warning, regex): + if regex is not None: + regex = re.compile(regex) + + with test_warning_util.record_warnings() as ws: yield + for w in ws: + if not isinstance(w.message, warning): + continue + if regex is not None and not regex.search(str(w.message)): + continue + return + self.fail(f"Expected warning not found {warning}:'{regex}', got " + f"{ws}") + def _CompileAndCheck(self, fun, args_maker, *, check_dtypes=True, tol=None, rtol=None, atol=None, check_cache_misses=True): @@ -1449,11 +1485,7 @@ def assertNotDeleted(self, x): self.assertFalse(x.is_deleted()) -@contextmanager -def ignore_warning(*, message='', category=Warning, **kw): - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", message=message, category=category, **kw) - yield +ignore_warning = test_warning_util.ignore_warning # -------------------- Mesh parametrization helpers -------------------- @@ -1768,9 +1800,8 @@ def make_axis_points(size): logtiny = finfo.minexp / prec_dps_ratio axis_points = np.zeros(3 + 2 * size, dtype=finfo.dtype) - with warnings.catch_warnings(): + with ignore_warning(category=RuntimeWarning): # Silence RuntimeWarning: overflow encountered in cast - warnings.simplefilter("ignore") half_neg_line = -np.logspace(logmin, logtiny, size, dtype=finfo.dtype) half_line = -half_neg_line[::-1] axis_points[-size - 1:-1] = half_line diff --git a/jax/_src/test_warning_util.py b/jax/_src/test_warning_util.py new file mode 100644 index 000000000000..b41fe3ce3b15 --- /dev/null +++ b/jax/_src/test_warning_util.py @@ -0,0 +1,132 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Thread-safe utilities for catching and testing for warnings. +# +# The Python warnings module, at least as of Python 3.13, is not thread-safe. +# The catch_warnings() feature is inherently racy, see +# https://py-free-threading.github.io/porting/#the-warnings-module-is-not-thread-safe +# +# This module offers a thread-safe way to catch and record warnings. We install +# a custom showwarning hook with the Python warning module, and then rely on +# the CPython warnings module to call our show warning function. We then use it +# to create our own thread-safe warning filtering utilities. + +import contextlib +import re +import threading +import warnings + + +class _WarningContext(threading.local): + "Thread-local state that contains a list of warning handlers." + + def __init__(self): + self.handlers = [] + + +_context = _WarningContext() + + +# Callback that applies the handlers in reverse order. If no handler matches, +# we raise an error. +def _showwarning(message, category, filename, lineno, file=None, line=None): + for handler in reversed(_context.handlers): + if handler(message, category, filename, lineno, file, line): + return + raise category(message) + + +@contextlib.contextmanager +def raise_on_warnings(): + "Context manager that raises an exception if a warning is raised." + if warnings.showwarning is not _showwarning: + with warnings.catch_warnings(): + warnings.simplefilter("error") + yield + return + + def handler(message, category, filename, lineno, file=None, line=None): + raise category(message) + + _context.handlers.append(handler) + try: + yield + finally: + _context.handlers.pop() + + +@contextlib.contextmanager +def record_warnings(): + "Context manager that yields a list of warnings that are raised." + if warnings.showwarning is not _showwarning: + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + yield w + return + + log = [] + + def handler(message, category, filename, lineno, file=None, line=None): + log.append(warnings.WarningMessage(message, category, filename, lineno, file, line)) + return True + + _context.handlers.append(handler) + try: + yield log + finally: + _context.handlers.pop() + + +@contextlib.contextmanager +def ignore_warning(*, message: str | None = None, category: type = Warning): + "Context manager that ignores any matching warnings." + if warnings.showwarning is not _showwarning: + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", message="" if message is None else message, category=category) + yield + return + + if message: + message_re = re.compile(message) + else: + message_re = None + + category_cls = category + + def handler(message, category, filename, lineno, file=None, line=None): + text = str(message) if isinstance(message, Warning) else message + if (message_re is None or message_re.match(text)) and issubclass( + category, category_cls + ): + return True + return False + + _context.handlers.append(handler) + try: + yield + finally: + _context.handlers.pop() + + +def install_threadsafe_warning_handlers(): + # Hook the showwarning method. The warnings module explicitly notes that + # this is a function that users may replace. + warnings.showwarning = _showwarning + + # Set the warnings module to always display warnings. We hook into it by + # overriding the "showwarning" method, so it's important that all warnings + # are "shown" by the usual mechanism. + warnings.simplefilter("always") diff --git a/tests/BUILD b/tests/BUILD index 4868bcf75e2e..126ca33275b8 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -1153,6 +1153,14 @@ jax_py_test( ], ) +jax_py_test( + name = "warnings_util_test", + srcs = ["warnings_util_test.py"], + deps = [ + "//jax:test_util", + ] + py_deps("absl/testing"), +) + jax_py_test( name = "xla_bridge_test", srcs = ["xla_bridge_test.py"], diff --git a/tests/compilation_cache_test.py b/tests/compilation_cache_test.py index 73d76c1a4938..27ebab88715c 100644 --- a/tests/compilation_cache_test.py +++ b/tests/compilation_cache_test.py @@ -23,7 +23,6 @@ import unittest from unittest import mock from unittest import SkipTest -import warnings from absl.testing import absltest from absl.testing import parameterized @@ -39,6 +38,7 @@ from jax._src import monitoring from jax._src import path as pathlib from jax._src import test_util as jtu +from jax._src import test_warning_util from jax._src import xla_bridge from jax._src.compilation_cache_interface import CacheInterface from jax._src.lib import xla_client as xc @@ -232,21 +232,20 @@ def test_cache_write_warning(self): with ( config.raise_persistent_cache_errors(False), mock.patch.object(cc._get_cache(backend).__class__, "put") as mock_put, - warnings.catch_warnings(record=True) as w, + test_warning_util.record_warnings() as w, ): - warnings.simplefilter("always") mock_put.side_effect = RuntimeError("test error") self.assertEqual(f(2).item(), 4) - if len(w) != 1: - print("Warnings:", [str(w_) for w_ in w], flush=True) - self.assertLen(w, 1) - self.assertIn( - ( - "Error writing persistent compilation cache entry " - "for 'jit__lambda_': RuntimeError: test error" - ), - str(w[0].message), - ) + if len(w) != 1: + print("Warnings:", [str(w_) for w_ in w], flush=True) + self.assertLen(w, 1) + self.assertIn( + ( + "Error writing persistent compilation cache entry " + "for 'jit__lambda_': RuntimeError: test error" + ), + str(w[0].message), + ) def test_cache_read_warning(self): f = jit(lambda x: x * x) @@ -255,23 +254,22 @@ def test_cache_read_warning(self): with ( config.raise_persistent_cache_errors(False), mock.patch.object(cc._get_cache(backend).__class__, "get") as mock_get, - warnings.catch_warnings(record=True) as w, + test_warning_util.record_warnings() as w, ): - warnings.simplefilter("always") mock_get.side_effect = RuntimeError("test error") # Calling assertEqual with the jitted f will generate two PJIT # executables: Equal and the lambda function itself. self.assertEqual(f(2).item(), 4) - if len(w) != 1: - print("Warnings:", [str(w_) for w_ in w], flush=True) - self.assertLen(w, 1) - self.assertIn( - ( - "Error reading persistent compilation cache entry " - "for 'jit__lambda_': RuntimeError: test error" - ), - str(w[0].message), - ) + if len(w) != 1: + print("Warnings:", [str(w_) for w_ in w], flush=True) + self.assertLen(w, 1) + self.assertIn( + ( + "Error reading persistent compilation cache entry " + "for 'jit__lambda_': RuntimeError: test error" + ), + str(w[0].message), + ) def test_min_entry_size(self): with ( diff --git a/tests/deprecation_test.py b/tests/deprecation_test.py index 382aed3c6717..f9313449ac13 100644 --- a/tests/deprecation_test.py +++ b/tests/deprecation_test.py @@ -12,18 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -import warnings - from absl.testing import absltest from jax._src import deprecations from jax._src import test_util as jtu +from jax._src import test_warning_util from jax._src.internal_test_util import deprecation_module as m class DeprecationTest(absltest.TestCase): def testModuleDeprecation(self): - with warnings.catch_warnings(): - warnings.simplefilter("error") + with test_warning_util.raise_on_warnings(): self.assertEqual(m.x, 42) with self.assertWarnsRegex(DeprecationWarning, "Please use x"): diff --git a/tests/lax_numpy_reducers_test.py b/tests/lax_numpy_reducers_test.py index 60194093e742..f4667df850f4 100644 --- a/tests/lax_numpy_reducers_test.py +++ b/tests/lax_numpy_reducers_test.py @@ -212,7 +212,7 @@ def testReducer(self, name, rng_factory, shape, dtype, out_dtype, rng = rng_factory(self.rng()) @jtu.ignore_warning(category=NumpyComplexWarning) @jtu.ignore_warning(category=RuntimeWarning, - message="mean of empty slice.*") + message="Mean of empty slice.*") @jtu.ignore_warning(category=RuntimeWarning, message="overflow encountered.*") def np_fun(x): diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index d507761abba7..c9b9779f02ea 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -5581,8 +5581,12 @@ def testDisableNumpyRankPromotionBroadcasting(self): jnp.ones(2) + 3 # don't want to raise for scalars with jax.numpy_rank_promotion('warn'): - self.assertWarnsRegex(UserWarning, "Following NumPy automatic rank promotion for add on " - r"shapes \(2,\) \(1, 2\).*", lambda: jnp.ones(2) + jnp.ones((1, 2))) + with self.assertWarnsRegex( + UserWarning, + "Following NumPy automatic rank promotion for add on shapes " + r"\(2,\) \(1, 2\).*" + ): + jnp.ones(2) + jnp.ones((1, 2)) jnp.ones(2) + 3 # don't want to warn for scalars @unittest.skip("Test fails on CI, perhaps due to JIT caching") diff --git a/tests/warnings_util_test.py b/tests/warnings_util_test.py new file mode 100644 index 000000000000..92f23bc784e2 --- /dev/null +++ b/tests/warnings_util_test.py @@ -0,0 +1,86 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings + +from absl.testing import absltest + +from jax._src import config +from jax._src import test_warning_util +from jax._src import test_util as jtu + + +config.parse_flags_with_absl() + +class WarningsUtilTest(jtu.JaxTestCase): + + @test_warning_util.raise_on_warnings() + def test_warning_raises(self): + with self.assertRaises(UserWarning, msg="hello"): + warnings.warn("hello", category=UserWarning) + + with self.assertRaises(DeprecationWarning, msg="hello"): + warnings.warn("hello", category=DeprecationWarning) + + @test_warning_util.raise_on_warnings() + def test_ignore_warning(self): + with test_warning_util.ignore_warning(message="h.*o"): + warnings.warn("hello", category=UserWarning) + + with self.assertRaises(UserWarning, msg="hello"): + with test_warning_util.ignore_warning(message="h.*o"): + warnings.warn("goodbye", category=UserWarning) + + with test_warning_util.ignore_warning(category=UserWarning): + warnings.warn("hello", category=UserWarning) + + with self.assertRaises(UserWarning, msg="hello"): + with test_warning_util.ignore_warning(category=DeprecationWarning): + warnings.warn("goodbye", category=UserWarning) + + def test_record_warning(self): + with test_warning_util.record_warnings() as w: + warnings.warn("hello", category=UserWarning) + warnings.warn("goodbye", category=DeprecationWarning) + self.assertLen(w, 2) + self.assertIs(w[0].category, UserWarning) + self.assertIn("hello", str(w[0].message)) + self.assertIs(w[1].category, DeprecationWarning) + self.assertIn("goodbye", str(w[1].message)) + + def test_record_warning_nested(self): + with test_warning_util.record_warnings() as w: + warnings.warn("aa", category=UserWarning) + with test_warning_util.record_warnings() as v: + warnings.warn("bb", category=UserWarning) + warnings.warn("cc", category=DeprecationWarning) + self.assertLen(w, 2) + self.assertIs(w[0].category, UserWarning) + self.assertIn("aa", str(w[0].message)) + self.assertIs(w[1].category, DeprecationWarning) + self.assertIn("cc", str(w[1].message)) + self.assertLen(v, 1) + self.assertIs(v[0].category, UserWarning) + self.assertIn("bb", str(v[0].message)) + + + def test_raises_warning(self): + with self.assertRaises(UserWarning, msg="hello"): + with test_warning_util.ignore_warning(): + with test_warning_util.raise_on_warnings(): + warnings.warn("hello", category=UserWarning) + + +if __name__ == '__main__': + absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/xla_bridge_test.py b/tests/xla_bridge_test.py index 9c64734645d2..509a4244dca7 100644 --- a/tests/xla_bridge_test.py +++ b/tests/xla_bridge_test.py @@ -14,8 +14,6 @@ import os import platform -import time -import warnings from absl import logging from absl.testing import absltest @@ -126,31 +124,6 @@ def test_local_devices(self): with self.assertRaisesRegex(RuntimeError, "Unknown backend foo"): xb.local_devices(backend="foo") - def test_timer_tpu_warning(self): - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter("always") - - def _mock_tpu_client_with_options(library_path=None, options=None): - time_to_wait = 5 - start = time.time() - while not w: - if time.time() - start > time_to_wait: - raise ValueError( - "This test should not hang for more than " - f"{time_to_wait} seconds.") - time.sleep(0.1) - - self.assertLen(w, 1) - msg = str(w[-1].message) - self.assertIn("Did you run your code on all TPU hosts?", msg) - - def _mock_tpu_client(library_path=None): - _mock_tpu_client_with_options(library_path=library_path, options=None) - - with mock.patch.object(xc, "make_tpu_client", - side_effect=_mock_tpu_client_with_options): - xb.tpu_client_timer_callback(0.01) - def test_register_plugin(self): with self.assertLogs(level="WARNING") as log_output: with mock.patch.object(xc, "load_pjrt_plugin_dynamically", autospec=True):