diff --git a/ivy_tests/test_ivy/helpers/assertions.py b/ivy_tests/test_ivy/helpers/assertions.py index ee1099e0577ad..a08d20bb1f678 100644 --- a/ivy_tests/test_ivy/helpers/assertions.py +++ b/ivy_tests/test_ivy/helpers/assertions.py @@ -84,6 +84,23 @@ def assert_same_type_and_shape(values, this_key_chain=None): ), "returned dtype = {}, ground-truth returned dtype = {}".format(x_d, y_d) +def assert_same_type(ret_from_target, ret_from_gt, backend_to_test, gt_backend): + """ + Assert that the return types from the target and ground truth frameworks are the + same. + + checks with a string comparison because with_backend returns + different objects. Doesn't check recursively. + """ + # ToDo: do this with nested map + assert_msg = ( + f"ground truth backend ({gt_backend}) returned" + f" {type(ret_from_gt)} but target backend ({backend_to_test}) returned" + f" {type(ret_from_target)}" + ) + assert str(type(ret_from_target)) == str(type(ret_from_gt)), assert_msg + + def value_test( *, ret_np_flat, diff --git a/ivy_tests/test_ivy/helpers/function_testing.py b/ivy_tests/test_ivy/helpers/function_testing.py index 42753710bfa3d..fbd3103fc23b0 100644 --- a/ivy_tests/test_ivy/helpers/function_testing.py +++ b/ivy_tests/test_ivy/helpers/function_testing.py @@ -29,6 +29,7 @@ from ivy_tests.test_ivy.helpers.testing_helpers import _create_transpile_report from .assertions import ( value_test, + assert_same_type, check_unsupported_dtype, ) @@ -567,17 +568,6 @@ def test_function( on_device=on_device, ) - assert ret_device == ret_from_gt_device, ( - f"ground truth backend ({test_flags.ground_truth_backend}) returned array on" - f" device {ret_from_gt_device} but target backend ({backend_to_test})" - f" returned array on device {ret_device}" - ) - if ret_device is not None: - assert ret_device == on_device, ( - f"device is set to {on_device}, but ground truth produced array on" - f" {ret_device}" - ) - # assuming value test will be handled manually in the test function if not test_values: if return_flat_np_arrays: @@ -599,6 +589,20 @@ def test_function( backend=backend_to_test, ground_truth_backend=test_flags.ground_truth_backend, ) + assert_same_type( + ret_from_target, ret_from_gt, backend_to_test, test_flags.ground_truth_backend + ) + + assert ret_device == ret_from_gt_device, ( + f"ground truth backend ({test_flags.ground_truth_backend}) returned array on" + f" device {ret_from_gt_device} but target backend ({backend_to_test})" + f" returned array on device {ret_device}" + ) + if ret_device is not None: + assert ret_device == on_device, ( + f"device is set to {on_device}, but ground truth produced array on" + f" {ret_device}" + ) def test_frontend_function(