From da892d11e8ebae565ddc56a985715b4e2d825d38 Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Mon, 20 Jan 2025 15:25:22 -0500 Subject: [PATCH 1/6] Raise runtime warning if jax > 0.4.28 is installed --- pennylane/__init__.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/pennylane/__init__.py b/pennylane/__init__.py index 6e8a8fe903b..3f1ca33be40 100644 --- a/pennylane/__init__.py +++ b/pennylane/__init__.py @@ -159,6 +159,22 @@ import pennylane.spin +from importlib.util import find_spec +from packaging.version import Version +from warnings import warn + +has_jax = find_spec("jax") + +if has_jax: # pragma: no cover + from jax import __version__ as jax_version + + if Version(jax_version) > Version("0.4.28"): + warn( + "PennyLane is currently not compatible with versions of JAX > 0.4.28. " + f"You have version {jax_version} installed.", + RuntimeWarning, + ) + # Look for an existing configuration file default_config = Configuration("config.toml") From 4ee58aee5ef5b1dc0d9eee564e3208472c99b0b0 Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Mon, 20 Jan 2025 16:07:40 -0500 Subject: [PATCH 2/6] Move warning validation from init to qnode initialization/execution --- doc/releases/changelog-dev.md | 5 +++++ pennylane/__init__.py | 16 ---------------- pennylane/workflow/execution.py | 6 +++++- pennylane/workflow/qnode.py | 5 ++++- pennylane/workflow/resolution.py | 21 +++++++++++++++++++++ 5 files changed, 35 insertions(+), 18 deletions(-) diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 89e21560387..e327a8e9dda 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -6,6 +6,10 @@

Improvements 🛠

