Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add cache and build name to status report #333

Merged
merged 5 commits into from
Jun 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions src/mlagility/analysis/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,11 @@ def explore_invocation(
# Update status to "computing"
invocation_info.status_message = "Computing..."
invocation_info.status_message_color = printing.Colors.OKBLUE
status.update(tracer_args.models_found)

build_name = filesystem.get_build_name(
tracer_args.script_name, tracer_args.labels, invocation_info.hash
)
status.update(tracer_args.models_found, build_name, tracer_args.cache_dir)

# Get a copy of the keyword arguments
args, kwargs = model_inputs
Expand Down Expand Up @@ -117,10 +121,6 @@ def explore_invocation(
inputs[all_args[i]] = args[i]
invocation_info.inputs = inputs

build_name = filesystem.get_build_name(
tracer_args.script_name, tracer_args.labels, invocation_info.hash
)

# Save model labels
tracer_args.labels["class"] = [f"{type(model_info.model).__name__}"]
labels.save_to_cache(tracer_args.cache_dir, build_name, tracer_args.labels)
Expand Down Expand Up @@ -188,7 +188,7 @@ def explore_invocation(
# Ensure that stdout is not being forwarded before updating status
if hasattr(sys.stdout, "terminal"):
sys.stdout = sys.stdout.terminal
status.update(tracer_args.models_found)
status.update(tracer_args.models_found, build_name, tracer_args.cache_dir)

if tracer_args.device == "groq":
import groqflow.common.build as groq_build
Expand Down Expand Up @@ -510,7 +510,10 @@ def forward_spy(*args, **kwargs):
# Ensure that groqit() doesn't interfere with our execution count
model_info.executed = 1

status.update(tracer_args.models_found)
build_name = filesystem.get_build_name(
tracer_args.script_name, tracer_args.labels, invocation_info.hash
)
status.update(tracer_args.models_found, build_name, tracer_args.cache_dir)

# Turn tracing on again after computing the outputs
sys.setprofile(tracer)
Expand Down
16 changes: 13 additions & 3 deletions src/mlagility/analysis/status.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
import onnxflow.common.build as build
from mlagility.analysis.util import ModelInfo


def update(models_found: Dict[str, ModelInfo]) -> None:
def update(models_found: Dict[str, ModelInfo],
build_name: str,
cache_dir: str) -> None:
"""
Prints all models and submodels found
"""
Expand All @@ -16,11 +17,13 @@ def update(models_found: Dict[str, ModelInfo]) -> None:
"\nModels discovered during profiling:\n",
c=printing.Colors.BOLD,
)
recursive_print(models_found, None, None, [])
recursive_print(models_found, build_name, cache_dir, None, None, [])


def recursive_print(
models_found: Dict[str, ModelInfo],
build_name: str,
cache_dir: str,
parent_model_hash: Union[str, None] = None,
parent_invocation_hash: Union[str, None] = None,
script_names_visited: List[str] = False,
Expand Down Expand Up @@ -50,6 +53,8 @@ def recursive_print(

print_invocation(
model_info,
build_name,
cache_dir,
invocation_hash,
print_file_name,
invocation_idx=invocation_idx,
Expand All @@ -63,6 +68,8 @@ def recursive_print(

recursive_print(
models_found,
build_name,
cache_dir,
parent_model_hash=model_hash,
parent_invocation_hash=invocation_hash,
script_names_visited=script_names_visited,
Expand All @@ -71,6 +78,8 @@ def recursive_print(

def print_invocation(
model_info: ModelInfo,
build_name: str,
ramkrishna2910 marked this conversation as resolved.
Show resolved Hide resolved
cache_dir: str,
invocation_hash: Union[str, None],
print_file_name: bool = False,
invocation_idx: int = 0,
Expand Down Expand Up @@ -133,6 +142,7 @@ def print_invocation(

print(f"{ident}\tInput Shape:\t{input_shape}")
print(f"{ident}\tHash:\t\t" + invocation_hash)
print(f"{ident}\tBuild dir:\t" + cache_dir + '/' + build_name)

# Print benchit results if benchit was run
if unique_invocation.performance:
Expand Down
20 changes: 10 additions & 10 deletions src/mlagility/api/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,23 @@
APIs to help with programmatically parsing the MLAgility report
"""

from typing import Dict
from typing import Dict, List
import pandas as pd


def get_dict(report_csv: str, column: str) -> Dict[str, str]:
def get_dict(report_csv: str, columns: List[str]) -> Dict[str, Dict[str, str]]:
"""
Returns a dictionary where the keys are model names and the values are the values of
a given column in a report.csv file.

Returns a dictionary where the keys are model names and the values are dictionaries.
Each dictionary represents a model with column names as keys and their corresponding values.
args:
- report_csv: path to a report.csv file generated by benchit
- column: name of a column in the report.csv file whose values will be used to
- columns: list of column names in the report.csv file whose values will be used to
populate the dictionary
"""

# Open the MLAgility report as a dataframe
# Load the MLAgility report as a dataframe
dataframe = pd.read_csv(report_csv)

# Convert to dictionary
return pd.Series(dataframe[column].values, index=dataframe.model_name).to_dict()
# Create a nested dictionary with model_name as keys and another dictionary of {column: value} pairs as values
result = {row[0]: row[1].to_dict() for row in dataframe.set_index("model_name")[columns].iterrows()}

return result
8 changes: 4 additions & 4 deletions test/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from mlagility.cli.cli import main as benchitcli
import mlagility.cli.report as report
import mlagility.api.report as report_api
from mlagility.common import filesystem
import mlagility.common.filesystem as filesystem
import mlagility.api.ortmodel as ortmodel
import mlagility.api.trtmodel as trtmodel
import onnxflow.common.build as build
Expand Down Expand Up @@ -436,11 +436,11 @@ def test_004_cli_report(self):
), f"onnx_converted must be True, got {linear_summary['onnx_converted']}"

# Make sure the report.get_dict() API works
result_dict = report_api.get_dict(summary_csv_path, "onnx_exported")
for result in result_dict.values():
result_dict = report_api.get_dict(summary_csv_path, ["onnx_exported"])
for model_name, result in result_dict.items():
# All of the models should have exported to ONNX, so the "onnx_exported" value
# should be True for all of them
assert result
assert result.get("onnx_exported") is True

def test_005_cli_list(self):
# NOTE: this is not a unit test, it relies on other command
Expand Down