diff --git a/ivy/functional/frontends/torch/creation_ops.py b/ivy/functional/frontends/torch/creation_ops.py index 118efd4e5dab3..299619b89b030 100644 --- a/ivy/functional/frontends/torch/creation_ops.py +++ b/ivy/functional/frontends/torch/creation_ops.py @@ -7,7 +7,6 @@ from ivy.func_wrapper import ( with_unsupported_dtypes, with_supported_dtypes, - handle_out_argument, ) @@ -64,6 +63,24 @@ def asarray( return ivy.asarray(obj, copy=copy, dtype=dtype, device=device) +@with_supported_dtypes({"2.0.1 and below": ("float32", "float64")}, "torch") +@to_ivy_arrays_and_back +def complex( + real, + imag, + *, + out=None, +): + assert real.dtype == imag.dtype, ValueError( + "Expected real and imag to have the same dtype, " + f" but got real.dtype = {real.dtype} and imag.dtype = {imag.dtype}." + ) + + complex_dtype = ivy.complex64 if real.dtype != ivy.float64 else ivy.complex128 + complex_array = real + imag * 1j + return complex_array.astype(complex_dtype, out=out) + + @to_ivy_arrays_and_back def empty( *args, @@ -227,6 +244,16 @@ def ones_like_v_0p4p0_and_above( return ret +@with_supported_dtypes({"2.0.1 and below": ("float32", "float64")}, "torch") +def polar( + abs, + angle, + *, + out=None, +): + return complex(abs * angle.cos(), abs * angle.sin(), out=out) + + @to_ivy_arrays_and_back @with_unsupported_dtypes({"2.0.1 and below": ("float16",)}, "torch") def range( @@ -301,32 +328,3 @@ def zeros_like( ): ret = ivy.zeros_like(input, dtype=dtype, device=device) return ret - - -@with_supported_dtypes({"2.0.1 and below": ("float32", "float64")}, "torch") -@to_ivy_arrays_and_back -@handle_out_argument -def complex( - real, - imag, - *, - out=None, -): - assert real.dtype == imag.dtype, ValueError( - f"Expected real and imag to have the same dtype, " - " but got real.dtype = {real.dtype} and imag.dtype = {imag.dtype}." - ) - - complex_dtype = ivy.complex64 if real.dtype != ivy.float64 else ivy.complex128 - complex_array = real + imag * 1j - return complex_array.astype(complex_dtype) - - -@with_supported_dtypes({"2.0.1 and below": ("float32", "float64")}, "torch") -def polar( - abs, - angle, - *, - out=None, -): - return complex(abs * angle.cos(), abs * angle.sin(), out=out) \ No newline at end of file diff --git a/ivy_tests/test_ivy/test_frontends/test_torch/test_creation_ops.py b/ivy_tests/test_ivy/test_frontends/test_torch/test_creation_ops.py index f4301cc80d321..1dbfabdcdacaf 100644 --- a/ivy_tests/test_ivy/test_frontends/test_torch/test_creation_ops.py +++ b/ivy_tests/test_ivy/test_frontends/test_torch/test_creation_ops.py @@ -117,6 +117,60 @@ def _start_stop_step(draw): # ------------ # +# complex +@handle_frontend_test( + fn_tree="torch.complex", + dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), +) +def test_complex( + *, + dtype_and_x, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, input = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + real=input[0], + imag=input[0], + ) + + +# polar +@handle_frontend_test( + fn_tree="torch.polar", + dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), +) +def test_polar( + *, + dtype_and_x, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, input = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + abs=input[0], + angle=input[0], + ) + + # arange @handle_frontend_test( fn_tree="torch.arange", @@ -789,57 +843,3 @@ def test_torch_zeros_like( dtype=dtype[0], device=on_device, ) - - -# complex -@handle_frontend_test( - fn_tree="torch.complex", - dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), -) -def test_complex( - *, - dtype_and_x, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - input_dtype, input = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - real=input[0], - imag=input[0], - ) - - -# polar -@handle_frontend_test( - fn_tree="torch.polar", - dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), -) -def test_polar( - *, - dtype_and_x, - on_device, - fn_tree, - frontend, - test_flags, - backend_fw, -): - input_dtype, input = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - backend_to_test=backend_fw, - frontend=frontend, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - abs=input[0], - angle=input[0], - )