diff --git a/ivy/functional/frontends/paddle/tensor/tensor.py b/ivy/functional/frontends/paddle/tensor/tensor.py index fba3ef539803c..8e74ef2c5d529 100644 --- a/ivy/functional/frontends/paddle/tensor/tensor.py +++ b/ivy/functional/frontends/paddle/tensor/tensor.py @@ -269,6 +269,10 @@ def add_(self, y, name=None): self.ivy_array = paddle_frontend.add(self, y).ivy_array return self + @with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") + def addmm(self, x, y, beta=1.0, alpha=1.0, name=None): + return paddle_frontend.addmm(self, x, y, beta, alpha) + @with_supported_dtypes( {"2.5.1 and below": ("float16", "float32", "float64", "int32", "int64")}, "paddle", diff --git a/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_tensor.py b/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_tensor.py index 149ba982cdc92..59720ff9cd2ce 100644 --- a/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_tensor.py +++ b/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_tensor.py @@ -16,6 +16,9 @@ from ivy_tests.test_ivy.test_functional.test_core.test_statistical import ( _statistical_dtype_values, ) +from ivy_tests.test_ivy.test_frontends.test_torch.test_blas_and_lapack_ops import ( + _get_dtype_and_3dbatch_matrices, +) CLASS_TREE = "ivy.functional.frontends.paddle.Tensor" @@ -514,6 +517,56 @@ def test_paddle_tensor_add_n( ) +# addmm +@handle_frontend_method( + class_tree=CLASS_TREE, + init_tree="paddle.to_tensor", + method_name="addmm", + dtype_input_xy=_get_dtype_and_3dbatch_matrices(with_input=True, input_3d=True), + beta=st.floats( + min_value=-5, + max_value=5, + allow_nan=False, + allow_subnormal=False, + allow_infinity=False, + ), + alpha=st.floats( + min_value=-5, + max_value=5, + allow_nan=False, + allow_subnormal=False, + allow_infinity=False, + ), +) +def test_paddle_tensor_addmm( + *, + dtype_input_xy, + beta, + alpha, + frontend_method_data, + init_flags, + method_flags, + frontend, + on_device, + backend_fw, +): + input_dtype, input, x, y = dtype_input_xy + helpers.test_frontend_method( + init_input_dtypes=input_dtype, + backend_to_test=backend_fw, + init_all_as_kwargs_np={ + "data": input[0], + }, + method_input_dtypes=input_dtype, + method_all_as_kwargs_np={"x": x[0], "y": y[0], "beta": beta, "alpha": alpha}, + frontend_method_data=frontend_method_data, + init_flags=init_flags, + method_flags=method_flags, + frontend=frontend, + on_device=on_device, + ) + + # all @handle_frontend_method( class_tree=CLASS_TREE,