Skip to content

Commit

Permalink
Merge pull request #10184 from jakevdp:merge-bcoo-dot-general
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 440178509
  • Loading branch information
jax authors committed Apr 7, 2022
2 parents 96af4d5 + 01e4fa8 commit 28cb44e
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 51 deletions.
2 changes: 0 additions & 2 deletions jax/experimental/sparse/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,8 +201,6 @@
bcoo_multiply_dense as bcoo_multiply_dense,
bcoo_multiply_sparse as bcoo_multiply_sparse,
bcoo_reduce_sum as bcoo_reduce_sum,
bcoo_rdot_general as bcoo_rdot_general,
bcoo_spdot_general as bcoo_spdot_general,
bcoo_spdot_general_p as bcoo_spdot_general_p,
bcoo_todense as bcoo_todense,
bcoo_todense_p as bcoo_todense_p,
Expand Down
60 changes: 19 additions & 41 deletions jax/experimental/sparse/bcoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,17 +622,31 @@ def bcoo_dot_general(lhs, rhs, *, dimension_numbers):
"""A general contraction operation.
Args:
lhs: A BCOO-format array.
rhs: An ndarray.
lhs: An ndarray or BCOO-format sparse array.
rhs: An ndarray or BCOO-format sparse array..
dimension_numbers: a tuple of tuples of the form
`((lhs_contracting_dims, rhs_contracting_dims),
(lhs_batch_dims, rhs_batch_dims))`.
Returns:
An ndarray containing the result.
An ndarray or BCOO-format sparse array containing the result. If both inputs
are sparse, the result will be sparse, of type BCOO. If either input is dense,
the result will be dense, of type ndarray.
"""
return _bcoo_dot_general(*lhs._bufs, rhs, dimension_numbers=dimension_numbers,
lhs_spinfo=lhs._info)
if isinstance(lhs, BCOO) and isinstance(rhs, BCOO):
shape = _dot_general_validated_shape(lhs.shape, rhs.shape, dimension_numbers)
bufs = _bcoo_spdot_general(lhs.data, lhs.indices, rhs.data, rhs.indices,
lhs_spinfo=lhs._info, rhs_spinfo=rhs._info,
dimension_numbers=dimension_numbers)
return BCOO(bufs, shape=shape)
elif isinstance(lhs, BCOO):
return _bcoo_dot_general(*lhs._bufs, rhs, dimension_numbers=dimension_numbers,
lhs_spinfo=lhs._info)
elif isinstance(rhs, BCOO):
return _bcoo_rdot_general(lhs, *rhs._bufs, dimension_numbers=dimension_numbers,
rhs_spinfo=rhs._info)
else:
return lax.dot_general(lhs, rhs, dimension_numbers=dimension_numbers)

def _bcoo_dot_general(lhs_data, lhs_indices, rhs, *, dimension_numbers, lhs_spinfo: BCOOInfo):
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers
Expand All @@ -644,23 +658,6 @@ def _bcoo_dot_general(lhs_data, lhs_indices, rhs, *, dimension_numbers, lhs_spin
dimension_numbers=(cdims, bdims),
lhs_spinfo=lhs_spinfo)

def bcoo_rdot_general(lhs, rhs, *, dimension_numbers: DotDimensionNumbers):
"""A general contraction operation.
Args:
lhs: An ndarray.
rhs: A BCOO-format array.
dimension_numbers: a tuple of tuples of the form
`((lhs_contracting_dims, rhs_contracting_dims),
(lhs_batch_dims, rhs_batch_dims))`.
Returns:
An ndarray containing the result.
"""
return _bcoo_rdot_general(lhs, rhs.data, rhs.indices,
dimension_numbers=dimension_numbers,
rhs_spinfo=rhs._info)

def _bcoo_rdot_general(lhs, rhs_data, rhs_indices, *, dimension_numbers: DotDimensionNumbers, rhs_spinfo: BCOOInfo):
# TODO(jakevdp): perhaps this should be part of the bcoo_dot_general primitive?
result = _bcoo_dot_general(rhs_data, rhs_indices, lhs, lhs_spinfo=rhs_spinfo,
Expand Down Expand Up @@ -1017,25 +1014,6 @@ def impl(A, B, indices):
bcoo_spdot_general_p = core.Primitive('bcoo_spdot_general')
bcoo_spdot_general_p.multiple_results = True

def bcoo_spdot_general(lhs, rhs, *, dimension_numbers: DotDimensionNumbers):
"""A general contraction operation.
Args:
lhs: A BCOO-format array.
rhs: A BCOO-format array.
dimension_numbers: a tuple of tuples of the form
`((lhs_contracting_dims, rhs_contracting_dims),
(lhs_batch_dims, rhs_batch_dims))`.
Returns:
A BCOO array containing the result.
"""
shape = _dot_general_validated_shape(lhs.shape, rhs.shape, dimension_numbers)
data, indices = _bcoo_spdot_general(lhs.data, lhs.indices, rhs.data, rhs.indices,
lhs_spinfo=lhs._info, rhs_spinfo=rhs._info,
dimension_numbers=dimension_numbers)
return BCOO((data, indices), shape=shape)

def _bcoo_spdot_general(lhs_data, lhs_indices, rhs_data, rhs_indices, *, lhs_spinfo: BCOOInfo, rhs_spinfo: BCOOInfo, dimension_numbers: DotDimensionNumbers):
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers
cdims = (api_util._ensure_index_tuple(lhs_contract),
Expand Down
10 changes: 2 additions & 8 deletions jax/experimental/sparse/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,14 +434,8 @@ def func(spenv, *spvalues, **kwargs):
def _dot_general_sparse(spenv, *spvalues, dimension_numbers, precision, preferred_element_type):
# TODO(jakevdp): pass along these unused configurations?
del precision, preferred_element_type # unused
if spvalues[0].is_sparse() and spvalues[1].is_sparse():
func = sparse.bcoo_spdot_general
elif spvalues[0].is_sparse():
func = sparse.bcoo_dot_general
else:
func = sparse.bcoo_rdot_general
A, B = spvalues_to_arrays(spenv, spvalues)
result = func(A, B, dimension_numbers=dimension_numbers)
result = sparse.bcoo_dot_general(*spvalues_to_arrays(spenv, spvalues),
dimension_numbers=dimension_numbers)
return arrays_to_spvalues(spenv, [result])

sparse_rules[lax.dot_general_p] = _dot_general_sparse
Expand Down
3 changes: 3 additions & 0 deletions tests/sparse_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1815,6 +1815,9 @@ def test_bcoo_bad_fillvals(self):
# bcoo_dot_general
self.assertArraysEqual(x_sp @ y_de, x_de @ y_de)

# bcoo_rdot_general
self.assertArraysEqual(x_de @ y_sp, x_de @ y_de)

# bcoo_spdot_general
self.assertArraysEqual((x_sp @ y_sp).todense(), x_de @ y_de)
self.assertArraysEqual((y_sp @ x_sp).todense(), y_de @ x_de)
Expand Down

0 comments on commit 28cb44e

Please sign in to comment.