Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Raise RuntimeWarning if jax > 0.4.28 is installed #6864

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@

<h3>Improvements 🛠</h3>

* A `RuntimeWarning` is now raised by `QNode` if `interface="jax"` and the installed version of JAX
is greater than `0.4.28`.
[(#6864)](https://github.com/PennyLaneAI/pennylane/pull/6864)

* `QNode` objects now have an `update` method that allows for re-configuring settings like `diff_method`, `mcm_method`, and more. This allows for easier on-the-fly adjustments to workflows. Any arguments not specified will retain their original value.
[(#6803)](https://github.com/PennyLaneAI/pennylane/pull/6803)

Expand Down Expand Up @@ -105,4 +109,5 @@ Diksha Dhawan,
Pietropaolo Frisoni,
Marcus Gisslén,
Christina Lee,
Mudit Pandey,
Andrija Paurevic
9 changes: 8 additions & 1 deletion pennylane/workflow/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from pennylane.workflow.resolution import SupportedDiffMethods

from ._setup_transform_program import _setup_transform_program
from .resolution import _resolve_execution_config, _resolve_interface
from .resolution import _resolve_execution_config, _resolve_interface, _validate_jax_version
from .run import run

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -201,9 +201,16 @@ def cost_fn(params, x):

### Specifying and preprocessing variables ####

old_interface = interface
interface = _resolve_interface(interface, tapes)
# Only need to calculate derivatives with jax when we know it will be executed later.

if old_interface in (None, "auto") and interface in (
Interface.JAX,
Interface.JAX_JIT,
): # pragma: no cover
_validate_jax_version()

config = qml.devices.ExecutionConfig(
interface=interface,
gradient_method=diff_method,
Expand Down
5 changes: 4 additions & 1 deletion pennylane/workflow/qnode.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from pennylane.tape import QuantumScript
from pennylane.transforms.core import TransformContainer, TransformDispatcher, TransformProgram

from .resolution import SupportedDiffMethods
from .resolution import SupportedDiffMethods, _validate_jax_version

logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())
Expand Down Expand Up @@ -552,6 +552,9 @@ def __init__(
self.func = func
self.device = device
self._interface = get_canonical_interface_name(interface)
if self._interface in (Interface.JAX, Interface.JAX_JIT):
_validate_jax_version()

self.diff_method = diff_method
mcm_config = qml.devices.MCMConfig(mcm_method=mcm_method, postselect_mode=postselect_mode)
cache = (max_diff > 1) if cache == "auto" else cache
Expand Down
21 changes: 21 additions & 0 deletions pennylane/workflow/resolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,12 @@
"""
from collections.abc import Callable
from dataclasses import replace
from importlib.metadata import version
from importlib.util import find_spec
from typing import Literal, Optional, Union, get_args
from warnings import warn

from packaging.version import Version

import pennylane as qml
from pennylane.logging import debug_logger
Expand Down Expand Up @@ -71,6 +76,22 @@ def _get_jax_interface_name(tapes):
return Interface.JAX


def _validate_jax_version() -> None:
"""Checks if the installed version of JAX is supported. If an unsupported version of
JAX is installed, a ``RuntimeWarning`` is raised."""
if not find_spec("jax"):
return

jax_version = version("jax")

if Version(jax_version) > Version("0.4.28"): # pragma: no cover
warn(
"PennyLane is currently not compatible with versions of JAX > 0.4.28. "
f"You have version {jax_version} installed.",
RuntimeWarning,
)


# pylint: disable=import-outside-toplevel
def _use_tensorflow_autograph():
"""Checks if TensorFlow is in graph mode, allowing Autograph for optimized execution"""
Expand Down
Loading