From 087743d0c76265757cb28ad0a7e6a9d6a1197430 Mon Sep 17 00:00:00 2001 From: Ryosuke Noro <64354442+RyosukeNORO@users.noreply.github.com> Date: Tue, 9 Jul 2024 11:29:28 -0400 Subject: [PATCH] reduced state fix (#395) * added lines to change the type from np.ndarray to list in `reduced_state` * added an explanation to CHANGELOG.md * added test * added the related PR number to CHANGELOG.md --- .github/CHANGELOG.md | 2 ++ thewalrus/symplectic.py | 3 +++ thewalrus/tests/test_symplectic.py | 9 +++++++++ 3 files changed, 14 insertions(+) diff --git a/.github/CHANGELOG.md b/.github/CHANGELOG.md index 392e5d283..5ee95b623 100644 --- a/.github/CHANGELOG.md +++ b/.github/CHANGELOG.md @@ -17,6 +17,8 @@ * Add the calculation method of `takagi` when the matrix is diagonal. [(#394)](https://github.com/XanaduAI/thewalrus/pull/394) +* Add the lines for avoiding the comparison of np.ndarray and list. [(#395)](https://github.com/XanaduAI/thewalrus/pull/395) + ### Documentation ### Contributors diff --git a/thewalrus/symplectic.py b/thewalrus/symplectic.py index 0321e518b..93d141433 100644 --- a/thewalrus/symplectic.py +++ b/thewalrus/symplectic.py @@ -171,6 +171,9 @@ def reduced_state(mu, cov, modes): """ N = len(mu) // 2 + if type(modes) == np.ndarray: + modes = modes.tolist() + if modes == list(range(N)): # reduced state is full state return mu, cov diff --git a/thewalrus/tests/test_symplectic.py b/thewalrus/tests/test_symplectic.py index 40ff800bf..3f5edcc71 100644 --- a/thewalrus/tests/test_symplectic.py +++ b/thewalrus/tests/test_symplectic.py @@ -367,6 +367,15 @@ def test_tms(self, hbar, tol): assert np.allclose(res[0], expected[0], atol=tol, rtol=0) assert np.allclose(res[1], expected[1], atol=tol, rtol=0) + def test_ndarray(self, hbar, tol): + """Test numpy.ndarray in the third argument of `reduced_state` is converted to list correctly""" + mu, cov = symplectic.vacuum_state(4, hbar=hbar) + res = symplectic.reduced_state(mu, cov, np.array([0, 1, 2, 3])) + expected = np.zeros([8]), np.identity(8) * hbar / 2 + + assert np.allclose(res[0], expected[0], atol=tol, rtol=0) + assert np.allclose(res[1], expected[1], atol=tol, rtol=0) + class TestLossChannel: """Tests for the loss channel"""