From 878deb27b430f96e0420a96d171210b094efe50e Mon Sep 17 00:00:00 2001 From: roudik Date: Thu, 14 Sep 2023 16:25:55 +0400 Subject: [PATCH] Implemented logical_xor --- ivy/functional/frontends/torch/tensor.py | 14 +++++++ .../test_frontends/test_torch/test_tensor.py | 40 ++++++++++++++++++- 2 files changed, 53 insertions(+), 1 deletion(-) diff --git a/ivy/functional/frontends/torch/tensor.py b/ivy/functional/frontends/torch/tensor.py index 20c0c5326a8df..40348d2bf6698 100644 --- a/ivy/functional/frontends/torch/tensor.py +++ b/ivy/functional/frontends/torch/tensor.py @@ -437,6 +437,20 @@ def logical_not_(self): def logical_or(self, other): return torch_frontend.logical_or(self, other) + @with_unsupported_dtypes( + { + "2.0.1 and below": ( + "bfloat16", + "uint64", + "uint32", + "uint16", + ) + }, + "torch", + ) + def logical_xor(self, other): + return torch_frontend.logical_xor(self, other) + def bitwise_not(self): return torch_frontend.bitwise_not(self) diff --git a/ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py b/ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py index cdd54b5cf2ccd..da5520ddc71fe 100644 --- a/ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py +++ b/ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py @@ -8536,7 +8536,7 @@ def test_torch_tensor_logical_not_( init_tree="torch.tensor", method_name="logical_or", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), + available_dtypes=helpers.get_dtypes("bool", "integer"), num_arrays=2, ), ) @@ -8568,6 +8568,44 @@ def test_torch_tensor_logical_or( ) +# logical_xor +@handle_frontend_method( + class_tree=CLASS_TREE, + init_tree="torch.tensor", + method_name="logical_xor", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + num_arrays=2, + ), +) +def test_torch_tensor_logical_xor( + dtype_and_x, + frontend_method_data, + init_flags, + method_flags, + frontend, + on_device, + backend_fw, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_method( + init_input_dtypes=input_dtype, + backend_to_test=backend_fw, + init_all_as_kwargs_np={ + "data": x[0], + }, + method_input_dtypes=input_dtype, + method_all_as_kwargs_np={ + "other": x[1], + }, + frontend_method_data=frontend_method_data, + init_flags=init_flags, + method_flags=method_flags, + frontend=frontend, + on_device=on_device, + ) + + # long @handle_frontend_method( class_tree=CLASS_TREE,