diff --git a/ivy/functional/frontends/torch/creation_ops.py b/ivy/functional/frontends/torch/creation_ops.py index b275cfb9d5db8..f41731483fe24 100644 --- a/ivy/functional/frontends/torch/creation_ops.py +++ b/ivy/functional/frontends/torch/creation_ops.py @@ -4,7 +4,10 @@ to_ivy_arrays_and_back, to_ivy_shape, ) -from ivy.func_wrapper import with_unsupported_dtypes +from ivy.func_wrapper import ( + with_unsupported_dtypes, + with_supported_dtypes, +) import ivy.functional.frontends.torch as torch_frontend @@ -71,6 +74,24 @@ def asarray( return ivy.asarray(obj, copy=copy, dtype=dtype, device=device) +@with_supported_dtypes({"2.0.1 and below": ("float32", "float64")}, "torch") +@to_ivy_arrays_and_back +def complex( + real, + imag, + *, + out=None, +): + assert real.dtype == imag.dtype, ValueError( + "Expected real and imag to have the same dtype, " + f" but got real.dtype = {real.dtype} and imag.dtype = {imag.dtype}." + ) + + complex_dtype = ivy.complex64 if real.dtype != ivy.float64 else ivy.complex128 + complex_array = real + imag * 1j + return complex_array.astype(complex_dtype, out=out) + + @to_ivy_arrays_and_back def empty( *args, @@ -234,6 +255,17 @@ def ones_like_v_0p4p0_and_above( return ret +@with_supported_dtypes({"2.0.1 and below": ("float32", "float64")}, "torch") +@to_ivy_arrays_and_back +def polar( + abs, + angle, + *, + out=None, +): + return complex(abs * angle.cos(), abs * angle.sin(), out=out) + + @to_ivy_arrays_and_back @with_unsupported_dtypes({"2.0.1 and below": ("float16",)}, "torch") def range( diff --git a/ivy_tests/test_ivy/test_frontends/test_torch/test_creation_ops.py b/ivy_tests/test_ivy/test_frontends/test_torch/test_creation_ops.py index ebcbee64b678d..781881e18e665 100644 --- a/ivy_tests/test_ivy/test_frontends/test_torch/test_creation_ops.py +++ b/ivy_tests/test_ivy/test_frontends/test_torch/test_creation_ops.py @@ -150,6 +150,61 @@ def _start_stop_step(draw): # ------------ # +# complex +@handle_frontend_test( + fn_tree="torch.complex", + dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), +) +def test_complex( + *, + dtype_and_x, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, input = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + real=input[0], + imag=input[0], + ) + + +# polar +@handle_frontend_test( + fn_tree="torch.polar", + dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), + test_with_out=st.just(False), +) +def test_polar( + *, + dtype_and_x, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, input = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + abs=input[0], + angle=input[0], + ) + + # arange @handle_frontend_test( fn_tree="torch.arange",