Skip to content

Commit

Permalink
Add cache and build name to status report (#333)
Browse files Browse the repository at this point in the history
  • Loading branch information
ramkrishna2910 authored Jun 23, 2023
1 parent f35e27c commit 59bcef7
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 24 deletions.
17 changes: 10 additions & 7 deletions src/mlagility/analysis/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,11 @@ def explore_invocation(
# Update status to "computing"
invocation_info.status_message = "Computing..."
invocation_info.status_message_color = printing.Colors.OKBLUE
status.update(tracer_args.models_found)

build_name = filesystem.get_build_name(
tracer_args.script_name, tracer_args.labels, invocation_info.hash
)
status.update(tracer_args.models_found, build_name, tracer_args.cache_dir)

# Get a copy of the keyword arguments
args, kwargs = model_inputs
Expand Down Expand Up @@ -117,10 +121,6 @@ def explore_invocation(
inputs[all_args[i]] = args[i]
invocation_info.inputs = inputs

build_name = filesystem.get_build_name(
tracer_args.script_name, tracer_args.labels, invocation_info.hash
)

# Save model labels
tracer_args.labels["class"] = [f"{type(model_info.model).__name__}"]
labels.save_to_cache(tracer_args.cache_dir, build_name, tracer_args.labels)
Expand Down Expand Up @@ -188,7 +188,7 @@ def explore_invocation(
# Ensure that stdout is not being forwarded before updating status
if hasattr(sys.stdout, "terminal"):
sys.stdout = sys.stdout.terminal
status.update(tracer_args.models_found)
status.update(tracer_args.models_found, build_name, tracer_args.cache_dir)

if tracer_args.device == "groq":
import groqflow.common.build as groq_build
Expand Down Expand Up @@ -510,7 +510,10 @@ def forward_spy(*args, **kwargs):
# Ensure that groqit() doesn't interfere with our execution count
model_info.executed = 1

status.update(tracer_args.models_found)
build_name = filesystem.get_build_name(
tracer_args.script_name, tracer_args.labels, invocation_info.hash
)
status.update(tracer_args.models_found, build_name, tracer_args.cache_dir)

# Turn tracing on again after computing the outputs
sys.setprofile(tracer)
Expand Down
16 changes: 13 additions & 3 deletions src/mlagility/analysis/status.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
import onnxflow.common.build as build
from mlagility.analysis.util import ModelInfo


def update(models_found: Dict[str, ModelInfo]) -> None:
def update(models_found: Dict[str, ModelInfo],
build_name: str,
cache_dir: str) -> None:
"""
Prints all models and submodels found
"""
Expand All @@ -16,11 +17,13 @@ def update(models_found: Dict[str, ModelInfo]) -> None:
"\nModels discovered during profiling:\n",
c=printing.Colors.BOLD,
)
recursive_print(models_found, None, None, [])
recursive_print(models_found, build_name, cache_dir, None, None, [])


def recursive_print(
models_found: Dict[str, ModelInfo],
build_name: str,
cache_dir: str,
parent_model_hash: Union[str, None] = None,
parent_invocation_hash: Union[str, None] = None,
script_names_visited: List[str] = False,
Expand Down Expand Up @@ -50,6 +53,8 @@ def recursive_print(

print_invocation(
model_info,
build_name,
cache_dir,
invocation_hash,
print_file_name,
invocation_idx=invocation_idx,
Expand All @@ -63,6 +68,8 @@ def recursive_print(

recursive_print(
models_found,
build_name,
cache_dir,
parent_model_hash=model_hash,
parent_invocation_hash=invocation_hash,
script_names_visited=script_names_visited,
Expand All @@ -71,6 +78,8 @@ def recursive_print(

def print_invocation(
model_info: ModelInfo,
build_name: str,
cache_dir: str,
invocation_hash: Union[str, None],
print_file_name: bool = False,
invocation_idx: int = 0,
Expand Down Expand Up @@ -133,6 +142,7 @@ def print_invocation(

print(f"{ident}\tInput Shape:\t{input_shape}")
print(f"{ident}\tHash:\t\t" + invocation_hash)
print(f"{ident}\tBuild dir:\t" + cache_dir + '/' + build_name)

# Print benchit results if benchit was run
if unique_invocation.performance:
Expand Down
20 changes: 10 additions & 10 deletions src/mlagility/api/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,23 @@
APIs to help with programmatically parsing the MLAgility report
"""

from typing import Dict
from typing import Dict, List
import pandas as pd


def get_dict(report_csv: str, column: str) -> Dict[str, str]:
def get_dict(report_csv: str, columns: List[str]) -> Dict[str, Dict[str, str]]:
"""
Returns a dictionary where the keys are model names and the values are the values of
a given column in a report.csv file.
Returns a dictionary where the keys are model names and the values are dictionaries.
Each dictionary represents a model with column names as keys and their corresponding values.
args:
- report_csv: path to a report.csv file generated by benchit
- column: name of a column in the report.csv file whose values will be used to
- columns: list of column names in the report.csv file whose values will be used to
populate the dictionary
"""

# Open the MLAgility report as a dataframe
# Load the MLAgility report as a dataframe
dataframe = pd.read_csv(report_csv)

# Convert to dictionary
return pd.Series(dataframe[column].values, index=dataframe.model_name).to_dict()
# Create a nested dictionary with model_name as keys and another dictionary of {column: value} pairs as values
result = {row[0]: row[1].to_dict() for row in dataframe.set_index("model_name")[columns].iterrows()}

return result
8 changes: 4 additions & 4 deletions test/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from mlagility.cli.cli import main as benchitcli
import mlagility.cli.report as report
import mlagility.api.report as report_api
from mlagility.common import filesystem
import mlagility.common.filesystem as filesystem
import mlagility.api.ortmodel as ortmodel
import mlagility.api.trtmodel as trtmodel
import onnxflow.common.build as build
Expand Down Expand Up @@ -436,11 +436,11 @@ def test_004_cli_report(self):
), f"onnx_converted must be True, got {linear_summary['onnx_converted']}"

# Make sure the report.get_dict() API works
result_dict = report_api.get_dict(summary_csv_path, "onnx_exported")
for result in result_dict.values():
result_dict = report_api.get_dict(summary_csv_path, ["onnx_exported"])
for model_name, result in result_dict.items():
# All of the models should have exported to ONNX, so the "onnx_exported" value
# should be True for all of them
assert result
assert result.get("onnx_exported") is True

def test_005_cli_list(self):
# NOTE: this is not a unit test, it relies on other command
Expand Down

0 comments on commit 59bcef7

Please sign in to comment.