Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hash model inputs instead of parameters #324

Merged
merged 40 commits into from
Jun 22, 2023
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
e417eba
Basic input hashing
Jun 14, 2023
5b9b188
Showing workload status correctly
Jun 14, 2023
7997a9b
Showing workload hash rather than model hash
Jun 15, 2023
056b90b
Temporarily modifying docker file to enalbe CI
Jun 15, 2023
a03f183
Merge branch 'main' into robust_hashing
Jun 15, 2023
ca64739
Merge main into branch
Jun 15, 2023
7d29c5a
Revert fs changes
Jun 15, 2023
3c6cb9e
Robust shape extraction
Jun 15, 2023
d3fece5
Update ci test hash
Jun 15, 2023
9da81c9
Update analysis CI
Jun 15, 2023
8aa05f7
Add test
Jun 15, 2023
52c5e16
Updated dockerfile
Jun 15, 2023
6aaa2f0
Add requirement
Jun 16, 2023
e02df17
Added input shape to print
Jun 16, 2023
cc8f5b6
Merge branch 'main' into robust_hashing
jeremyfowers Jun 16, 2023
8593f8b
Simplify llama code
Jun 16, 2023
46993ed
Merge branch 'robust_hashing' of https://github.com/groq/mlagility in…
Jun 16, 2023
676036f
recursively printing for each model
Jun 16, 2023
f6fd7f1
Keeping track of parent workload hash
Jun 16, 2023
e462373
Correctly printing when max_depth is set
Jun 16, 2023
5f1ba1d
Ensure that hashes are different if they come from different workloads
Jun 16, 2023
417c701
Fix CI
Jun 17, 2023
b2be4bb
Correctly keeping track of last workload executed
Jun 20, 2023
c972930
Better UI
Jun 20, 2023
6abf7ac
Added test
Jun 20, 2023
0d34fa7
Renamed function as suggested
Jun 20, 2023
b4b6371
Change model to workload where appropriate
Jun 21, 2023
60862bd
Fix CI
Jun 21, 2023
cdfef11
Fix slurm CI
Jun 21, 2023
328938f
Revert "Fix slurm CI"
Jun 21, 2023
7c6668d
Revert "Fix CI"
Jun 21, 2023
2651512
Revert "Change model to workload where appropriate"
Jun 21, 2023
5d85530
Suggested changes
Jun 21, 2023
c67f140
Replacing the term workloads by invocations
Jun 21, 2023
336d563
Add documentation for new feature
Jun 21, 2023
6c06115
merge main into branch
Jun 21, 2023
e939b8f
Merge branch 'main' into robust_hashing
jeremyfowers Jun 22, 2023
17f0e69
Fix tutorial typo
jeremyfowers Jun 22, 2023
18aca16
Copy editing the multiple invokation tutorial
jeremyfowers Jun 22, 2023
daf13a8
Note issue in the code
jeremyfowers Jun 22, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 59 additions & 31 deletions src/mlagility/analysis/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
import functools
import dataclasses
import traceback
from typing import Union, List, Dict
import hashlib
from typing import Union, List, Dict, Tuple
from types import FrameType, TracebackType
from enum import Enum
import torch
Expand Down Expand Up @@ -63,28 +64,31 @@ def torch_activations(self) -> List[str]:
return act


def _store_traceback(model_info: util.ModelInfo):
def _store_traceback(workload_info: util.WorkloadInfo):
"""
Store the traceback from an exception into model_info so that
Store the traceback from an exception into workload_info so that
we can print it during the status update.
"""

