Skip to content

Commit

Permalink
Hash model inputs instead of parameters (#324)
Browse files Browse the repository at this point in the history
Co-authored-by: Jeremy Fowers <[email protected]>
Co-authored-by: jfowers <[email protected]>
  • Loading branch information
3 people authored Jun 22, 2023
1 parent 6f32ec2 commit f35e27c
Show file tree
Hide file tree
Showing 12 changed files with 367 additions and 119 deletions.
4 changes: 3 additions & 1 deletion .github/workflows/test_mlagility.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ jobs:
python -m pip install --upgrade pip
conda install pylint
if [ -f setup.py ]; then pip install -e .; fi
pip install transformers
pip install transformers timm
python -m pip check
- name: Lint with PyLint
shell: bash -el {0}
Expand All @@ -56,6 +56,8 @@ jobs:
rm -rf ~/.cache/mlagility
benchit examples/cli/scripts/hello_world.py
rm -rf ~/.cache/mlagility
benchit examples/cli/scripts/multiple_invocations.py
rm -rf ~/.cache/mlagility
benchit examples/cli/scripts/max_depth.py --max-depth 1
rm -rf ~/.cache/mlagility
benchit examples/cli/scripts/two_models.py
Expand Down
2 changes: 1 addition & 1 deletion examples/cli/discovery.md
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ You can see that `hello_world.py`, `two_models.py`, and `max_depth.py` are all e

> See the [Benchmark Multiple Scripts documentation](https://github.com/groq/mlagility/blob/main/docs/tools_user_guide.md#benchmark-multiple-scripts) for more details.
### Maximum Analysis Depth
## Maximum Analysis Depth

PyTorch models (eg, `torch.nn.Module`) are often built out of a collection of smaller instances. For example, a PyTorch multilayer perceptron (MLP) model may be built out of many `torch.nn.Linear` modules.

Expand Down
41 changes: 41 additions & 0 deletions examples/cli/readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ In this tutorial you will learn things such as:
- [A "hello world" example, which is the easiest way to get started](#hello-world)
- [Benchmarking on Nvidia GPUs](#nvidia-benchmarking)
- [Working with scripts that invoke more than one model](#multiple-models-per-script)
- [Working with scripts that invoke a model multiple times](#multiple-invocations-of-a-model)
- [Benchmarking an ONNX file](#onnx-benchmarking)

# Just Benchmark BERT
Expand Down Expand Up @@ -171,6 +172,46 @@ Woohoo! The 'benchmark' command is complete.

You can see that both model instances in `two_models.py`, `pytorch_model` and `another_pytorch_model`, are both discovered and benchmarked.

## Multiple Invocations of a Model

A single script may invoke the same model multiple times using different input shapes (e.g. when varying the batch size). When this happens, MLAgility will benchmark and display each of those invocations as sub-results of the same model instance.

> **Note**: multiple invocations of a model with the same input shape will only be benchmarked once.
The `multiple_invocations.py` script instantiates a single model and invokes it three times. The fist two times the model is invoked with inputs of the same shape (batch 1), while the third invocation uses a different input shape (batch 2). Note that two unique static model invocations are identified.

Run the following command:

```
benchit scripts/multiple_invocations.py
```

To get a result like:
```
Models discovered during profiling:
multiple_invocations.py:
pytorch_model
Model Type: Pytorch (torch.nn.Module)
Class: SmallModel (<class 'multiple_invocations.SmallModel'>)
Location: /net/home/dhnoronha/mlagility/examples/cli/scripts/multiple_invocations.py, line 40
Parameters: 60 (<0.1 MB)
With input shape 1 (executed 2x)
Input Shape: 'x': (1, 11)
Hash: b4aa73ae
Status: Successfully benchmarked on Intel(R) Xeon(R) CPU @ 2.20GHz (ort v1.14.1)
Mean Latency: 0.013 milliseconds (ms)
Throughput: 77909.6 inferences per second (IPS)
With input shape 2 (executed 1x)
Input Shape: 'x': (2, 11)
Hash: cfaa2e2c
Status: Successfully benchmarked on Intel(R) Xeon(R) CPU @ 2.20GHz (ort v1.14.1)
Mean Latency: 0.015 milliseconds (ms)
Throughput: 64938.1 inferences per second (IPS)
```

## ONNX Benchmarking

If you already happen to have an ONNX file, `benchit` can benchmark it for you. We can demonstrate this with the ONNX file in `examples/cli/onnx/sample.onnx`.
Expand Down
47 changes: 47 additions & 0 deletions examples/cli/scripts/multiple_invocations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# labels: name::multiple_invocations
"""
This example demonstrates what happens when your script contains
a model that is invoked multiple times with different input shapes
To try it, run:
benchit multiple_invocations.py
You should see the two unique invocations being identified.
"""
import torch

torch.manual_seed(1)

# Define model class
class SmallModel(torch.nn.Module):
def __init__(self, input_features, output_size):
super(SmallModel, self).__init__()
self.fc = torch.nn.Linear(input_features, output_size)

def forward(self, x):
# x has shape (batch_size, input_features)
# Set the batch size dimension to -1 to allow for flexibility
x = x.view(-1, x.size(1))

output = self.fc(x)

# Reshape the output to restore the original batch size dimension
output = output.view(-1, output_size)
return output


# Instantiate model and generate inputs
input_features = 11
output_size = 5
pytorch_model = SmallModel(input_features, output_size)

# Create 3 sets of inputs
batch_size = 1
inputs1 = {"x": torch.rand(batch_size, input_features)}
inputs2 = {"x": torch.rand(batch_size, input_features)}
inputs3 = {"x": torch.rand(batch_size + 1, input_features)}

pytorch_model(**inputs1)
pytorch_model(**inputs2)
pytorch_model(**inputs3)
9 changes: 1 addition & 8 deletions models/llm_layer/llama_layer_prototype.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,7 @@

def call_llama_layer(params="7B", use_cache=False):

# Use different torch seeds for KV caching vs. not, so that
# the models end up with different mlagility hashes
# Remove the if-statement when
# https://github.com/groq/mlagility/issues/316 is fixed
if use_cache:
torch.manual_seed(0)
else:
torch.manual_seed(1)
torch.manual_seed(0)

# Parsing command-line arguments
batch_size, max_seq_length = parse(["batch_size", "max_seq_length"])
Expand Down
127 changes: 89 additions & 38 deletions src/mlagility/analysis/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
import functools
import dataclasses
import traceback
from typing import Union, List, Dict
import hashlib
from typing import Union, List, Dict, Tuple
from types import FrameType, TracebackType
from enum import Enum
import torch
Expand Down Expand Up @@ -63,28 +64,31 @@ def torch_activations(self) -> List[str]:
return act


def _store_traceback(model_info: util.ModelInfo):
def _store_traceback(invocation_info: util.UniqueInvocationInfo):
"""
Store the traceback from an exception into model_info so that
Store the traceback from an exception into invocation_info so that
we can print it during the status update.
"""

exc_type, exc_value, exc_traceback = sys.exc_info()
model_info.traceback = traceback.format_exception(
invocation_info.traceback = traceback.format_exception(
exc_type, exc_value, exc_traceback
)


def call_benchit(
model_inputs: dict, model_info: util.ModelInfo, tracer_args: TracerArgs
def explore_invocation(
model_inputs: dict,
model_info: util.ModelInfo,
invocation_info: util.UniqueInvocationInfo,
tracer_args: TracerArgs,
) -> None:
"""
Calls the benchit function from within the model forward function
"""

# Update status to "computing"
model_info.status_message = "Computing..."
model_info.status_message_color = printing.Colors.OKBLUE
invocation_info.status_message = "Computing..."
invocation_info.status_message_color = printing.Colors.OKBLUE
status.update(tracer_args.models_found)

# Get a copy of the keyword arguments
Expand All @@ -111,10 +115,10 @@ def call_benchit(
inputs[all_args[i]] = torch.tensor(args[i].detach().numpy())
else:
inputs[all_args[i]] = args[i]
model_info.inputs = inputs
invocation_info.inputs = inputs

build_name = filesystem.get_build_name(
tracer_args.script_name, tracer_args.labels, model_info.hash
tracer_args.script_name, tracer_args.labels, invocation_info.hash
)

# Save model labels
Expand All @@ -124,12 +128,12 @@ def call_benchit(
perf = None
try:
if model_info.model_type == build.ModelType.PYTORCH_COMPILED:
model_info.status_message = (
invocation_info.status_message = (
"Skipping model compiled using torch.compile(). "
"benchit requires models to be in eager mode "
"(regardless of what runtime you have selected)."
)
model_info.status_message_color = printing.Colors.WARNING
invocation_info.status_message_color = printing.Colors.WARNING
else:
perf = benchmark_model(
model_info.model,
Expand All @@ -150,36 +154,36 @@ def call_benchit(
onnx_opset=tracer_args.onnx_opset,
)
if Action.BENCHMARK in tracer_args.actions:
model_info.status_message = "Model successfully benchmarked!"
model_info.performance = perf
model_info.status_message_color = printing.Colors.OKGREEN
invocation_info.status_message = "Model successfully benchmarked!"
invocation_info.performance = perf
invocation_info.status_message_color = printing.Colors.OKGREEN
else:
model_info.status_message = "Model successfully built!"
model_info.status_message_color = printing.Colors.OKGREEN
invocation_info.status_message = "Model successfully built!"
invocation_info.status_message_color = printing.Colors.OKGREEN

except exp.StageError:
build_state = build.load_state(
cache_dir=tracer_args.cache_dir, build_name=build_name
)
model_info.status_message = "Build Error: see log files for details."
model_info.status_message_color = printing.Colors.WARNING
invocation_info.status_message = "Build Error: see log files for details."
invocation_info.status_message_color = printing.Colors.WARNING

_store_traceback(model_info)
_store_traceback(invocation_info)

except exp.Error:
model_info.status_message = "GroqFlowError: see log files for details."
model_info.status_message_color = printing.Colors.WARNING
invocation_info.status_message = "GroqFlowError: see log files for details."
invocation_info.status_message_color = printing.Colors.WARNING

_store_traceback(model_info)
_store_traceback(invocation_info)

# This broad exception is ok since enumerating all exceptions is
# not possible, as the tested software continuously evolves.
except Exception as e: # pylint: disable=broad-except
util.stop_stdout_forward()
model_info.status_message = f"Unknown benchit error: {e}"
model_info.status_message_color = printing.Colors.WARNING
invocation_info.status_message = f"Unknown benchit error: {e}"
invocation_info.status_message_color = printing.Colors.WARNING

_store_traceback(model_info)
_store_traceback(invocation_info)
finally:
# Ensure that stdout is not being forwarded before updating status
if hasattr(sys.stdout, "terminal"):
Expand Down Expand Up @@ -247,7 +251,28 @@ def call_benchit(
def get_model_hash(
model: Union[torch.nn.Module, "tf.keras.Model"], model_type: build.ModelType
):
return build.hash_model(model, model_type, hash_params=True)[:8]
return build.hash_model(model, model_type, hash_params=False)[:8]


def get_invocation_hash(
model_hash: str, parent_invocation_hash: str, args: Tuple, kwargs: Dict
) -> str:
"""
Combines the model hash and the input shapes to create the invocation hash
We also ensure that invocations that come from different parents have different hashes
"""

# Merge positional and keyword args
args = {"Positional Arg {}".format(i + 1): arg for i, arg in enumerate(args)}
kwargs = {**kwargs, **args}

# Get input shapes and types
input_shapes, input_dtypes = build.get_shapes_and_dtypes(kwargs)

hashable_content = (
f"{model_hash}{parent_invocation_hash}{input_shapes}{input_dtypes}"
)
return hashlib.sha256(hashable_content.encode()).hexdigest()[:8], input_shapes


def store_model_info(
Expand Down Expand Up @@ -292,7 +317,6 @@ def store_model_info(
depth=depth,
hash=model_hash,
parent_hash=parent_hash,
is_target=model_hash in tracer_args.targets or tracer_args.targets == [],
build_model=build_model,
model_type=model_type,
script_name=tracer_args.script_name,
Expand Down Expand Up @@ -421,11 +445,6 @@ def forward_spy(*args, **kwargs):
# do so by setting the max_depth flag.
return old_forward(*args, **kwargs)

# Keep track of execution time
start_time = time.time()
outputs = old_forward(*args, **kwargs)
end_time = time.time()

# We can only keep track of keras models once they have been executed
if model_type == build.ModelType.KERAS:
store_model_info(
Expand All @@ -438,22 +457,54 @@ def forward_spy(*args, **kwargs):
depth,
parent_hash,
)

# Get parent invocation hash
parent_invocation_hash = None
if parent_hash:
parent_invocation_hash = tracer_args.models_found[
parent_hash
].last_unique_invocation_executed

model_hash = get_model_hash(local_var, model_type)
invocation_hash, input_shapes = get_invocation_hash(
model_hash, parent_invocation_hash, args, kwargs
)
model_info = tracer_args.models_found[model_hash]
model_info.exec_time = model_info.exec_time + end_time - start_time

model_info.executed = model_info.executed + 1
if invocation_hash not in model_info.unique_invocations:
model_info.unique_invocations[
invocation_hash
] = util.UniqueInvocationInfo(
hash=invocation_hash,
is_target=invocation_hash in tracer_args.targets
or len(tracer_args.targets) == 0,
input_shapes=input_shapes,
parent_hash=parent_invocation_hash,
)
model_info.last_unique_invocation_executed = invocation_hash

# Keep track of execution time
start_time = time.time()
outputs = old_forward(*args, **kwargs)
end_time = time.time()

invocation_info = model_info.unique_invocations[invocation_hash]
invocation_info.exec_time = (
invocation_info.exec_time + end_time - start_time
)
invocation_info.executed = invocation_info.executed + 1

# Call groqit if this is the first time the model is being executed
# and this model has been selected by the user
if (
model_info.executed == 1
and model_info.is_target
invocation_info.executed == 1
and invocation_info.is_target
and (model_info.build_model)
):
call_benchit(
explore_invocation(
model_inputs=[args, kwargs],
model_info=model_info,
invocation_info=invocation_info,
tracer_args=tracer_args,
)
# Ensure that groqit() doesn't interfere with our execution count
Expand Down
Loading

0 comments on commit f35e27c

Please sign in to comment.