Skip to content

Commit

Permalink
feat(testing): transpiling on all frameworks now for generating repor…
Browse files Browse the repository at this point in the history
…t.json(#25867)
  • Loading branch information
sherry30 authored Sep 27, 2023
1 parent dfe5ea0 commit acebe2b
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 21 deletions.
48 changes: 31 additions & 17 deletions ivy_tests/test_ivy/helpers/function_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -954,13 +954,16 @@ def test_frontend_function(
frontend_fw_fn = frontend_fw.__dict__[gt_fn_name]
frontend_ret = frontend_fw_fn(*args_frontend, **kwargs_frontend)

# ToDo: only traces and does inference on ivy arrays for now
if test_flags.transpile and hasattr(frontend_config, "backend_str"):
if test_flags.transpile:
_get_transpiled_data_if_required(
frontend_fn,
frontend_fw_fn,
frontend,
backend_to_test,
fn_name=f"{gt_frontend_submods}.{gt_fn_name}",
generate_frontend_arrays=test_flags.generate_frontend_arrays,
args_for_test=args_for_test,
kwargs_for_test=kwargs_for_test,
frontend_fw_args=args_frontend,
frontend_fw_kwargs=kwargs_frontend,
)
Expand Down Expand Up @@ -2437,31 +2440,41 @@ def _get_transpiled_data_if_required(
frontend_fn,
frontend_fw_fn,
frontend,
backend,
fn_name,
generate_frontend_arrays,
args_for_test,
kwargs_for_test,
frontend_fw_args,
frontend_fw_kwargs,
):
iterations = 1

# to trace the frontend function on ivy arrays
with BackendHandler.update_backend(frontend) as ivy_backend:
args, kwargs = ivy_backend.args_to_ivy(*frontend_fw_args, **frontend_fw_kwargs)

# for backend transpilation
with BackendHandler.update_backend(backend) as ivy_backend:
if generate_frontend_arrays:
args_for_test, kwargs_for_test = ivy.nested_map(
_frontend_array_to_ivy,
(args_for_test, kwargs_for_test),
include_derived={"tuple": True},
)
else:
args_for_test, kwargs_for_test = ivy_backend.args_to_ivy(
*args_for_test, **kwargs_for_test
)
traced_fn = traced_if_required(
frontend,
backend,
frontend_fn,
test_trace=True,
args=args,
kwargs=kwargs,
args=args_for_test,
kwargs=kwargs_for_test,
)

# running inference to get runtime
frontend_timings = []
frontend_fw_timings = []
for i in range(0, iterations):
# timing the traced_fn
start = time.time()
traced_fn(*args, **kwargs)
traced_fn(*args_for_test, **kwargs_for_test)
end = time.time()
frontend_timings.append(end - start)

Expand All @@ -2471,12 +2484,11 @@ def _get_transpiled_data_if_required(
end = time.time()
frontend_fw_timings.append(end - start)

# trace to get ivy nodes
with BackendHandler.update_backend(frontend) as ivy_backend:
# compile to get ivy nodes
with BackendHandler.update_backend(backend) as ivy_backend:
traced_fn_to_ivy = ivy_backend.trace_graph(
frontend_fn, to="ivy", args=args, kwargs=kwargs
frontend_fn, to="ivy", args=args_for_test, kwargs=kwargs_for_test
)

frontend_time = np.mean(frontend_timings).item()
frontend_fw_time = np.mean(frontend_fw_timings).item()
backend_nodes = len(traced_fn._functions)
Expand All @@ -2485,14 +2497,16 @@ def _get_transpiled_data_if_required(
data = {
"frontend": frontend,
"frontend_func": fn_name,
"args": str(args_for_test),
"kwargs": str(kwargs_for_test),
"frontend_time": frontend_time,
"frontend_fw_time": frontend_fw_time,
"backend_nodes": backend_nodes,
"ivy_nodes": ivy_nodes,
}

# creating json object and creating a file
_create_transpile_report(data, "report.json")
_create_transpile_report(data, backend, "report.json")


def args_to_container(array_args):
Expand Down
22 changes: 18 additions & 4 deletions ivy_tests/test_ivy/helpers/testing_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -901,13 +901,27 @@ def seed(draw):
return draw(st.integers(min_value=0, max_value=2**8 - 1))


def _create_transpile_report(data: dict, file_name: str):
json_object = json.dumps(data, indent=6)
def _create_transpile_report(data: dict, backend: str, file_name: str):
if os.path.isfile(file_name):
with open(file_name, "r") as outfile:
# Load the file's existing data
data = json.load(outfile)
if data["backend_nodes"] > data["backend_nodes"]:
file_data = json.load(outfile)
if file_data["backend_nodes"].get(backend, 0) > data["backend_nodes"]:
return
file_data["backend_nodes"][backend] = data["backend_nodes"]
file_data["frontend_time"][backend] = data["frontend_time"]
file_data["args"][backend] = data["args"]
file_data["kwargs"][backend] = data["kwargs"]
file_data["ivy_nodes"] = data["ivy_nodes"]
file_data["frontend_fw_time"] = data["frontend_fw_time"]
json_object = json.dumps(file_data, indent=6)
with open(file_name, "w") as outfile:
outfile.write(json_object)
return
data["backend_nodes"] = {backend: data["backend_nodes"]}
data["frontend_time"] = {backend: data["frontend_time"]}
data["args"] = {backend: data["args"]}
data["kwargs"] = {backend: data["kwargs"]}
json_object = json.dumps(data, indent=6)
with open(file_name, "w") as outfile:
outfile.write(json_object)

0 comments on commit acebe2b

Please sign in to comment.