Skip to content

Commit

Permalink
Correctly handle 0-length inputs in the numpy substrate tf.linalg.lstsq.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 604442726
  • Loading branch information
SiegeLordEx authored and tensorflower-gardener committed Feb 5, 2024
1 parent 230463a commit eb29f6e
Showing 1 changed file with 4 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,10 @@ def _lstsq(matrix, rhs, l2_regularizer=0.0, fast=True, name=None):
if JAX_MODE:
import jax # pylint: disable=g-import-not-at-top
return jax.vmap(functools.partial(_lstsq, fast=False))(matrix, rhs)
return np.array([_lstsq(mat, r, fast=False) for mat, r in zip(matrix, rhs)])
res = np.array([_lstsq(mat, r, fast=False) for mat, r in zip(matrix, rhs)])
if matrix.shape[0] == 0:
res = res.reshape(matrix.shape[:-2] + (matrix.shape[-1], rhs.shape[-1]))
return res
rcond = None
if JAX_MODE and matrix.dtype == np.float32:
rcond = 0. # https://github.com/google/jax/issues/15591
Expand Down

0 comments on commit eb29f6e

Please sign in to comment.