Skip to content

Commit cf6a36f

Browse files
Rename mul -> multiply
1 parent 95acdb3 commit cf6a36f

File tree

2 files changed

+29
-13
lines changed

2 files changed

+29
-13
lines changed

pytensor/sparse/basic.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2491,29 +2491,31 @@ def infer_shape(self, fgraph, node, ins_shapes):
24912491
mul_s_v = MulSV()
24922492

24932493

2494-
def mul(x, y):
2494+
def multiply(
2495+
x: SparseTensorType | TensorType, y: SparseTensorType | TensorType
2496+
) -> SparseVariable:
24952497
"""
24962498
Multiply elementwise two matrices, at least one of which is sparse.
24972499
24982500
This method will provide the right op according to the inputs.
24992501
25002502
Parameters
25012503
----------
2502-
x
2504+
x: SparseVariable
25032505
A matrix variable.
2504-
y
2506+
y: SparseVariable
25052507
A matrix variable.
25062508
25072509
Returns
25082510
-------
2509-
A sparse matrix
2510-
`x` * `y`
2511+
result: SparseVariable
2512+
The elementwise multiplication of `x` and `y`.
25112513
25122514
Notes
25132515
-----
25142516
At least one of `x` and `y` must be a sparse matrix.
2515-
The grad is regular, i.e. not structured.
25162517
2518+
The gradient is regular, i.e. not structured.
25172519
"""
25182520

25192521
x = as_sparse_or_tensor_variable(x)
@@ -2541,6 +2543,20 @@ def mul(x, y):
25412543
raise NotImplementedError()
25422544

25432545

2546+
def mul(x, y):
2547+
warn(
2548+
"pytensor.sparse.mul is deprecated and will be removed in a future version. Use "
2549+
"pytensor.tensor.sparse.multiply instead.",
2550+
category=DeprecationWarning,
2551+
stacklevel=2,
2552+
)
2553+
2554+
return multiply(x, y)
2555+
2556+
2557+
mul.__doc__ = multiply.__doc__
2558+
2559+
25442560
class __ComparisonOpSS(Op):
25452561
"""
25462562
Used as a superclass for all comparisons between two sparses matrices.

tests/sparse/test_basic.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,8 @@
6565
gt,
6666
le,
6767
lt,
68-
mul,
6968
mul_s_v,
69+
multiply,
7070
sampling_dot,
7171
sp_ones_like,
7272
square_diagonal,
@@ -724,21 +724,21 @@ def test_AddDS(self):
724724

725725
def test_MulSS(self):
726726
self._testSS(
727-
mul,
727+
multiply,
728728
np.array([[1.0, 0], [3, 0], [0, 6]]),
729729
np.array([[1.0, 2], [3, 0], [0, 6]]),
730730
)
731731

732732
def test_MulSD(self):
733733
self._testSD(
734-
mul,
734+
multiply,
735735
np.array([[1.0, 0], [3, 0], [0, 6]]),
736736
np.array([[1.0, 2], [3, 0], [0, 6]]),
737737
)
738738

739739
def test_MulDS(self):
740740
self._testDS(
741-
mul,
741+
multiply,
742742
np.array([[1.0, 0], [3, 0], [0, 6]]),
743743
np.array([[1.0, 2], [3, 0], [0, 6]]),
744744
)
@@ -783,7 +783,7 @@ def _testSS(
783783
assert np.all(val.todense() == array1 + array2)
784784
if dtype1.startswith("float") and dtype2.startswith("float"):
785785
verify_grad_sparse(op, [a, b], structured=False)
786-
elif op is mul:
786+
elif op is multiply:
787787
assert np.all(val.todense() == array1 * array2)
788788
if dtype1.startswith("float") and dtype2.startswith("float"):
789789
verify_grad_sparse(op, [a, b], structured=False)
@@ -833,7 +833,7 @@ def _testSD(
833833
continue
834834
if dtype1.startswith("float") and dtype2.startswith("float"):
835835
verify_grad_sparse(op, [a, b], structured=True)
836-
elif op is mul:
836+
elif op is multiply:
837837
assert _is_sparse_variable(apb)
838838
assert np.all(val.todense() == b.multiply(array1))
839839
assert np.all(
@@ -887,7 +887,7 @@ def _testDS(
887887
b = b.data
888888
if dtype1.startswith("float") and dtype2.startswith("float"):
889889
verify_grad_sparse(op, [a, b], structured=True)
890-
elif op is mul:
890+
elif op is multiply:
891891
assert _is_sparse_variable(apb)
892892
ans = np.array([[1, 0], [9, 0], [0, 36]])
893893
assert np.all(val.todense() == (a.multiply(array2)))

0 commit comments

Comments
 (0)