From 083ebb35b5039fa2ed3e7ce98f285c56ec675c49 Mon Sep 17 00:00:00 2001 From: Ishtiaq Hussain <53497039+Ishticode@users.noreply.github.com> Date: Tue, 19 Dec 2023 19:31:18 +0000 Subject: [PATCH] feat: add torch frontend tensor methods index_put and its inplace version along with tests with some todos regarding dtype specification based on https://pytorch.org/docs/stable/generated/torch.Tensor.index_put_.html --- ivy/functional/frontends/torch/tensor.py | 15 +++ .../test_frontends/test_torch/test_tensor.py | 98 +++++++++++++++++++ 2 files changed, 113 insertions(+) diff --git a/ivy/functional/frontends/torch/tensor.py b/ivy/functional/frontends/torch/tensor.py index 23a45a28f533d..46f95c5ec0641 100644 --- a/ivy/functional/frontends/torch/tensor.py +++ b/ivy/functional/frontends/torch/tensor.py @@ -2213,6 +2213,21 @@ def rad2deg(self, *, out=None): def corrcoef(self): return torch_frontend.corrcoef(self) + def index_put(self, indices, values, accumulate=False): + ret = self.clone() + if accumulate: + ret[indices[0]] += values + else: + ret[indices[0]] = values + return ret + + def index_put_(self, indices, values, accumulate=False): + if accumulate: + self[indices] += values + else: + self[indices] = values + return self + # Method aliases absolute, absolute_ = abs, abs_ clip, clip_ = clamp, clamp_ diff --git a/ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py b/ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py index ad382206bb170..ab4c97560d58c 100644 --- a/ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py +++ b/ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py @@ -7692,6 +7692,104 @@ def test_torch_index_fill( ) +# todo: remove dtype specifications +@handle_frontend_method( + class_tree=CLASS_TREE, + init_tree="torch.tensor", + method_name="index_put", + x_and_indices=helpers.array_indices_axis( + array_dtypes=st.just(("float32",)), + indices_dtypes=st.just(("int64",)), + ), + values=helpers.dtype_and_values( + available_dtypes=st.just(("float32",)), max_num_dims=1, max_dim_size=1 + ), + accumulate=st.booleans(), +) +def test_torch_index_put( + x_and_indices, + values, + accumulate, + frontend_method_data, + init_flags, + method_flags, + frontend, + on_device, + backend_fw, +): + input_dtype, x, indices, *_ = x_and_indices + values_dtype, values = values + init_dtypes = [input_dtype[0]] + method_dtypes = [input_dtype[1], values_dtype[0]] + helpers.test_frontend_method( + init_input_dtypes=init_dtypes, + backend_to_test=backend_fw, + init_all_as_kwargs_np={ + "data": x, + }, + method_input_dtypes=method_dtypes, + method_all_as_kwargs_np={ + "indices": (indices,), + "values": values[0], + "accumulate": accumulate, + }, + frontend_method_data=frontend_method_data, + init_flags=init_flags, + method_flags=method_flags, + frontend=frontend, + on_device=on_device, + ) + + +@handle_frontend_method( + class_tree=CLASS_TREE, + init_tree="torch.tensor", + method_name="index_put_", + x_and_indices=helpers.array_indices_axis( + array_dtypes=st.just(("float32",)), + indices_dtypes=st.just(("int64",)), + ), + values=helpers.dtype_and_values( + available_dtypes=st.just(("float32",)), max_num_dims=1, max_dim_size=1 + ), + accumulate=st.booleans(), + test_inplace=st.just(True), +) +def test_torch_index_put_( + x_and_indices, + values, + accumulate, + frontend_method_data, + init_flags, + method_flags, + frontend, + on_device, + backend_fw, +): + input_dtype, x, indices, *_ = x_and_indices + values_dtype, values = values + init_dtypes = [input_dtype[0]] + method_dtypes = [input_dtype[1], values_dtype[0]] + helpers.test_frontend_method( + init_input_dtypes=init_dtypes, + backend_to_test=backend_fw, + init_all_as_kwargs_np={ + "data": x, + }, + method_input_dtypes=method_dtypes, + method_all_as_kwargs_np={ + "indices": (indices,), + "values": values[0], + "accumulate": accumulate, + }, + frontend_method_data=frontend_method_data, + init_flags=init_flags, + method_flags=method_flags, + frontend=frontend, + on_device=on_device, + ) + + # index_select @handle_frontend_method( class_tree=CLASS_TREE,