diff --git a/ivy_tests/test_ivy/helpers/function_testing.py b/ivy_tests/test_ivy/helpers/function_testing.py index 28bf2ad96d175..ee97da28a12ff 100644 --- a/ivy_tests/test_ivy/helpers/function_testing.py +++ b/ivy_tests/test_ivy/helpers/function_testing.py @@ -954,13 +954,16 @@ def test_frontend_function( frontend_fw_fn = frontend_fw.__dict__[gt_fn_name] frontend_ret = frontend_fw_fn(*args_frontend, **kwargs_frontend) - # ToDo: only traces and does inference on ivy arrays for now - if test_flags.transpile and hasattr(frontend_config, "backend_str"): + if test_flags.transpile: _get_transpiled_data_if_required( frontend_fn, frontend_fw_fn, frontend, + backend_to_test, fn_name=f"{gt_frontend_submods}.{gt_fn_name}", + generate_frontend_arrays=test_flags.generate_frontend_arrays, + args_for_test=args_for_test, + kwargs_for_test=kwargs_for_test, frontend_fw_args=args_frontend, frontend_fw_kwargs=kwargs_frontend, ) @@ -2437,31 +2440,41 @@ def _get_transpiled_data_if_required( frontend_fn, frontend_fw_fn, frontend, + backend, fn_name, + generate_frontend_arrays, + args_for_test, + kwargs_for_test, frontend_fw_args, frontend_fw_kwargs, ): iterations = 1 - - # to trace the frontend function on ivy arrays - with BackendHandler.update_backend(frontend) as ivy_backend: - args, kwargs = ivy_backend.args_to_ivy(*frontend_fw_args, **frontend_fw_kwargs) - + # for backend transpilation + with BackendHandler.update_backend(backend) as ivy_backend: + if generate_frontend_arrays: + args_for_test, kwargs_for_test = ivy.nested_map( + _frontend_array_to_ivy, + (args_for_test, kwargs_for_test), + include_derived={"tuple": True}, + ) + else: + args_for_test, kwargs_for_test = ivy_backend.args_to_ivy( + *args_for_test, **kwargs_for_test + ) traced_fn = traced_if_required( - frontend, + backend, frontend_fn, test_trace=True, - args=args, - kwargs=kwargs, + args=args_for_test, + kwargs=kwargs_for_test, ) - # running inference to get runtime frontend_timings = [] frontend_fw_timings = [] for i in range(0, iterations): # timing the traced_fn start = time.time() - traced_fn(*args, **kwargs) + traced_fn(*args_for_test, **kwargs_for_test) end = time.time() frontend_timings.append(end - start) @@ -2471,12 +2484,11 @@ def _get_transpiled_data_if_required( end = time.time() frontend_fw_timings.append(end - start) - # trace to get ivy nodes - with BackendHandler.update_backend(frontend) as ivy_backend: + # compile to get ivy nodes + with BackendHandler.update_backend(backend) as ivy_backend: traced_fn_to_ivy = ivy_backend.trace_graph( - frontend_fn, to="ivy", args=args, kwargs=kwargs + frontend_fn, to="ivy", args=args_for_test, kwargs=kwargs_for_test ) - frontend_time = np.mean(frontend_timings).item() frontend_fw_time = np.mean(frontend_fw_timings).item() backend_nodes = len(traced_fn._functions) @@ -2485,6 +2497,8 @@ def _get_transpiled_data_if_required( data = { "frontend": frontend, "frontend_func": fn_name, + "args": str(args_for_test), + "kwargs": str(kwargs_for_test), "frontend_time": frontend_time, "frontend_fw_time": frontend_fw_time, "backend_nodes": backend_nodes, @@ -2492,7 +2506,7 @@ def _get_transpiled_data_if_required( } # creating json object and creating a file - _create_transpile_report(data, "report.json") + _create_transpile_report(data, backend, "report.json") def args_to_container(array_args): diff --git a/ivy_tests/test_ivy/helpers/testing_helpers.py b/ivy_tests/test_ivy/helpers/testing_helpers.py index 962318d3084af..41061e0fe3700 100644 --- a/ivy_tests/test_ivy/helpers/testing_helpers.py +++ b/ivy_tests/test_ivy/helpers/testing_helpers.py @@ -901,13 +901,27 @@ def seed(draw): return draw(st.integers(min_value=0, max_value=2**8 - 1)) -def _create_transpile_report(data: dict, file_name: str): - json_object = json.dumps(data, indent=6) +def _create_transpile_report(data: dict, backend: str, file_name: str): if os.path.isfile(file_name): with open(file_name, "r") as outfile: # Load the file's existing data - data = json.load(outfile) - if data["backend_nodes"] > data["backend_nodes"]: + file_data = json.load(outfile) + if file_data["backend_nodes"].get(backend, 0) > data["backend_nodes"]: return + file_data["backend_nodes"][backend] = data["backend_nodes"] + file_data["frontend_time"][backend] = data["frontend_time"] + file_data["args"][backend] = data["args"] + file_data["kwargs"][backend] = data["kwargs"] + file_data["ivy_nodes"] = data["ivy_nodes"] + file_data["frontend_fw_time"] = data["frontend_fw_time"] + json_object = json.dumps(file_data, indent=6) + with open(file_name, "w") as outfile: + outfile.write(json_object) + return + data["backend_nodes"] = {backend: data["backend_nodes"]} + data["frontend_time"] = {backend: data["frontend_time"]} + data["args"] = {backend: data["args"]} + data["kwargs"] = {backend: data["kwargs"]} + json_object = json.dumps(data, indent=6) with open(file_name, "w") as outfile: outfile.write(json_object)