diff --git a/ivy/functional/frontends/torch/miscellaneous_ops.py b/ivy/functional/frontends/torch/miscellaneous_ops.py index 0f1d694e32e50..c5892f3b51afa 100644 --- a/ivy/functional/frontends/torch/miscellaneous_ops.py +++ b/ivy/functional/frontends/torch/miscellaneous_ops.py @@ -236,6 +236,22 @@ def gcd(input, other, *, out=None): return ivy.gcd(input, other, out=out) +@to_ivy_arrays_and_back +def histc(input, bins=100, min=0, max=0, *, out=None): + input_ivy = ivy.flatten( + input + ) # torch.histc results in a 1D tensor so I flattened it out. + if min == 0 and max == 0: + min = ivy.min(input_ivy) + max = ivy.max(input_ivy) + if min == max: + range = (min - 1e-2, min + 1e-2) + else: + range = (min, max) + + return ivy.histogram(input_ivy, bins=bins, range=range, axis=0) + + @to_ivy_arrays_and_back def kron(input, other, *, out=None): return ivy.kron(input, other, out=out) @@ -541,3 +557,22 @@ def view_as_real(input): re_part = ivy.real(input) im_part = ivy.imag(input) return ivy.stack((re_part, im_part), axis=-1) + + +@to_ivy_arrays_and_back +@with_supported_dtypes({"2.0.1 and below": ("float32", "float64")}, +def histc(input, bins=100, min=0, max=0, *, out= None): + + input_ivy = ivy.flatten(input) # torch.histc results in a 1D tensor so I flattened it out. + + if min ==0 and max ==0: + min = ivy.min(input_ivy); max = ivy.max(input_ivy) + if min == max: + min = min - 5.000000e-01 + max = max + 5.000000e-01 + + range = (min , max) + if range[0] > range[1]: + raise ivy.exceptions.IvyError("Max must be greater than or equal to min") + + return ivy.histogram(input_ivy, bins = bins, range=range, axis =0, out =out, dtype = input.dtype) diff --git a/ivy_tests/test_ivy/test_frontends/test_torch/test_miscellaneous_ops.py b/ivy_tests/test_ivy/test_frontends/test_torch/test_miscellaneous_ops.py index 992c6a7689d1e..cbebf32cb7f45 100644 --- a/ivy_tests/test_ivy/test_frontends/test_torch/test_miscellaneous_ops.py +++ b/ivy_tests/test_ivy/test_frontends/test_torch/test_miscellaneous_ops.py @@ -1143,6 +1143,34 @@ def test_torch_gcd( ) +# histc +@handle_frontend_test( + fn_tree="torch.histc", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + ), +) +def test_torch_histc( + *, + dtype_and_x, + on_device, + fn_tree, + frontend, + backend_fw, + test_flags, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + frontend=frontend, + backend_to_test=backend_fw, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + input=x[0], + ) + + # kron @handle_frontend_test( fn_tree="torch.kron", @@ -1813,14 +1841,48 @@ def test_torch_view_as_real( on_device, fn_tree, frontend, + backend_fw, test_flags, ): input_dtype, x = dtype_and_x helpers.test_frontend_function( input_dtypes=input_dtype, frontend=frontend, + backend_to_test=backend_fw, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, input=np.asarray(x[0], dtype=input_dtype[0]), ) +<<<<<<< HEAD +======= + + +# histc +@handle_frontend_test( + fn_tree= "torch.histc", + dtype_and_x = helpers.dtype_and_values( + available_dtypes= helpers.get_dtypes("float"), + ), + ) + +def test_torch_histc( + *, + dtype_and_x, + on_device, + fn_tree, + frontend, + backend_fw, + test_flags, +): + input_dtype , x = dtype_and_x + helpers.test_frontend_function( + input_dtypes= input_dtype, + frontend=frontend, + backend_to_test=backend_fw, + test_flags = test_flags, + fn_tree= fn_tree, + on_device=on_device, + input=np.asarray(x[0], dtype=input_dtype[0]), + ) +>>>>>>> 0e45c14f2 (Changed histc for min==max case)