-
Notifications
You must be signed in to change notification settings - Fork 359
Added power function to backend to new backends #889
base: master
Are you sure you want to change the base?
Changes from 15 commits
016d654
62a06c0
43952dd
b0f579e
e9f9869
9e16a9f
81b4d8b
195f691
737f071
0fbe8fb
8c37656
06c579a
3624ccf
aefd843
fa93dd6
4f135fe
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -689,3 +689,17 @@ def matmul(self, tensor1: Tensor, tensor2: Tensor): | |
if (tensor1.ndim != 2) or (tensor2.ndim != 2): | ||
raise ValueError("inputs to `matmul` have to be matrices") | ||
return tensor1 @ tensor2 | ||
|
||
def power(self, a: Tensor, b: Tensor) -> Tensor: | ||
""" | ||
Returns the power of tensor a to the value of b. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm a bit confused with this request since this test works fine for me with no failures could you elaborate on what the problem is, please? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The reason the tests passes is due to a bug in the test (I had actually missed that in the earlier reviewa, see below). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. pls add some test ensuring that |
||
In the case b is a tensor, then the power is by element | ||
with a as the base and b as the exponent. | ||
In the case b is a scalar, then the power of each value in a | ||
is raised to the exponent of b. | ||
|
||
Args: | ||
a: The tensor that contains the base. | ||
b: The tensor that contains the exponent or a single scalar. | ||
""" | ||
return a ** b |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -609,6 +609,19 @@ def test_addition_raises(R, dtype, num_charges): | |
backend.addition(a, c) | ||
|
||
|
||
@pytest.mark.parametrize("dtype", np_dtypes) | ||
@pytest.mark.parametrize("num_charges", [1, 2]) | ||
def test_power(dtype, num_charges): | ||
np.random.seed(10) | ||
R = 4 | ||
backend = symmetric_backend.SymmetricBackend() | ||
a = get_tensor(R, num_charges, dtype) | ||
b = BlockSparseTensor.random(a.sparse_shape) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Pls modify the test so that There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
expected = np.power(a.data, b.data) | ||
actual = backend.power(a.data, b.data) | ||
np.testing.assert_allclose(expected, actual) | ||
|
||
|
||
@pytest.mark.parametrize("dtype", np_dtypes) | ||
@pytest.mark.parametrize("R", [2, 3, 4, 5]) | ||
@pytest.mark.parametrize("num_charges", [1, 2]) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the wording of this docstring seems somewhat imprecise. Can you fix this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the change I made to it, once we resolve the other request I will push the changes