Skip to content

Commit

Permalink
Fix InterfaceEnum object (#6877)
Browse files Browse the repository at this point in the history
**Context:**

`QNode.interface` returns a `str` type,
```python
    @Property
    def interface(self) -> str:
        """The interface used by the QNode"""
        return self._interface.value
```
Therefore, this equality will never be true,
```python
from pennylane.math.interface_utils import Interface

>>> "jax" == Interface.JAX
False
```

Moreover, `str` cannot be compared with `Enum` objects at all. 🤯 

**Description of the Change:**

- Adjust the `InterfaceEnum` class to have dunder methods to overwrite
comparisons. This way it will error out if you try to compare it with a
string.
- Add test coverage for the private helpers in `qnode.py`

Now we have,
```python
from pennylane.math.interface_utils import Interface

>>> "jax" == Interface.JAX
TypeError: Cannot compare Interface with str
```

**Benefits:** Logic is correct.

**Possible Drawbacks:** None identified.

[sc-83034]

---------

Co-authored-by: Yushao Chen (Jerry) <[email protected]>
  • Loading branch information
andrijapau and JerryChen97 authored Jan 24, 2025
1 parent 666941f commit 721acb2
Show file tree
Hide file tree
Showing 6 changed files with 82 additions and 9 deletions.
3 changes: 3 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,9 @@

<h3>Internal changes ⚙️</h3>

* Improved the `InterfaceEnum` object to prevent direct comparisons to `str` objects.
[(#6877)](https://github.com/PennyLaneAI/pennylane/pull/6877)

* Added a `QmlPrimitive` class that inherits `jax.core.Primitive` to a new `qml.capture.custom_primitives` module.
This class contains a `prim_type` property so that we can differentiate between different sets of PennyLane primitives.
Consequently, `QmlPrimitive` is now used to define all PennyLane primitives.
Expand Down
11 changes: 10 additions & 1 deletion pennylane/math/interface_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,15 @@ def get_like(self):
}
return mapping[self]

def __eq__(self, interface):
if isinstance(interface, str):
raise TypeError("Cannot compare Interface with str")
return super().__eq__(interface)

def __hash__(self):
# pylint: disable=useless-super-delegation
return super().__hash__()


InterfaceLike = Union[str, Interface, None]

Expand Down Expand Up @@ -225,7 +234,7 @@ def get_canonical_interface_name(user_input: InterfaceLike) -> Interface:
Interface: canonical interface
"""

if user_input in SUPPORTED_INTERFACE_NAMES:
if isinstance(user_input, Interface) and user_input in SUPPORTED_INTERFACE_NAMES:
return user_input
try:
return INTERFACE_MAP[user_input]
Expand Down
4 changes: 2 additions & 2 deletions pennylane/workflow/qnode.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,11 @@ def _convert_to_interface(result, interface: Interface):
def _make_execution_config(
circuit: Optional["QNode"], diff_method=None, mcm_config=None
) -> "qml.devices.ExecutionConfig":
circuit_interface = getattr(circuit, "interface", Interface.NUMPY)
circuit_interface = getattr(circuit, "interface", Interface.NUMPY.value)
execute_kwargs = getattr(circuit, "execute_kwargs", {})
gradient_kwargs = getattr(circuit, "gradient_kwargs", {})
grad_on_execution = execute_kwargs.get("grad_on_execution")
if circuit_interface == Interface.JAX:
if circuit_interface in {Interface.JAX.value, Interface.JAX_JIT.value}:
grad_on_execution = False
elif grad_on_execution == "best":
grad_on_execution = None
Expand Down
11 changes: 5 additions & 6 deletions pennylane/workflow/resolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,19 +96,18 @@ def _resolve_interface(interface: Union[str, Interface], tapes: QuantumScriptBat
Returns:
Interface: resolved interface
"""

interface = get_canonical_interface_name(interface)

if interface == Interface.AUTO:
params = []
for tape in tapes:
params.extend(tape.get_parameters(trainable_only=False))
interface = get_interface(*params)
if interface != Interface.NUMPY:
try:
interface = get_canonical_interface_name(interface)
except ValueError:
interface = Interface.NUMPY
try:
interface = get_canonical_interface_name(interface)
except ValueError:
# If the interface is not recognized, default to numpy, like networkx
interface = Interface.NUMPY
if interface == Interface.TF and _use_tensorflow_autograph():
interface = Interface.TF_AUTOGRAPH
if interface == Interface.JAX:
Expand Down
12 changes: 12 additions & 0 deletions tests/math/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1024,6 +1024,18 @@ def test_get_interface(t, interface):
assert res == interface


# pylint: disable=too-few-public-methods
class TestInterfaceEnum:
"""Test the Interface enum class"""

def test_eq(self):
"""Test that an error is raised if comparing to string"""
assert fn.Interface.NUMPY == fn.Interface.NUMPY
with pytest.raises(TypeError, match="Cannot compare Interface with str"):
# pylint: disable=pointless-statement
fn.Interface.NUMPY == "numpy"


@pytest.mark.parametrize("t", test_data)
def test_toarray(t):
"""Test that the toarray method correctly converts the input
Expand Down
50 changes: 50 additions & 0 deletions tests/test_qnode.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from pennylane import qnode
from pennylane.tape import QuantumScript, QuantumScriptBatch
from pennylane.typing import PostprocessingFn
from pennylane.workflow.qnode import _make_execution_config


def dummyfunc():
Expand Down Expand Up @@ -2073,3 +2074,52 @@ def circuit(x):
circuit(qml.numpy.array(0.1))

assert circuit.interface == "auto"


class TestPrivateFunctions:
"""Tests for private functions in the QNode class."""

def test_make_execution_config_with_no_qnode(self):
"""Test that the _make_execution_config function correctly creates an execution config."""
diff_method = "best"
mcm_config = qml.devices.MCMConfig(postselect_mode="fill-shots", mcm_method="deferred")
config = _make_execution_config(None, diff_method, mcm_config)

expected_config = qml.devices.ExecutionConfig(
interface="numpy",
gradient_keyword_arguments={},
use_device_jacobian_product=False,
grad_on_execution=None,
gradient_method=diff_method,
mcm_config=mcm_config,
)

assert config == expected_config

@pytest.mark.parametrize("interface", ["autograd", "torch", "tf", "jax", "jax-jit"])
def test_make_execution_config_with_qnode(self, interface):
"""Test that a execution config is made correctly with no QNode."""
if "jax" in interface:
grad_on_execution = False
else:
grad_on_execution = None

@qml.qnode(qml.device("default.qubit"), interface=interface)
def circuit():
qml.H(0)
return qml.probs()

diff_method = "best"
mcm_config = qml.devices.MCMConfig(postselect_mode="fill-shots", mcm_method="deferred")
config = _make_execution_config(circuit, diff_method, mcm_config)

expected_config = qml.devices.ExecutionConfig(
interface=interface,
gradient_keyword_arguments={},
use_device_jacobian_product=False,
grad_on_execution=grad_on_execution,
gradient_method=diff_method,
mcm_config=mcm_config,
)

assert config == expected_config

0 comments on commit 721acb2

Please sign in to comment.