exc_type, exc_value, exc_traceback = sys.exc_info()
model_info.traceback = traceback.format_exception(
workload_info.traceback = traceback.format_exception(
exc_type, exc_value, exc_traceback
)


def call_benchit(
danielholanda marked this conversation as resolved.
Show resolved Hide resolved
model_inputs: dict, model_info: util.ModelInfo, tracer_args: TracerArgs
model_inputs: dict,
model_info: util.ModelInfo,
workload_info: util.WorkloadInfo,
tracer_args: TracerArgs,
) -> None:
"""
Calls the benchit function from within the model forward function
"""

# Update status to "computing"
model_info.status_message = "Computing..."
model_info.status_message_color = printing.Colors.OKBLUE
workload_info.status_message = "Computing..."
workload_info.status_message_color = printing.Colors.OKBLUE
status.update(tracer_args.models_found)

# Get a copy of the keyword arguments
Expand All @@ -111,10 +115,10 @@ def call_benchit(
inputs[all_args[i]] = torch.tensor(args[i].detach().numpy())
else:
inputs[all_args[i]] = args[i]
model_info.inputs = inputs
workload_info.inputs = inputs

build_name = filesystem.get_build_name(
tracer_args.script_name, tracer_args.labels, model_info.hash
tracer_args.script_name, tracer_args.labels, workload_info.hash
)

# Save model labels
Expand All @@ -124,12 +128,12 @@ def call_benchit(
perf = None
try:
if model_info.model_type == build.ModelType.PYTORCH_COMPILED:
model_info.status_message = (
workload_info.status_message = (
"Skipping model compiled using torch.compile(). "
"benchit requires models to be in eager mode "
"(regardless of what runtime you have selected)."
)
model_info.status_message_color = printing.Colors.WARNING
workload_info.status_message_color = printing.Colors.WARNING
else:
perf = benchmark_model(
model_info.model,
Expand All @@ -150,36 +154,36 @@ def call_benchit(
onnx_opset=tracer_args.onnx_opset,
)
if Action.BENCHMARK in tracer_args.actions:
model_info.status_message = "Model successfully benchmarked!"
model_info.performance = perf
model_info.status_message_color = printing.Colors.OKGREEN
workload_info.status_message = "Model successfully benchmarked!"
workload_info.performance = perf
workload_info.status_message_color = printing.Colors.OKGREEN
else:
model_info.status_message = "Model successfully built!"
model_info.status_message_color = printing.Colors.OKGREEN
workload_info.status_message = "Model successfully built!"
workload_info.status_message_color = printing.Colors.OKGREEN

except exp.StageError:
build_state = build.load_state(
cache_dir=tracer_args.cache_dir, build_name=build_name
)
model_info.status_message = "Build Error: see log files for details."
model_info.status_message_color = printing.Colors.WARNING
workload_info.status_message = "Build Error: see log files for details."
workload_info.status_message_color = printing.Colors.WARNING

_store_traceback(model_info)
_store_traceback(workload_info)

except exp.Error:
model_info.status_message = "GroqFlowError: see log files for details."
model_info.status_message_color = printing.Colors.WARNING
workload_info.status_message = "GroqFlowError: see log files for details."
workload_info.status_message_color = printing.Colors.WARNING

_store_traceback(model_info)
_store_traceback(workload_info)

# This broad exception is ok since enumerating all exceptions is
# not possible, as the tested software continuously evolves.
except Exception as e: # pylint: disable=broad-except
util.stop_stdout_forward()
model_info.status_message = f"Unknown benchit error: {e}"
model_info.status_message_color = printing.Colors.WARNING
workload_info.status_message = f"Unknown benchit error: {e}"
workload_info.status_message_color = printing.Colors.WARNING

_store_traceback(model_info)
_store_traceback(workload_info)
finally:
# Ensure that stdout is not being forwarded before updating status
if hasattr(sys.stdout, "terminal"):
Expand Down Expand Up @@ -247,7 +251,23 @@ def call_benchit(
def get_model_hash(
model: Union[torch.nn.Module, "tf.keras.Model"], model_type: build.ModelType
):
return build.hash_model(model, model_type, hash_params=True)[:8]
danielholanda marked this conversation as resolved.
Show resolved Hide resolved
return build.hash_model(model, model_type, hash_params=False)[:8]


def get_workload_hash(model_hash: str, args: Tuple, kwargs: Dict) -> str:
"""
Combines the model hash and the input shapes to create the workload hash
"""

# Merge positional and keyword args
args = {"positional{}".format(i + 1): arg for i, arg in enumerate(args)}
kwargs = {**kwargs, **args}

# Get input shapes and types
input_shapes, input_dtypes = build.get_shapes_and_dtypes(kwargs)

hashable_content = f"{model_hash}{input_shapes}{input_dtypes}"
return hashlib.sha256(hashable_content.encode()).hexdigest()[:8]


def store_model_info(
Expand Down Expand Up @@ -292,7 +312,6 @@ def store_model_info(
depth=depth,
hash=model_hash,
parent_hash=parent_hash,
is_target=model_hash in tracer_args.targets or tracer_args.targets == [],
build_model=build_model,
model_type=model_type,
script_name=tracer_args.script_name,
Expand Down Expand Up @@ -439,21 +458,30 @@ def forward_spy(*args, **kwargs):
parent_hash,
)
model_hash = get_model_hash(local_var, model_type)
workload_hash = get_workload_hash(model_hash, args, kwargs)
model_info = tracer_args.models_found[model_hash]
model_info.exec_time = model_info.exec_time + end_time - start_time
if workload_hash not in model_info.workloads:
model_info.workloads[workload_hash] = util.WorkloadInfo(
hash=workload_hash,
is_target=workload_hash in tracer_args.targets
or tracer_args.targets == [],
danielholanda marked this conversation as resolved.
Show resolved Hide resolved
)
workload_info = model_info.workloads[workload_hash]
workload_info.exec_time = workload_info.exec_time + end_time - start_time

model_info.executed = model_info.executed + 1
workload_info.executed = workload_info.executed + 1

# Call groqit if this is the first time the model is being executed
# and this model has been selected by the user
if (
model_info.executed == 1
and model_info.is_target
workload_info.executed == 1
and workload_info.is_target
and (model_info.build_model)
):
call_benchit(
model_inputs=[args, kwargs],
model_info=model_info,
workload_info=workload_info,
tracer_args=tracer_args,
)
# Ensure that groqit() doesn't interfere with our execution count
Expand Down
71 changes: 45 additions & 26 deletions src/mlagility/analysis/status.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,25 +23,45 @@ def recursive_print(
models_found: Dict[str, ModelInfo],
parent_hash: Union[str, None] = None,
script_names_visited: List[str] = False,
depth: int = 0,
) -> None:
script_names_visited = []

for h in models_found.keys():
if parent_hash == models_found[h].parent_hash and models_found[h].executed > 0:
print_file_name = models_found[h].script_name not in script_names_visited
for model_hash in models_found.keys():
workloads_executed = False
for workload_hash in models_found[model_hash].workloads.keys():
workload = models_found[model_hash].workloads[workload_hash]
danielholanda marked this conversation as resolved.
Show resolved Hide resolved

print_model(models_found[h], h, print_file_name)
if (
parent_hash == models_found[model_hash].parent_hash
and workload.executed > 0
):

if print_file_name:
script_names_visited.append(models_found[h].script_name)
workloads_executed = True
print_file_name = False
if models_found[model_hash].script_name not in script_names_visited:
script_names_visited.append(models_found[model_hash].script_name)
if depth == 0:
print_file_name = True

print_workload(models_found[model_hash], workload_hash, print_file_name)

if print_file_name:
script_names_visited.append(models_found[model_hash].script_name)

if workloads_executed:
recursive_print(
models_found, parent_hash=h, script_names_visited=script_names_visited
models_found,
parent_hash=model_hash,
script_names_visited=script_names_visited,
depth=depth + 1,
)


def print_model(
model_info: ModelInfo, model_hash: Union[str, None], print_file_name: bool = False
def print_workload(
model_info: ModelInfo,
workload_hash: Union[str, None],
print_file_name: bool = False,
) -> None:
"""
Print information about a given model or submodel
Expand All @@ -54,12 +74,13 @@ def print_model(
# Show the number of times the model has been executed
# Only show the execution time if we are not running benchit() as this
# impacts time measurement.
if model_info.exec_time == 0 or model_info.build_model:
workload = model_info.workloads[workload_hash]
if workload.exec_time == 0 or model_info.build_model:
exec_time = ""
else:
exec_time = f" - {model_info.exec_time:.2f}s"
exec_time = f" - {workload.exec_time:.2f}s"
printing.logn(
f"(executed {model_info.executed}x{exec_time})",
f"(executed {workload.executed}x{exec_time})",
c=printing.Colors.OKGREEN,
)

Expand All @@ -78,34 +99,32 @@ def print_model(
model_size = model_info.params * 2 / (1024 * 1024)
model_size = "{:.1f}".format(model_size) if model_size > 0.1 else "<0.1"
print(f"{ident}\tParameters:\t{'{:,}'.format(model_info.params)} ({model_size} MB)")
print(f"{ident}\tHash:\t\t" + model_hash)
print(f"{ident}\tHash:\t\t" + workload_hash)

# Print benchit results if benchit was run
if model_info.performance:
if workload.performance:
printing.log(f"{ident}\tStatus:\t\t")
printing.logn(
f"Successfully benchmarked on {model_info.performance.device} ({model_info.performance.runtime} v{model_info.performance.runtime_version})",
c=model_info.status_message_color,
f"Successfully benchmarked on {workload.performance.device} ({workload.performance.runtime} v{workload.performance.runtime_version})",
c=workload.status_message_color,
)
printing.logn(
f"{ident}\t\t\tMean Latency:\t{model_info.performance.mean_latency:.3f}"
f"\t{model_info.performance.latency_units}"
f"{ident}\t\t\tMean Latency:\t{workload.performance.mean_latency:.3f}"
f"\t{workload.performance.latency_units}"
)
printing.logn(
f"{ident}\t\t\tThroughput:\t{model_info.performance.throughput:.1f}"
f"\t{model_info.performance.throughput_units}"
f"{ident}\t\t\tThroughput:\t{workload.performance.throughput:.1f}"
f"\t{workload.performance.throughput_units}"
)
print()
else:
if model_info.is_target and model_info.build_model:
if workload.is_target and model_info.build_model:
printing.log(f"{ident}\tStatus:\t\t")
printing.logn(
f"{model_info.status_message}", c=model_info.status_message_color
)
printing.logn(f"{workload.status_message}", c=workload.status_message_color)

if model_info.traceback is not None:
if workload.traceback is not None:
if os.environ.get("MLAGILITY_TRACEBACK") != "False":
for line in model_info.traceback:
for line in workload.traceback:
for subline in line.split("\n")[:-1]:
print(f"{ident}\t{subline}")

Expand Down
27 changes: 18 additions & 9 deletions src/mlagility/analysis/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from dataclasses import dataclass
from typing import Callable, List, Union, Dict
import inspect
import dataclasses
import torch
import onnx
from onnxflow.common import printing
Expand All @@ -15,6 +16,20 @@ class AnalysisException(Exception):
"""


@dataclass
class WorkloadInfo:
hash: Union[str, None] = None
performance: MeasuredPerformance = None
traceback: List[str] = None
inputs: Union[dict, None] = None
executed: int = 0
exec_time: float = 0.0
status_message: str = ""
is_target: bool = False
status_message_color: printing.Colors = printing.Colors.ENDC
traceback_message_color: printing.Colors = printing.Colors.FAIL


@dataclass
class ModelInfo:
model: torch.nn.Module
Expand All @@ -26,18 +41,12 @@ class ModelInfo:
depth: int = 0
hash: Union[str, None] = None
parent_hash: Union[str, None] = None
inputs: Union[dict, None] = None
executed: int = 0
exec_time: float = 0.0
old_forward: Union[Callable, None] = None
status_message: str = ""
status_message_color: printing.Colors = printing.Colors.ENDC
traceback_message_color: printing.Colors = printing.Colors.FAIL
is_target: bool = False
workloads: Union[Dict[str, WorkloadInfo], None] = dataclasses.field(
default_factory=dict
)
build_model: bool = False
model_type: build.ModelType = build.ModelType.PYTORCH
performance: MeasuredPerformance = None
traceback: List[str] = None

def __post_init__(self):
self.params = count_parameters(self.model, self.model_type)
Expand Down
2 changes: 1 addition & 1 deletion src/mlagility/api/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@ from httpd

RUN apt-get update && apt-get install -y --no-install-recommends python3-dev python3-setuptools python3-wheel python3-pip
ENV PYTHONPATH "${PYTHONPATH}:/usr/bin/python3"
RUN pip install onnxruntime==1.14.1
RUN pip install onnxruntime --break-system-packages
danielholanda marked this conversation as resolved.
Show resolved Hide resolved
5 changes: 4 additions & 1 deletion src/onnxflow/common/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,10 @@ def get_shapes_and_dtypes(inputs: dict):
subkey = f"{key}[{i}]"
shapes[subkey] = np.array(v).shape
dtypes[subkey] = np.array(v).dtype.name
elif torch.is_tensor(value) or tf_helpers.is_keras_tensor(value):
elif torch.is_tensor(value):
shapes[key] = np.array(value.detach()).shape
dtypes[key] = np.array(value.detach()).dtype.name
elif tf_helpers.is_keras_tensor(value):
shapes[key] = np.array(value).shape
dtypes[key] = np.array(value).dtype.name
elif isinstance(value, np.ndarray):
Expand Down
Loading