From 016d654f2dc8b61b15457f62eff1719c4881e6d2 Mon Sep 17 00:00:00 2001 From: LuiGiovanni Date: Fri, 4 Dec 2020 18:14:57 -0700 Subject: [PATCH 01/14] Added power to Jax backend and test --- tensornetwork/backends/jax/jax_backend.py | 3 +++ tensornetwork/backends/jax/jax_backend_test.py | 15 +++++++++++++++ 2 files changed, 18 insertions(+) diff --git a/tensornetwork/backends/jax/jax_backend.py b/tensornetwork/backends/jax/jax_backend.py index d7417e8f3..038960ce3 100644 --- a/tensornetwork/backends/jax/jax_backend.py +++ b/tensornetwork/backends/jax/jax_backend.py @@ -874,3 +874,6 @@ def sign(self, tensor: Tensor) -> Tensor: def item(self, tensor): return tensor.item() + + def power(self, a: Tensor, b: Union[Tensor, float]) -> Tensor: + return jnp.power(a,b) \ No newline at end of file diff --git a/tensornetwork/backends/jax/jax_backend_test.py b/tensornetwork/backends/jax/jax_backend_test.py index 13c74247c..cc3246b57 100644 --- a/tensornetwork/backends/jax/jax_backend_test.py +++ b/tensornetwork/backends/jax/jax_backend_test.py @@ -974,6 +974,18 @@ def matvec_jax(vector, matrix): num_krylov_vecs=100, tol=0.0001) +def test_power(dtype): + np.random.seed(10) + backend = jax_backend.JaxBackend() + tensor = np.random.rand(2, 3, 4) + a = backend.convert_to_tensor(tensor) + actual = backend.power(a, axis=(1, 2)) + expected = np.power(a, axis=(1, 2)) + np.testing.assert_allclose(expected, actual) + + actual = backend.power(a, axis=(1, 2), keepdims=True) + expected = np.power(a, axis=(1, 2), keepdims=True) + np.testing.assert_allclose(expected, actual) def test_sum(): np.random.seed(10) @@ -1240,3 +1252,6 @@ def test_item(dtype): backend = jax_backend.JaxBackend() tensor = backend.randn((1,), dtype=dtype, seed=10) assert backend.item(tensor) == tensor.item() + + + From 62a06c0688a04a45fa80945c08f311eace398bda Mon Sep 17 00:00:00 2001 From: LuiGiovanni Date: Fri, 4 Dec 2020 19:18:27 -0700 Subject: [PATCH 02/14] Added power function to the Jax backend and test. --- tensornetwork/backends/jax/jax_backend.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/tensornetwork/backends/jax/jax_backend.py b/tensornetwork/backends/jax/jax_backend.py index 038960ce3..aa488bee8 100644 --- a/tensornetwork/backends/jax/jax_backend.py +++ b/tensornetwork/backends/jax/jax_backend.py @@ -876,4 +876,15 @@ def item(self, tensor): return tensor.item() def power(self, a: Tensor, b: Union[Tensor, float]) -> Tensor: - return jnp.power(a,b) \ No newline at end of file + """ + Returns power of tensor a to the value of b. + In the case b is a tensor, then the power is by element + with a as the base and b as the exponent. + In the case b is a scalar, then the power of each value in a + is raised to the exponent of b. + + Args: + a: The tensor that contains the base. + b: The tensor that contains the exponent or a single scalar. + """ + return jnp.power(a,b) From 43952dde29765a0dfcd73c97b28914a55d9c7ad3 Mon Sep 17 00:00:00 2001 From: LuiGiovanni Date: Fri, 4 Dec 2020 20:13:21 -0700 Subject: [PATCH 03/14] Fixed white space error. --- tensornetwork/backends/jax/jax_backend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensornetwork/backends/jax/jax_backend.py b/tensornetwork/backends/jax/jax_backend.py index aa488bee8..68c4d9540 100644 --- a/tensornetwork/backends/jax/jax_backend.py +++ b/tensornetwork/backends/jax/jax_backend.py @@ -887,4 +887,4 @@ def power(self, a: Tensor, b: Union[Tensor, float]) -> Tensor: a: The tensor that contains the base. b: The tensor that contains the exponent or a single scalar. """ - return jnp.power(a,b) + return jnp.power(a, b) From b0f579e4634e3636793dd8b3f933ad14eb19b123 Mon Sep 17 00:00:00 2001 From: LuiGiovanni Date: Sat, 5 Dec 2020 00:44:28 -0700 Subject: [PATCH 04/14] Re-made the power test function. --- .../backends/jax/jax_backend_test.py | 29 +++++++++---------- 1 file changed, 14 insertions(+), 15 deletions(-) diff --git a/tensornetwork/backends/jax/jax_backend_test.py b/tensornetwork/backends/jax/jax_backend_test.py index cc3246b57..3c8926f64 100644 --- a/tensornetwork/backends/jax/jax_backend_test.py +++ b/tensornetwork/backends/jax/jax_backend_test.py @@ -974,19 +974,6 @@ def matvec_jax(vector, matrix): num_krylov_vecs=100, tol=0.0001) -def test_power(dtype): - np.random.seed(10) - backend = jax_backend.JaxBackend() - tensor = np.random.rand(2, 3, 4) - a = backend.convert_to_tensor(tensor) - actual = backend.power(a, axis=(1, 2)) - expected = np.power(a, axis=(1, 2)) - np.testing.assert_allclose(expected, actual) - - actual = backend.power(a, axis=(1, 2), keepdims=True) - expected = np.power(a, axis=(1, 2), keepdims=True) - np.testing.assert_allclose(expected, actual) - def test_sum(): np.random.seed(10) backend = jax_backend.JaxBackend() @@ -1253,5 +1240,17 @@ def test_item(dtype): tensor = backend.randn((1,), dtype=dtype, seed=10) assert backend.item(tensor) == tensor.item() - - +@pytest.mark.parametrize("dtype", tf_dtypes) +def test_power(dtype): + shape = (4, 3, 2) + backend = jax_backend.JaxBackend() + base_tensor = backend.randn(shape, dtype=dtype, seed=10) + power_tensor = backend.randn(shape, dtype=dtype, seed=10) + actual = backend.power(base_tensor, power_tensor) + expected = tf.math.pow(base_tensor, power_tensor) + np.testing.assert_allclose(expected, actual) + + power = np.random.rand(1)[0] + actual = backend.power(base_tensor, power) + expected = tf.math.pow(base_tensor, power) + np.testing.assert_allclose(expected, actual) \ No newline at end of file From e9f9869942b9ceb58caf5552e6e7290be0a958a7 Mon Sep 17 00:00:00 2001 From: LuiGiovanni Date: Sat, 5 Dec 2020 00:49:59 -0700 Subject: [PATCH 05/14] Fixed typo --- tensornetwork/backends/jax/jax_backend_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensornetwork/backends/jax/jax_backend_test.py b/tensornetwork/backends/jax/jax_backend_test.py index 3c8926f64..02b41c05e 100644 --- a/tensornetwork/backends/jax/jax_backend_test.py +++ b/tensornetwork/backends/jax/jax_backend_test.py @@ -1240,7 +1240,7 @@ def test_item(dtype): tensor = backend.randn((1,), dtype=dtype, seed=10) assert backend.item(tensor) == tensor.item() -@pytest.mark.parametrize("dtype", tf_dtypes) +@pytest.mark.parametrize("dtype", np_dtypes) def test_power(dtype): shape = (4, 3, 2) backend = jax_backend.JaxBackend() @@ -1249,7 +1249,7 @@ def test_power(dtype): actual = backend.power(base_tensor, power_tensor) expected = tf.math.pow(base_tensor, power_tensor) np.testing.assert_allclose(expected, actual) - + power = np.random.rand(1)[0] actual = backend.power(base_tensor, power) expected = tf.math.pow(base_tensor, power) From 9e16a9faacbfc318a497f8dadf130b4ef0ca32c1 Mon Sep 17 00:00:00 2001 From: LuiGiovanni Date: Sat, 5 Dec 2020 01:30:23 -0700 Subject: [PATCH 06/14] Testing out numpy square. --- tensornetwork/backends/jax/jax_backend.py | 2 +- tensornetwork/backends/jax/jax_backend_test.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/tensornetwork/backends/jax/jax_backend.py b/tensornetwork/backends/jax/jax_backend.py index 68c4d9540..f3d7c04a2 100644 --- a/tensornetwork/backends/jax/jax_backend.py +++ b/tensornetwork/backends/jax/jax_backend.py @@ -887,4 +887,4 @@ def power(self, a: Tensor, b: Union[Tensor, float]) -> Tensor: a: The tensor that contains the base. b: The tensor that contains the exponent or a single scalar. """ - return jnp.power(a, b) + return jnp.square(a, b) diff --git a/tensornetwork/backends/jax/jax_backend_test.py b/tensornetwork/backends/jax/jax_backend_test.py index 02b41c05e..6050c6271 100644 --- a/tensornetwork/backends/jax/jax_backend_test.py +++ b/tensornetwork/backends/jax/jax_backend_test.py @@ -1253,4 +1253,5 @@ def test_power(dtype): power = np.random.rand(1)[0] actual = backend.power(base_tensor, power) expected = tf.math.pow(base_tensor, power) - np.testing.assert_allclose(expected, actual) \ No newline at end of file + np.testing.assert_allclose(expected, actual) + \ No newline at end of file From 81b4d8b6faeeb251e3ec8ef040c2ba174d074d4d Mon Sep 17 00:00:00 2001 From: LuiGiovanni Date: Sat, 5 Dec 2020 02:09:54 -0700 Subject: [PATCH 07/14] Fixed issues with the assertion in test, should work now. --- tensornetwork/backends/jax/jax_backend.py | 4 ++-- tensornetwork/backends/jax/jax_backend_test.py | 5 ++--- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/tensornetwork/backends/jax/jax_backend.py b/tensornetwork/backends/jax/jax_backend.py index f3d7c04a2..2ef092eda 100644 --- a/tensornetwork/backends/jax/jax_backend.py +++ b/tensornetwork/backends/jax/jax_backend.py @@ -875,7 +875,7 @@ def sign(self, tensor: Tensor) -> Tensor: def item(self, tensor): return tensor.item() - def power(self, a: Tensor, b: Union[Tensor, float]) -> Tensor: + def power(self, a: Tensor, b: Union[Tensor, int]) -> Tensor: """ Returns power of tensor a to the value of b. In the case b is a tensor, then the power is by element @@ -887,4 +887,4 @@ def power(self, a: Tensor, b: Union[Tensor, float]) -> Tensor: a: The tensor that contains the base. b: The tensor that contains the exponent or a single scalar. """ - return jnp.square(a, b) + return jnp.power(a, b) diff --git a/tensornetwork/backends/jax/jax_backend_test.py b/tensornetwork/backends/jax/jax_backend_test.py index 6050c6271..87b270e02 100644 --- a/tensornetwork/backends/jax/jax_backend_test.py +++ b/tensornetwork/backends/jax/jax_backend_test.py @@ -1247,11 +1247,10 @@ def test_power(dtype): base_tensor = backend.randn(shape, dtype=dtype, seed=10) power_tensor = backend.randn(shape, dtype=dtype, seed=10) actual = backend.power(base_tensor, power_tensor) - expected = tf.math.pow(base_tensor, power_tensor) + expected = jax.numpy.power(base_tensor, power_tensor) np.testing.assert_allclose(expected, actual) power = np.random.rand(1)[0] actual = backend.power(base_tensor, power) - expected = tf.math.pow(base_tensor, power) + expected = jax.numpy.power(base_tensor, power) np.testing.assert_allclose(expected, actual) - \ No newline at end of file From 195f69198d4cc1a626ca43bba07bab05ece994fd Mon Sep 17 00:00:00 2001 From: LuiGiovanni Date: Sat, 5 Dec 2020 23:57:51 -0700 Subject: [PATCH 08/14] Added NotImplementedError function and it's respective test for Cholskey decomposition. --- tensornetwork/backends/abstract_backend.py | 4 ++++ tensornetwork/backends/backend_test.py | 6 ++++++ 2 files changed, 10 insertions(+) diff --git a/tensornetwork/backends/abstract_backend.py b/tensornetwork/backends/abstract_backend.py index 4e57d366e..e72d5a14f 100644 --- a/tensornetwork/backends/abstract_backend.py +++ b/tensornetwork/backends/abstract_backend.py @@ -1022,3 +1022,7 @@ def item(self, tensor) -> Union[float, int, complex]: The value in tensor. """ raise NotImplementedError("Backend {self.name} has not implemented item") + + def chsky(self, tensor: Tensor, pivot_axis: int = -1, non_negative_diagonal: bool = False) -> Tuple[Tensor, Tensor]: + """Computes the Cholskey decomposition of a tensor.""" + raise NotImplementedError("Backend '{}' has not implemented chsky.".format(self.name)) diff --git a/tensornetwork/backends/backend_test.py b/tensornetwork/backends/backend_test.py index 6230361a1..2aef64732 100644 --- a/tensornetwork/backends/backend_test.py +++ b/tensornetwork/backends/backend_test.py @@ -193,6 +193,12 @@ def test_abstract_backend_qr_decompositon_not_implemented(): backend.qr(np.ones((2, 2)), 0) +def test_abstract_backend_chsky_decompositon_not_implemented(): + backend = AbstractBackend() + with pytest.raises(NotImplementedError): + backend.chsky(np.ones((2, 2)), 0) + + def test_abstract_backend_rq_decompositon_not_implemented(): backend = AbstractBackend() with pytest.raises(NotImplementedError): From 737f07149cbb3cc09399d30b3035f02171954a56 Mon Sep 17 00:00:00 2001 From: LuiGiovanni Date: Sun, 6 Dec 2020 00:23:20 -0700 Subject: [PATCH 09/14] Fixed line too long error --- tensornetwork/backends/abstract_backend.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tensornetwork/backends/abstract_backend.py b/tensornetwork/backends/abstract_backend.py index e72d5a14f..31ff08921 100644 --- a/tensornetwork/backends/abstract_backend.py +++ b/tensornetwork/backends/abstract_backend.py @@ -1023,6 +1023,9 @@ def item(self, tensor) -> Union[float, int, complex]: """ raise NotImplementedError("Backend {self.name} has not implemented item") - def chsky(self, tensor: Tensor, pivot_axis: int = -1, non_negative_diagonal: bool = False) -> Tuple[Tensor, Tensor]: + def chsky(self, tensor: Tensor, + pivot_axis: int = -1, + non_negative_diagonal: bool = False) -> + Tuple[Tensor, Tensor]: """Computes the Cholskey decomposition of a tensor.""" raise NotImplementedError("Backend '{}' has not implemented chsky.".format(self.name)) From 0fbe8fb3cb5a94b219280f17e65abbfebccdb83f Mon Sep 17 00:00:00 2001 From: LuiGiovanni Date: Sun, 6 Dec 2020 00:59:34 -0700 Subject: [PATCH 10/14] Removing changes for different branches. --- tensornetwork/backends/jax/jax_backend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensornetwork/backends/jax/jax_backend.py b/tensornetwork/backends/jax/jax_backend.py index 2ef092eda..d444653b1 100644 --- a/tensornetwork/backends/jax/jax_backend.py +++ b/tensornetwork/backends/jax/jax_backend.py @@ -877,7 +877,7 @@ def item(self, tensor): def power(self, a: Tensor, b: Union[Tensor, int]) -> Tensor: """ - Returns power of tensor a to the value of b. + Returns the power of tensor a to the value of b. In the case b is a tensor, then the power is by element with a as the base and b as the exponent. In the case b is a scalar, then the power of each value in a From 8c376566cc1ab9e8151bed64edbd2fc538a802b2 Mon Sep 17 00:00:00 2001 From: LuiGiovanni Date: Sun, 6 Dec 2020 01:38:44 -0700 Subject: [PATCH 11/14] Revert "Fixed issues with the assertion in test, should work now." This reverts commit 81b4d8b6faeeb251e3ec8ef040c2ba174d074d4d. --- tensornetwork/backends/jax/jax_backend.py | 4 ++-- tensornetwork/backends/jax/jax_backend_test.py | 5 +++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/tensornetwork/backends/jax/jax_backend.py b/tensornetwork/backends/jax/jax_backend.py index d444653b1..3b6ed659a 100644 --- a/tensornetwork/backends/jax/jax_backend.py +++ b/tensornetwork/backends/jax/jax_backend.py @@ -875,7 +875,7 @@ def sign(self, tensor: Tensor) -> Tensor: def item(self, tensor): return tensor.item() - def power(self, a: Tensor, b: Union[Tensor, int]) -> Tensor: + def power(self, a: Tensor, b: Union[Tensor, float]) -> Tensor: """ Returns the power of tensor a to the value of b. In the case b is a tensor, then the power is by element @@ -887,4 +887,4 @@ def power(self, a: Tensor, b: Union[Tensor, int]) -> Tensor: a: The tensor that contains the base. b: The tensor that contains the exponent or a single scalar. """ - return jnp.power(a, b) + return jnp.square(a, b) diff --git a/tensornetwork/backends/jax/jax_backend_test.py b/tensornetwork/backends/jax/jax_backend_test.py index 87b270e02..6050c6271 100644 --- a/tensornetwork/backends/jax/jax_backend_test.py +++ b/tensornetwork/backends/jax/jax_backend_test.py @@ -1247,10 +1247,11 @@ def test_power(dtype): base_tensor = backend.randn(shape, dtype=dtype, seed=10) power_tensor = backend.randn(shape, dtype=dtype, seed=10) actual = backend.power(base_tensor, power_tensor) - expected = jax.numpy.power(base_tensor, power_tensor) + expected = tf.math.pow(base_tensor, power_tensor) np.testing.assert_allclose(expected, actual) power = np.random.rand(1)[0] actual = backend.power(base_tensor, power) - expected = jax.numpy.power(base_tensor, power) + expected = tf.math.pow(base_tensor, power) np.testing.assert_allclose(expected, actual) + \ No newline at end of file From 06c579aedb65c9447725533fe6177ab65a84d14c Mon Sep 17 00:00:00 2001 From: LuiGiovanni Date: Sun, 6 Dec 2020 12:36:35 -0700 Subject: [PATCH 12/14] Spilt up cholesky and jax modifications to diferent branches. --- tensornetwork/backends/abstract_backend.py | 7 ------- tensornetwork/backends/backend_test.py | 6 ------ tensornetwork/backends/jax/jax_backend.py | 2 +- 3 files changed, 1 insertion(+), 14 deletions(-) diff --git a/tensornetwork/backends/abstract_backend.py b/tensornetwork/backends/abstract_backend.py index 31ff08921..4e57d366e 100644 --- a/tensornetwork/backends/abstract_backend.py +++ b/tensornetwork/backends/abstract_backend.py @@ -1022,10 +1022,3 @@ def item(self, tensor) -> Union[float, int, complex]: The value in tensor. """ raise NotImplementedError("Backend {self.name} has not implemented item") - - def chsky(self, tensor: Tensor, - pivot_axis: int = -1, - non_negative_diagonal: bool = False) -> - Tuple[Tensor, Tensor]: - """Computes the Cholskey decomposition of a tensor.""" - raise NotImplementedError("Backend '{}' has not implemented chsky.".format(self.name)) diff --git a/tensornetwork/backends/backend_test.py b/tensornetwork/backends/backend_test.py index 2aef64732..6230361a1 100644 --- a/tensornetwork/backends/backend_test.py +++ b/tensornetwork/backends/backend_test.py @@ -193,12 +193,6 @@ def test_abstract_backend_qr_decompositon_not_implemented(): backend.qr(np.ones((2, 2)), 0) -def test_abstract_backend_chsky_decompositon_not_implemented(): - backend = AbstractBackend() - with pytest.raises(NotImplementedError): - backend.chsky(np.ones((2, 2)), 0) - - def test_abstract_backend_rq_decompositon_not_implemented(): backend = AbstractBackend() with pytest.raises(NotImplementedError): diff --git a/tensornetwork/backends/jax/jax_backend.py b/tensornetwork/backends/jax/jax_backend.py index 3b6ed659a..9f5eb9843 100644 --- a/tensornetwork/backends/jax/jax_backend.py +++ b/tensornetwork/backends/jax/jax_backend.py @@ -887,4 +887,4 @@ def power(self, a: Tensor, b: Union[Tensor, float]) -> Tensor: a: The tensor that contains the base. b: The tensor that contains the exponent or a single scalar. """ - return jnp.square(a, b) + return jnp.power(a, b) From 3624ccf270b44bfcaffceab9d421dc341b601dd8 Mon Sep 17 00:00:00 2001 From: LuiGiovanni Date: Sun, 6 Dec 2020 13:18:14 -0700 Subject: [PATCH 13/14] Fixed wrong power function call --- tensornetwork/backends/jax/jax_backend_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensornetwork/backends/jax/jax_backend_test.py b/tensornetwork/backends/jax/jax_backend_test.py index 6050c6271..615bba2e6 100644 --- a/tensornetwork/backends/jax/jax_backend_test.py +++ b/tensornetwork/backends/jax/jax_backend_test.py @@ -1247,11 +1247,11 @@ def test_power(dtype): base_tensor = backend.randn(shape, dtype=dtype, seed=10) power_tensor = backend.randn(shape, dtype=dtype, seed=10) actual = backend.power(base_tensor, power_tensor) - expected = tf.math.pow(base_tensor, power_tensor) + expected = jax.numpy.power(base_tensor, power_tensor) np.testing.assert_allclose(expected, actual) power = np.random.rand(1)[0] actual = backend.power(base_tensor, power) - expected = tf.math.pow(base_tensor, power) + expected = jax.numpy.power(base_tensor, power) np.testing.assert_allclose(expected, actual) \ No newline at end of file From aefd843178342f67f2fd7b3af571681db54b33a6 Mon Sep 17 00:00:00 2001 From: LuiGiovanni Date: Fri, 18 Dec 2020 13:47:03 -0700 Subject: [PATCH 14/14] Added power function for Pytorch and symmetric with their tests --- tensornetwork/backends/pytorch/pytorch_backend.py | 14 ++++++++++++++ .../backends/pytorch/pytorch_backend_test.py | 12 ++++++++++++ .../backends/symmetric/symmetric_backend.py | 14 ++++++++++++++ .../backends/symmetric/symmetric_backend_test.py | 13 +++++++++++++ 4 files changed, 53 insertions(+) diff --git a/tensornetwork/backends/pytorch/pytorch_backend.py b/tensornetwork/backends/pytorch/pytorch_backend.py index 8467fac39..fb3ae238a 100644 --- a/tensornetwork/backends/pytorch/pytorch_backend.py +++ b/tensornetwork/backends/pytorch/pytorch_backend.py @@ -475,3 +475,17 @@ def sign(self, tensor: Tensor) -> Tensor: def item(self, tensor): return tensor.item() + + def power(self, a: Tensor, b: Tensor) -> Tensor: + """ + Returns the power of tensor a to the value of b. + In the case b is a tensor, then the power is by element + with a as the base and b as the exponent. + In the case b is a scalar, then the power of each value in a + is raised to the exponent of b. + + Args: + a: The tensor that contains the base. + b: The tensor that contains the exponent or a single scalar. + """ + return a ** b diff --git a/tensornetwork/backends/pytorch/pytorch_backend_test.py b/tensornetwork/backends/pytorch/pytorch_backend_test.py index f44aed5c6..82a3e6356 100644 --- a/tensornetwork/backends/pytorch/pytorch_backend_test.py +++ b/tensornetwork/backends/pytorch/pytorch_backend_test.py @@ -566,6 +566,18 @@ def test_matmul(): np.testing.assert_allclose(expected, actual) +def test_power(): + np.random.seed(10) + backend = pytorch_backend.PyTorchBackend() + t1 = np.random.rand(10, 2, 3) + t2 = np.random.rand(10, 3, 4) + a = backend.convert_to_tensor(t1) + b = backend.convert_to_tensor(t2) + actual = backend.power(a, b) + expected = np.power(t1, t2) + np.testing.assert_allclose(expected, actual) + + @pytest.mark.parametrize("dtype", torch_randn_dtypes) @pytest.mark.parametrize("offset", range(-2, 2)) @pytest.mark.parametrize("axis1", [-2, 0]) diff --git a/tensornetwork/backends/symmetric/symmetric_backend.py b/tensornetwork/backends/symmetric/symmetric_backend.py index 9e55bc694..1b1c517e3 100644 --- a/tensornetwork/backends/symmetric/symmetric_backend.py +++ b/tensornetwork/backends/symmetric/symmetric_backend.py @@ -689,3 +689,17 @@ def matmul(self, tensor1: Tensor, tensor2: Tensor): if (tensor1.ndim != 2) or (tensor2.ndim != 2): raise ValueError("inputs to `matmul` have to be matrices") return tensor1 @ tensor2 + + def power(self, a: Tensor, b: Tensor) -> Tensor: + """ + Returns the power of tensor a to the value of b. + In the case b is a tensor, then the power is by element + with a as the base and b as the exponent. + In the case b is a scalar, then the power of each value in a + is raised to the exponent of b. + + Args: + a: The tensor that contains the base. + b: The tensor that contains the exponent or a single scalar. + """ + return a ** b diff --git a/tensornetwork/backends/symmetric/symmetric_backend_test.py b/tensornetwork/backends/symmetric/symmetric_backend_test.py index a8b78051e..f38d8bdc5 100644 --- a/tensornetwork/backends/symmetric/symmetric_backend_test.py +++ b/tensornetwork/backends/symmetric/symmetric_backend_test.py @@ -609,6 +609,19 @@ def test_addition_raises(R, dtype, num_charges): backend.addition(a, c) +@pytest.mark.parametrize("dtype", np_dtypes) +@pytest.mark.parametrize("num_charges", [1, 2]) +def test_power(dtype, num_charges): + np.random.seed(10) + R = 4 + backend = symmetric_backend.SymmetricBackend() + a = get_tensor(R, num_charges, dtype) + b = BlockSparseTensor.random(a.sparse_shape) + expected = np.power(a.data, b.data) + actual = backend.power(a.data, b.data) + np.testing.assert_allclose(expected, actual) + + @pytest.mark.parametrize("dtype", np_dtypes) @pytest.mark.parametrize("R", [2, 3, 4, 5]) @pytest.mark.parametrize("num_charges", [1, 2])