diff --git a/src/mlagility/analysis/analysis.py b/src/mlagility/analysis/analysis.py index 4f273e0e..8cbd63e6 100644 --- a/src/mlagility/analysis/analysis.py +++ b/src/mlagility/analysis/analysis.py @@ -89,7 +89,11 @@ def explore_invocation( # Update status to "computing" invocation_info.status_message = "Computing..." invocation_info.status_message_color = printing.Colors.OKBLUE - status.update(tracer_args.models_found) + + build_name = filesystem.get_build_name( + tracer_args.script_name, tracer_args.labels, invocation_info.hash + ) + status.update(tracer_args.models_found, build_name, tracer_args.cache_dir) # Get a copy of the keyword arguments args, kwargs = model_inputs @@ -117,10 +121,6 @@ def explore_invocation( inputs[all_args[i]] = args[i] invocation_info.inputs = inputs - build_name = filesystem.get_build_name( - tracer_args.script_name, tracer_args.labels, invocation_info.hash - ) - # Save model labels tracer_args.labels["class"] = [f"{type(model_info.model).__name__}"] labels.save_to_cache(tracer_args.cache_dir, build_name, tracer_args.labels) @@ -188,7 +188,7 @@ def explore_invocation( # Ensure that stdout is not being forwarded before updating status if hasattr(sys.stdout, "terminal"): sys.stdout = sys.stdout.terminal - status.update(tracer_args.models_found) + status.update(tracer_args.models_found, build_name, tracer_args.cache_dir) if tracer_args.device == "groq": import groqflow.common.build as groq_build @@ -510,7 +510,10 @@ def forward_spy(*args, **kwargs): # Ensure that groqit() doesn't interfere with our execution count model_info.executed = 1 - status.update(tracer_args.models_found) + build_name = filesystem.get_build_name( + tracer_args.script_name, tracer_args.labels, invocation_info.hash + ) + status.update(tracer_args.models_found, build_name, tracer_args.cache_dir) # Turn tracing on again after computing the outputs sys.setprofile(tracer) diff --git a/src/mlagility/analysis/status.py b/src/mlagility/analysis/status.py index ec88ca56..3fac0e13 100644 --- a/src/mlagility/analysis/status.py +++ b/src/mlagility/analysis/status.py @@ -4,8 +4,9 @@ import onnxflow.common.build as build from mlagility.analysis.util import ModelInfo - -def update(models_found: Dict[str, ModelInfo]) -> None: +def update(models_found: Dict[str, ModelInfo], + build_name: str, + cache_dir: str) -> None: """ Prints all models and submodels found """ @@ -16,11 +17,13 @@ def update(models_found: Dict[str, ModelInfo]) -> None: "\nModels discovered during profiling:\n", c=printing.Colors.BOLD, ) - recursive_print(models_found, None, None, []) + recursive_print(models_found, build_name, cache_dir, None, None, []) def recursive_print( models_found: Dict[str, ModelInfo], + build_name: str, + cache_dir: str, parent_model_hash: Union[str, None] = None, parent_invocation_hash: Union[str, None] = None, script_names_visited: List[str] = False, @@ -50,6 +53,8 @@ def recursive_print( print_invocation( model_info, + build_name, + cache_dir, invocation_hash, print_file_name, invocation_idx=invocation_idx, @@ -63,6 +68,8 @@ def recursive_print( recursive_print( models_found, + build_name, + cache_dir, parent_model_hash=model_hash, parent_invocation_hash=invocation_hash, script_names_visited=script_names_visited, @@ -71,6 +78,8 @@ def recursive_print( def print_invocation( model_info: ModelInfo, + build_name: str, + cache_dir: str, invocation_hash: Union[str, None], print_file_name: bool = False, invocation_idx: int = 0, @@ -133,6 +142,7 @@ def print_invocation( print(f"{ident}\tInput Shape:\t{input_shape}") print(f"{ident}\tHash:\t\t" + invocation_hash) + print(f"{ident}\tBuild dir:\t" + cache_dir + '/' + build_name) # Print benchit results if benchit was run if unique_invocation.performance: diff --git a/src/mlagility/api/report.py b/src/mlagility/api/report.py index b407849f..9d5ae5cc 100644 --- a/src/mlagility/api/report.py +++ b/src/mlagility/api/report.py @@ -2,23 +2,23 @@ APIs to help with programmatically parsing the MLAgility report """ -from typing import Dict +from typing import Dict, List import pandas as pd - -def get_dict(report_csv: str, column: str) -> Dict[str, str]: +def get_dict(report_csv: str, columns: List[str]) -> Dict[str, Dict[str, str]]: """ - Returns a dictionary where the keys are model names and the values are the values of - a given column in a report.csv file. - + Returns a dictionary where the keys are model names and the values are dictionaries. + Each dictionary represents a model with column names as keys and their corresponding values. args: - report_csv: path to a report.csv file generated by benchit - - column: name of a column in the report.csv file whose values will be used to + - columns: list of column names in the report.csv file whose values will be used to populate the dictionary """ - # Open the MLAgility report as a dataframe + # Load the MLAgility report as a dataframe dataframe = pd.read_csv(report_csv) - # Convert to dictionary - return pd.Series(dataframe[column].values, index=dataframe.model_name).to_dict() + # Create a nested dictionary with model_name as keys and another dictionary of {column: value} pairs as values + result = {row[0]: row[1].to_dict() for row in dataframe.set_index("model_name")[columns].iterrows()} + + return result diff --git a/test/cli.py b/test/cli.py index 97061152..10292f06 100644 --- a/test/cli.py +++ b/test/cli.py @@ -19,7 +19,7 @@ from mlagility.cli.cli import main as benchitcli import mlagility.cli.report as report import mlagility.api.report as report_api -from mlagility.common import filesystem +import mlagility.common.filesystem as filesystem import mlagility.api.ortmodel as ortmodel import mlagility.api.trtmodel as trtmodel import onnxflow.common.build as build @@ -436,11 +436,11 @@ def test_004_cli_report(self): ), f"onnx_converted must be True, got {linear_summary['onnx_converted']}" # Make sure the report.get_dict() API works - result_dict = report_api.get_dict(summary_csv_path, "onnx_exported") - for result in result_dict.values(): + result_dict = report_api.get_dict(summary_csv_path, ["onnx_exported"]) + for model_name, result in result_dict.items(): # All of the models should have exported to ONNX, so the "onnx_exported" value # should be True for all of them - assert result + assert result.get("onnx_exported") is True def test_005_cli_list(self): # NOTE: this is not a unit test, it relies on other command