diff --git a/ivy/functional/frontends/torch/tensor.py b/ivy/functional/frontends/torch/tensor.py index 13a056436f6a7..f784dccd8fd95 100644 --- a/ivy/functional/frontends/torch/tensor.py +++ b/ivy/functional/frontends/torch/tensor.py @@ -1893,6 +1893,29 @@ def index_fill(self, dim, index, value): arr = torch_frontend.moveaxis(self, 0, dim) return arr + @with_unsupported_dtypes( + { + "2.0.1 and below": ( + "bfloat16", + "int8", + "uint8", + "uint32", + "uint16", + "uint64", + "int16", + "float16", + "complex128", + "complex64", + "bool", + ) + }, + "torch", + ) + def unique_consecutive(self, return_inverse, return_counts, dim): + return torch_frontend.unique_consecutive( + self, return_inverse, return_counts, dim + ) + @with_unsupported_dtypes( { "2.0.1 and below": ( diff --git a/ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py b/ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py index 76053d9758076..2fc81319b5ef7 100644 --- a/ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py +++ b/ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py @@ -12816,3 +12816,48 @@ def test_torch_unique( frontend=frontend, on_device=on_device, ) + + +# unique_consecutive +@handle_frontend_method( + class_tree=CLASS_TREE, + init_tree="torch.tensor", + method_name="unique_consecutive", + dtype_x_axis=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("valid"), + min_num_dims=2, + min_dim_size=2, + force_int_axis=True, + valid_axis=True, + ), + return_inverse=st.booleans(), + return_counts=st.booleans(), +) +def test_torch_unique_consecutive( + dtype_x_axis, + return_inverse, + return_counts, + frontend, + frontend_method_data, + init_flags, + method_flags, + on_device, + backend_fw, +): + input_dtype, x, axis = dtype_x_axis + helpers.test_frontend_method( + init_input_dtypes=input_dtype, + backend_to_test=backend_fw, + init_all_as_kwargs_np={"data": x[0]}, + method_input_dtypes=input_dtype, + method_all_as_kwargs_np={ + "return_inverse": return_inverse, + "return_counts": return_counts, + "dim": axis, + }, + frontend_method_data=frontend_method_data, + init_flags=init_flags, + method_flags=method_flags, + frontend=frontend, + on_device=on_device, + )