+* A `RuntimeWarning` is now raised by `QNode` if `interface="jax"` and the installed version of JAX + is greater than `0.4.28`. + [(#6864)](https://github.com/PennyLaneAI/pennylane/pull/6864) + * `QNode` objects now have an `update` method that allows for re-configuring settings like `diff_method`, `mcm_method`, and more. This allows for easier on-the-fly adjustments to workflows. Any arguments not specified will retain their original value. [(#6803)](https://github.com/PennyLaneAI/pennylane/pull/6803) @@ -105,4 +109,5 @@ Diksha Dhawan, Pietropaolo Frisoni, Marcus Gisslén, Christina Lee, +Mudit Pandey, Andrija Paurevic diff --git a/pennylane/__init__.py b/pennylane/__init__.py index 3f1ca33be40..6e8a8fe903b 100644 --- a/pennylane/__init__.py +++ b/pennylane/__init__.py @@ -159,22 +159,6 @@ import pennylane.spin -from importlib.util import find_spec -from packaging.version import Version -from warnings import warn - -has_jax = find_spec("jax") - -if has_jax: # pragma: no cover - from jax import __version__ as jax_version - - if Version(jax_version) > Version("0.4.28"): - warn( - "PennyLane is currently not compatible with versions of JAX > 0.4.28. " - f"You have version {jax_version} installed.", - RuntimeWarning, - ) - # Look for an existing configuration file default_config = Configuration("config.toml") diff --git a/pennylane/workflow/execution.py b/pennylane/workflow/execution.py index cfe93f322d2..620581e710d 100644 --- a/pennylane/workflow/execution.py +++ b/pennylane/workflow/execution.py @@ -33,7 +33,7 @@ from pennylane.workflow.resolution import SupportedDiffMethods from ._setup_transform_program import _setup_transform_program -from .resolution import _resolve_execution_config, _resolve_interface +from .resolution import _resolve_execution_config, _resolve_interface, _validate_jax_version from .run import run logger = logging.getLogger(__name__) @@ -201,9 +201,13 @@ def cost_fn(params, x): ### Specifying and preprocessing variables #### + old_interface = interface interface = _resolve_interface(interface, tapes) # Only need to calculate derivatives with jax when we know it will be executed later. + if old_interface in (None, "auto") and interface in (Interface.JAX, Interface.JAX_JIT): + _validate_jax_version() + config = qml.devices.ExecutionConfig( interface=interface, gradient_method=diff_method, diff --git a/pennylane/workflow/qnode.py b/pennylane/workflow/qnode.py index 29430c9a645..939df58c6ae 100644 --- a/pennylane/workflow/qnode.py +++ b/pennylane/workflow/qnode.py @@ -32,7 +32,7 @@ from pennylane.tape import QuantumScript from pennylane.transforms.core import TransformContainer, TransformDispatcher, TransformProgram -from .resolution import SupportedDiffMethods +from .resolution import SupportedDiffMethods, _validate_jax_version logger = logging.getLogger(__name__) logger.addHandler(logging.NullHandler()) @@ -552,6 +552,9 @@ def __init__( self.func = func self.device = device self._interface = get_canonical_interface_name(interface) + if self._interface in (Interface.JAX, Interface.JAX_JIT): + _validate_jax_version() + self.diff_method = diff_method mcm_config = qml.devices.MCMConfig(mcm_method=mcm_method, postselect_mode=postselect_mode) cache = (max_diff > 1) if cache == "auto" else cache diff --git a/pennylane/workflow/resolution.py b/pennylane/workflow/resolution.py index 77d85aae9b3..21c48bb57b9 100644 --- a/pennylane/workflow/resolution.py +++ b/pennylane/workflow/resolution.py @@ -16,7 +16,12 @@ """ from collections.abc import Callable from dataclasses import replace +from importlib.metadata import version +from importlib.util import find_spec from typing import Literal, Optional, Union, get_args +from warnings import warn + +from packaging.version import Version import pennylane as qml from pennylane.logging import debug_logger @@ -71,6 +76,22 @@ def _get_jax_interface_name(tapes): return Interface.JAX +def _validate_jax_version() -> None: + """Checks if the installed version of JAX is supported. If an unsupported version of + JAX is installed, a ``RuntimeWarning`` is raised.""" + if not find_spec("jax"): + return + + jax_version = version("jax") + + if Version(jax_version) > Version("0.4.28"): # pragma: no cover + warn( + "PennyLane is currently not compatible with versions of JAX > 0.4.28. " + f"You have version {jax_version} installed.", + RuntimeWarning, + ) + + # pylint: disable=import-outside-toplevel def _use_tensorflow_autograph(): """Checks if TensorFlow is in graph mode, allowing Autograph for optimized execution""" From 5af09854a9e8775e3ec4c8c4f21b197bfc421d86 Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Mon, 20 Jan 2025 16:44:31 -0500 Subject: [PATCH 3/6] Add pragma no cover --- pennylane/workflow/execution.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/pennylane/workflow/execution.py b/pennylane/workflow/execution.py index 620581e710d..66765f8a20d 100644 --- a/pennylane/workflow/execution.py +++ b/pennylane/workflow/execution.py @@ -205,7 +205,10 @@ def cost_fn(params, x): interface = _resolve_interface(interface, tapes) # Only need to calculate derivatives with jax when we know it will be executed later. - if old_interface in (None, "auto") and interface in (Interface.JAX, Interface.JAX_JIT): + if old_interface in (None, "auto") and interface in ( + Interface.JAX, + Interface.JAX_JIT, + ): # pragma: no cover _validate_jax_version() config = qml.devices.ExecutionConfig( From 62a97a4d2b58963632514b5b6e05722876d3bd84 Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Tue, 21 Jan 2025 12:23:08 -0500 Subject: [PATCH 4/6] Refactor where validation happens --- pennylane/math/interface_utils.py | 41 +++++++++++++++++++++++++++---- pennylane/workflow/execution.py | 9 +------ pennylane/workflow/qnode.py | 7 ++---- pennylane/workflow/resolution.py | 25 ++----------------- 4 files changed, 41 insertions(+), 41 deletions(-) diff --git a/pennylane/math/interface_utils.py b/pennylane/math/interface_utils.py index 4fc401b9fd6..9e5edce1e75 100644 --- a/pennylane/math/interface_utils.py +++ b/pennylane/math/interface_utils.py @@ -15,9 +15,13 @@ import warnings from enum import Enum +from importlib.metadata import version +from importlib.util import find_spec from typing import Literal, Union +from warnings import warn import autoray as ar +from packaging.version import Version class Interface(Enum): @@ -212,23 +216,50 @@ def get_deep_interface(value): return _get_interface_of_single_tensor(itr) -def get_canonical_interface_name(user_input: InterfaceLike) -> Interface: +def _check_supported_jax() -> None: + """Checks if the installed version of JAX is supported. If an unsupported version of + JAX is installed, a ``RuntimeWarning`` is raised.""" + if not find_spec("jax"): + return + + jax_version = version("jax") + if Version(jax_version) > Version("0.4.28"): # pragma: no cover + warn( + "PennyLane is currently not compatible with versions of JAX > 0.4.28. " + f"You have version {jax_version} installed.", + RuntimeWarning, + ) + + +def get_canonical_interface_name( + user_input: InterfaceLike, _validate_jax_version=False +) -> Interface: """Helper function to get the canonical interface name. Args: - interface (str, Interface): reference interface + interface (str, Interface): Reference interface + _validate_jax_version (bool): Whether we should check if a supported version of + JAX is installed. If ``True``, a ``RuntimeWarning`` is raised if the canonical + interface is found to be JAX. ``False`` by default. Raises: - ValueError: key does not exist in the interface map + ValueError: Key does not exist in the interface map Returns: - Interface: canonical interface + Interface: Canonical interface """ if user_input in SUPPORTED_INTERFACE_NAMES: + if _validate_jax_version and user_input in (Interface.JAX, Interface.JAX_JIT): + _check_supported_jax() + return user_input try: - return INTERFACE_MAP[user_input] + out = INTERFACE_MAP[user_input] + + if out in (Interface.JAX, Interface.JAX_JIT): + _check_supported_jax() + return out except KeyError as exc: raise ValueError( f"Unknown interface {user_input}. Interface must be one of {SUPPORTED_INTERFACE_NAMES}." diff --git a/pennylane/workflow/execution.py b/pennylane/workflow/execution.py index 66765f8a20d..cfe93f322d2 100644 --- a/pennylane/workflow/execution.py +++ b/pennylane/workflow/execution.py @@ -33,7 +33,7 @@ from pennylane.workflow.resolution import SupportedDiffMethods from ._setup_transform_program import _setup_transform_program -from .resolution import _resolve_execution_config, _resolve_interface, _validate_jax_version +from .resolution import _resolve_execution_config, _resolve_interface from .run import run logger = logging.getLogger(__name__) @@ -201,16 +201,9 @@ def cost_fn(params, x): ### Specifying and preprocessing variables #### - old_interface = interface interface = _resolve_interface(interface, tapes) # Only need to calculate derivatives with jax when we know it will be executed later. - if old_interface in (None, "auto") and interface in ( - Interface.JAX, - Interface.JAX_JIT, - ): # pragma: no cover - _validate_jax_version() - config = qml.devices.ExecutionConfig( interface=interface, gradient_method=diff_method, diff --git a/pennylane/workflow/qnode.py b/pennylane/workflow/qnode.py index 939df58c6ae..b448db0b557 100644 --- a/pennylane/workflow/qnode.py +++ b/pennylane/workflow/qnode.py @@ -32,7 +32,7 @@ from pennylane.tape import QuantumScript from pennylane.transforms.core import TransformContainer, TransformDispatcher, TransformProgram -from .resolution import SupportedDiffMethods, _validate_jax_version +from .resolution import SupportedDiffMethods logger = logging.getLogger(__name__) logger.addHandler(logging.NullHandler()) @@ -551,10 +551,7 @@ def __init__( # input arguments self.func = func self.device = device - self._interface = get_canonical_interface_name(interface) - if self._interface in (Interface.JAX, Interface.JAX_JIT): - _validate_jax_version() - + self._interface = get_canonical_interface_name(interface, _validate_jax_version=True) self.diff_method = diff_method mcm_config = qml.devices.MCMConfig(mcm_method=mcm_method, postselect_mode=postselect_mode) cache = (max_diff > 1) if cache == "auto" else cache diff --git a/pennylane/workflow/resolution.py b/pennylane/workflow/resolution.py index 21c48bb57b9..1ddfa4366d4 100644 --- a/pennylane/workflow/resolution.py +++ b/pennylane/workflow/resolution.py @@ -16,12 +16,7 @@ """ from collections.abc import Callable from dataclasses import replace -from importlib.metadata import version -from importlib.util import find_spec from typing import Literal, Optional, Union, get_args -from warnings import warn - -from packaging.version import Version import pennylane as qml from pennylane.logging import debug_logger @@ -76,22 +71,6 @@ def _get_jax_interface_name(tapes): return Interface.JAX -def _validate_jax_version() -> None: - """Checks if the installed version of JAX is supported. If an unsupported version of - JAX is installed, a ``RuntimeWarning`` is raised.""" - if not find_spec("jax"): - return - - jax_version = version("jax") - - if Version(jax_version) > Version("0.4.28"): # pragma: no cover - warn( - "PennyLane is currently not compatible with versions of JAX > 0.4.28. " - f"You have version {jax_version} installed.", - RuntimeWarning, - ) - - # pylint: disable=import-outside-toplevel def _use_tensorflow_autograph(): """Checks if TensorFlow is in graph mode, allowing Autograph for optimized execution""" @@ -118,7 +97,7 @@ def _resolve_interface(interface: Union[str, Interface], tapes: QuantumScriptBat Interface: resolved interface """ - interface = get_canonical_interface_name(interface) + interface = get_canonical_interface_name(interface, _validate_jax_version=True) if interface == Interface.AUTO: params = [] @@ -127,7 +106,7 @@ def _resolve_interface(interface: Union[str, Interface], tapes: QuantumScriptBat interface = get_interface(*params) if interface != Interface.NUMPY: try: - interface = get_canonical_interface_name(interface) + interface = get_canonical_interface_name(interface, _validate_jax_version=True) except ValueError: interface = Interface.NUMPY if interface == Interface.TF and _use_tensorflow_autograph(): From 33fc899a596c9325681096bed744c01135432614 Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Tue, 21 Jan 2025 12:24:33 -0500 Subject: [PATCH 5/6] Fix changelog --- doc/releases/changelog-dev.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index e327a8e9dda..c36cdb93f80 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -6,7 +6,7 @@

Improvements 🛠

-* A `RuntimeWarning` is now raised by `QNode` if `interface="jax"` and the installed version of JAX +* A `RuntimeWarning` is now raised by `qml.QNode` and `qml.execute` if executing JAX workflows and the installed version of JAX is greater than `0.4.28`. [(#6864)](https://github.com/PennyLaneAI/pennylane/pull/6864) From 16b9c6b77c441c0cd0ee7c51ebe508e29b572a41 Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Thu, 30 Jan 2025 17:03:13 -0500 Subject: [PATCH 6/6] Address code review --- pennylane/math/interface_utils.py | 35 ++----------------------------- pennylane/workflow/qnode.py | 7 +++++-- pennylane/workflow/resolution.py | 28 +++++++++++++++++++++++-- 3 files changed, 33 insertions(+), 37 deletions(-) diff --git a/pennylane/math/interface_utils.py b/pennylane/math/interface_utils.py index 9e5edce1e75..ef37bc62b70 100644 --- a/pennylane/math/interface_utils.py +++ b/pennylane/math/interface_utils.py @@ -15,13 +15,9 @@ import warnings from enum import Enum -from importlib.metadata import version -from importlib.util import find_spec from typing import Literal, Union -from warnings import warn import autoray as ar -from packaging.version import Version class Interface(Enum): @@ -216,31 +212,11 @@ def get_deep_interface(value): return _get_interface_of_single_tensor(itr) -def _check_supported_jax() -> None: - """Checks if the installed version of JAX is supported. If an unsupported version of - JAX is installed, a ``RuntimeWarning`` is raised.""" - if not find_spec("jax"): - return - - jax_version = version("jax") - if Version(jax_version) > Version("0.4.28"): # pragma: no cover - warn( - "PennyLane is currently not compatible with versions of JAX > 0.4.28. " - f"You have version {jax_version} installed.", - RuntimeWarning, - ) - - -def get_canonical_interface_name( - user_input: InterfaceLike, _validate_jax_version=False -) -> Interface: +def get_canonical_interface_name(user_input: InterfaceLike) -> Interface: """Helper function to get the canonical interface name. Args: interface (str, Interface): Reference interface - _validate_jax_version (bool): Whether we should check if a supported version of - JAX is installed. If ``True``, a ``RuntimeWarning`` is raised if the canonical - interface is found to be JAX. ``False`` by default. Raises: ValueError: Key does not exist in the interface map @@ -250,16 +226,9 @@ def get_canonical_interface_name( """ if user_input in SUPPORTED_INTERFACE_NAMES: - if _validate_jax_version and user_input in (Interface.JAX, Interface.JAX_JIT): - _check_supported_jax() - return user_input try: - out = INTERFACE_MAP[user_input] - - if out in (Interface.JAX, Interface.JAX_JIT): - _check_supported_jax() - return out + return INTERFACE_MAP[user_input] except KeyError as exc: raise ValueError( f"Unknown interface {user_input}. Interface must be one of {SUPPORTED_INTERFACE_NAMES}." diff --git a/pennylane/workflow/qnode.py b/pennylane/workflow/qnode.py index b448db0b557..939df58c6ae 100644 --- a/pennylane/workflow/qnode.py +++ b/pennylane/workflow/qnode.py @@ -32,7 +32,7 @@ from pennylane.tape import QuantumScript from pennylane.transforms.core import TransformContainer, TransformDispatcher, TransformProgram -from .resolution import SupportedDiffMethods +from .resolution import SupportedDiffMethods, _validate_jax_version logger = logging.getLogger(__name__) logger.addHandler(logging.NullHandler()) @@ -551,7 +551,10 @@ def __init__( # input arguments self.func = func self.device = device - self._interface = get_canonical_interface_name(interface, _validate_jax_version=True) + self._interface = get_canonical_interface_name(interface) + if self._interface in (Interface.JAX, Interface.JAX_JIT): + _validate_jax_version() + self.diff_method = diff_method mcm_config = qml.devices.MCMConfig(mcm_method=mcm_method, postselect_mode=postselect_mode) cache = (max_diff > 1) if cache == "auto" else cache diff --git a/pennylane/workflow/resolution.py b/pennylane/workflow/resolution.py index 1ddfa4366d4..feb7005a1c2 100644 --- a/pennylane/workflow/resolution.py +++ b/pennylane/workflow/resolution.py @@ -16,7 +16,12 @@ """ from collections.abc import Callable from dataclasses import replace +from importlib.metadata import version +from importlib.util import find_spec from typing import Literal, Optional, Union, get_args +from warnings import warn + +from packaging.version import Version import pennylane as qml from pennylane.logging import debug_logger @@ -86,6 +91,21 @@ def _use_tensorflow_autograph(): return not tf.executing_eagerly() +def _validate_jax_version(): + """Checks if the installed version of JAX is supported. If an unsupported version of + JAX is installed, a ``RuntimeWarning`` is raised.""" + if not find_spec("jax"): + return + + jax_version = version("jax") + if Version(jax_version) > Version("0.4.28"): # pragma: no cover + warn( + "PennyLane is currently not compatible with versions of JAX > 0.4.28. " + f"You have version {jax_version} installed.", + RuntimeWarning, + ) + + def _resolve_interface(interface: Union[str, Interface], tapes: QuantumScriptBatch) -> Interface: """Helper function to resolve an interface based on a set of tapes. @@ -97,7 +117,9 @@ def _resolve_interface(interface: Union[str, Interface], tapes: QuantumScriptBat Interface: resolved interface """ - interface = get_canonical_interface_name(interface, _validate_jax_version=True) + interface = get_canonical_interface_name(interface) + if interface in (Interface.JAX, Interface.JAX_JIT): + _validate_jax_version() if interface == Interface.AUTO: params = [] @@ -106,7 +128,9 @@ def _resolve_interface(interface: Union[str, Interface], tapes: QuantumScriptBat interface = get_interface(*params) if interface != Interface.NUMPY: try: - interface = get_canonical_interface_name(interface, _validate_jax_version=True) + interface = get_canonical_interface_name(interface) + if interface in (Interface.JAX, Interface.JAX_JIT): + _validate_jax_version() except ValueError: interface = Interface.NUMPY if interface == Interface.TF and _use_tensorflow_autograph():