From 674efa6751667eaa2a8010e02fe8e46d1acc5017 Mon Sep 17 00:00:00 2001 From: Hadeer Arafa Date: Sat, 2 Sep 2023 17:42:06 +0300 Subject: [PATCH] implement lstsq --- ivy/functional/frontends/tensorflow/linalg.py | 43 ++++++++++++++++++ .../test_tensorflow/test_linalg.py | 44 +++++++++++++++++++ 2 files changed, 87 insertions(+) diff --git a/ivy/functional/frontends/tensorflow/linalg.py b/ivy/functional/frontends/tensorflow/linalg.py index dfd8e20537eb4..07191f1d3d600 100644 --- a/ivy/functional/frontends/tensorflow/linalg.py +++ b/ivy/functional/frontends/tensorflow/linalg.py @@ -177,6 +177,49 @@ def logdet(matrix, name=None): return ivy.det(matrix).log() +@to_ivy_arrays_and_back +@with_supported_dtypes( + {"2.13.0 and below": ("float32", "float64", "complex64")}, "tensorflow" +) +def lstsq( + matrix, + rhs, + l2_regularizer=0.0, + fast=True, +): + matrix_num_dim = matrix.get_num_dims() + rhs_num_dim = rhs.get_num_dims() + if matrix_num_dim < 2: + raise RuntimeError("input must have at least 2 dimensions. ") + if matrix_num_dim - rhs_num_dim <= 1: + for i in range( + matrix_num_dim - 1 + ): # should have the same batch shape and same m shape + if matrix.shape[i] != rhs.shape[i]: + raise RuntimeError(f" input.size({i}) should match other.size({i})") + else: + raise RuntimeError( + "input.dim() must be greater or equal to other.dim() and (input.dim() -" + " other.dim()) <= 1" + ) + if l2_regularizer != 0: + raise NotImplementedError( + "linalg.lstsq is currently disabled for complex128 and l2_regularizer != 0" + " due to poor accuracy." + ) + + matrix_dtype = matrix.dtype + matrix = ivy.astype(matrix, ivy.float64) + rhs = ivy.astype(rhs, ivy.float64) + + q, r = ivy.qr(matrix) + r_inv = ivy.pinv(r) + solution = ivy.matmul(ivy.matmul(r_inv, ivy.matrix_transpose(q)), rhs) + solution = ivy.astype(solution, matrix_dtype) + + return solution + + @to_ivy_arrays_and_back def lu_matrix_inverse(lower_upper, perm, validate_args=False, name=None): return ivy.lu_matrix_inverse( diff --git a/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_linalg.py b/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_linalg.py index c33001f11444f..134e194fb71b9 100644 --- a/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_linalg.py +++ b/ivy_tests/test_ivy/test_frontends/test_tensorflow/test_linalg.py @@ -93,6 +93,22 @@ def _get_dtype_and_rank_2k_tensors(draw): ) +@st.composite +def _get_dtype_and_same_dim_matrix(draw): + randam_shape = draw(helpers.get_shape(min_num_dims=2, max_num_dims=4)) + dtype_and_values = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes(kind="valid", full=False), + num_arrays=2, + shape=randam_shape, + shared_dtype=True, + min_value=1, + max_value=4, + ) + ) + return dtype_and_values + + @st.composite def _get_dtype_and_sequence_of_arrays(draw): array_dtype = draw(helpers.get_dtypes("float", full=False)) @@ -747,6 +763,34 @@ def test_tensorflow_logdet( ) +@handle_frontend_test( + fn_tree="tensorflow.linalg.lstsq", dtype_x=_get_dtype_and_same_dim_matrix() +) +def test_tensorflow_lstsq( + *, + dtype_x, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, values = dtype_x + test_flags.num_positional_args = 2 + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + test_flags=test_flags, + frontend=frontend, + fn_tree=fn_tree, + on_device=on_device, + rtol=1e-03, + atol=1e-03, + a=values[0], + b=values[1], + ) + + @handle_frontend_test( fn_tree="tensorflow.linalg.matmul", dtype_x=helpers.dtype_and_values(