Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hash model inputs instead of parameters #324

Merged
merged 40 commits into from
Jun 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
e417eba
Basic input hashing
Jun 14, 2023
5b9b188
Showing workload status correctly
Jun 14, 2023
7997a9b
Showing workload hash rather than model hash
Jun 15, 2023
056b90b
Temporarily modifying docker file to enalbe CI
Jun 15, 2023
a03f183
Merge branch 'main' into robust_hashing
Jun 15, 2023
ca64739
Merge main into branch
Jun 15, 2023
7d29c5a
Revert fs changes
Jun 15, 2023
3c6cb9e
Robust shape extraction
Jun 15, 2023
d3fece5
Update ci test hash
Jun 15, 2023
9da81c9
Update analysis CI
Jun 15, 2023
8aa05f7
Add test
Jun 15, 2023
52c5e16
Updated dockerfile
Jun 15, 2023
6aaa2f0
Add requirement
Jun 16, 2023
e02df17
Added input shape to print
Jun 16, 2023
cc8f5b6
Merge branch 'main' into robust_hashing
jeremyfowers Jun 16, 2023
8593f8b
Simplify llama code
Jun 16, 2023
46993ed
Merge branch 'robust_hashing' of https://github.com/groq/mlagility in…
Jun 16, 2023
676036f
recursively printing for each model
Jun 16, 2023
f6fd7f1
Keeping track of parent workload hash
Jun 16, 2023
e462373
Correctly printing when max_depth is set
Jun 16, 2023
5f1ba1d
Ensure that hashes are different if they come from different workloads
Jun 16, 2023
417c701
Fix CI
Jun 17, 2023
b2be4bb
Correctly keeping track of last workload executed
Jun 20, 2023
c972930
Better UI
Jun 20, 2023
6abf7ac
Added test
Jun 20, 2023
0d34fa7
Renamed function as suggested
Jun 20, 2023
b4b6371
Change model to workload where appropriate
Jun 21, 2023
60862bd
Fix CI
Jun 21, 2023
cdfef11
Fix slurm CI
Jun 21, 2023
328938f
Revert "Fix slurm CI"
Jun 21, 2023
7c6668d
Revert "Fix CI"
Jun 21, 2023
2651512
Revert "Change model to workload where appropriate"
Jun 21, 2023
5d85530
Suggested changes
Jun 21, 2023
c67f140
Replacing the term workloads by invocations
Jun 21, 2023
336d563
Add documentation for new feature
Jun 21, 2023
6c06115
merge main into branch
Jun 21, 2023
e939b8f
Merge branch 'main' into robust_hashing
jeremyfowers Jun 22, 2023
17f0e69
Fix tutorial typo
jeremyfowers Jun 22, 2023
18aca16
Copy editing the multiple invokation tutorial
jeremyfowers Jun 22, 2023
daf13a8
Note issue in the code
jeremyfowers Jun 22, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .github/workflows/test_mlagility.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ jobs:
python -m pip install --upgrade pip
conda install pylint
if [ -f setup.py ]; then pip install -e .; fi
pip install transformers
pip install transformers timm
python -m pip check
- name: Lint with PyLint
shell: bash -el {0}
Expand All @@ -56,6 +56,8 @@ jobs:
rm -rf ~/.cache/mlagility
benchit examples/cli/scripts/hello_world.py
rm -rf ~/.cache/mlagility
benchit examples/cli/scripts/multiple_invocations.py
rm -rf ~/.cache/mlagility
benchit examples/cli/scripts/max_depth.py --max-depth 1
rm -rf ~/.cache/mlagility
benchit examples/cli/scripts/two_models.py
Expand Down
2 changes: 1 addition & 1 deletion examples/cli/discovery.md
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ You can see that `hello_world.py`, `two_models.py`, and `max_depth.py` are all e

> See the [Benchmark Multiple Scripts documentation](https://github.com/groq/mlagility/blob/main/docs/tools_user_guide.md#benchmark-multiple-scripts) for more details.

### Maximum Analysis Depth
## Maximum Analysis Depth

PyTorch models (eg, `torch.nn.Module`) are often built out of a collection of smaller instances. For example, a PyTorch multilayer perceptron (MLP) model may be built out of many `torch.nn.Linear` modules.

Expand Down
41 changes: 41 additions & 0 deletions examples/cli/readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ In this tutorial you will learn things such as:
- [A "hello world" example, which is the easiest way to get started](#hello-world)
- [Benchmarking on Nvidia GPUs](#nvidia-benchmarking)
- [Working with scripts that invoke more than one model](#multiple-models-per-script)
- [Working with scripts that invoke a model multiple times](#multiple-invocations-of-a-model)
- [Benchmarking an ONNX file](#onnx-benchmarking)

# Just Benchmark BERT
Expand Down Expand Up @@ -171,6 +172,46 @@ Woohoo! The 'benchmark' command is complete.

You can see that both model instances in `two_models.py`, `pytorch_model` and `another_pytorch_model`, are both discovered and benchmarked.

## Multiple Invocations of a Model

A single script may invoke the same model multiple times using different input shapes (e.g. when varying the batch size). When this happens, MLAgility will benchmark and display each of those invocations as sub-results of the same model instance.

> **Note**: multiple invocations of a model with the same input shape will only be benchmarked once.

The `multiple_invocations.py` script instantiates a single model and invokes it three times. The fist two times the model is invoked with inputs of the same shape (batch 1), while the third invocation uses a different input shape (batch 2). Note that two unique static model invocations are identified.

Run the following command:

```
benchit scripts/multiple_invocations.py
```

To get a result like:
```
Models discovered during profiling:

multiple_invocations.py:
pytorch_model
Model Type: Pytorch (torch.nn.Module)
Class: SmallModel (<class 'multiple_invocations.SmallModel'>)
Location: /net/home/dhnoronha/mlagility/examples/cli/scripts/multiple_invocations.py, line 40
Parameters: 60 (<0.1 MB)

With input shape 1 (executed 2x)
Input Shape: 'x': (1, 11)
Hash: b4aa73ae
Status: Successfully benchmarked on Intel(R) Xeon(R) CPU @ 2.20GHz (ort v1.14.1)
Mean Latency: 0.013 milliseconds (ms)
Throughput: 77909.6 inferences per second (IPS)

With input shape 2 (executed 1x)
Input Shape: 'x': (2, 11)
Hash: cfaa2e2c
Status: Successfully benchmarked on Intel(R) Xeon(R) CPU @ 2.20GHz (ort v1.14.1)
Mean Latency: 0.015 milliseconds (ms)
Throughput: 64938.1 inferences per second (IPS)
```

## ONNX Benchmarking

If you already happen to have an ONNX file, `benchit` can benchmark it for you. We can demonstrate this with the ONNX file in `examples/cli/onnx/sample.onnx`.
Expand Down
47 changes: 47 additions & 0 deletions examples/cli/scripts/multiple_invocations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# labels: name::multiple_invocations
"""
This example demonstrates what happens when your script contains
a model that is invoked multiple times with different input shapes

To try it, run:

benchit multiple_invocations.py

You should see the two unique invocations being identified.
"""
import torch

torch.manual_seed(1)

# Define model class
class SmallModel(torch.nn.Module):
def __init__(self, input_features, output_size):
super(SmallModel, self).__init__()
self.fc = torch.nn.Linear(input_features, output_size)

def forward(self, x):
# x has shape (batch_size, input_features)
# Set the batch size dimension to -1 to allow for flexibility
x = x.view(-1, x.size(1))

output = self.fc(x)

# Reshape the output to restore the original batch size dimension
output = output.view(-1, output_size)
return output


# Instantiate model and generate inputs
input_features = 11
output_size = 5
pytorch_model = SmallModel(input_features, output_size)

# Create 3 sets of inputs
batch_size = 1
inputs1 = {"x": torch.rand(batch_size, input_features)}
inputs2 = {"x": torch.rand(batch_size, input_features)}
inputs3 = {"x": torch.rand(batch_size + 1, input_features)}

pytorch_model(**inputs1)
pytorch_model(**inputs2)
pytorch_model(**inputs3)
9 changes: 1 addition & 8 deletions models/llm_layer/llama_layer_prototype.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,7 @@

def call_llama_layer(params="7B", use_cache=False):

# Use different torch seeds for KV caching vs. not, so that
# the models end up with different mlagility hashes
# Remove the if-statement when
# https://github.com/groq/mlagility/issues/316 is fixed
if use_cache:
torch.manual_seed(0)
else:
torch.manual_seed(1)
torch.manual_seed(0)

# Parsing command-line arguments
batch_size, max_seq_length = parse(["batch_size", "max_seq_length"])
Expand Down
127 changes: 89 additions & 38 deletions src/mlagility/analysis/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
import functools
import dataclasses
import traceback
from typing import Union, List, Dict
import hashlib
from typing import Union, List, Dict, Tuple
from types import FrameType, TracebackType
from enum import Enum
import torch
Expand Down Expand Up @@ -63,28 +64,31 @@ def torch_activations(self) -> List[str]:
return act


def _store_traceback(model_info: util.ModelInfo):
def _store_traceback(invocation_info: util.UniqueInvocationInfo):
"""
Store the traceback from an exception into model_info so that
Store the traceback from an exception into invocation_info so that
we can print it during the status update.
"""

exc_type, exc_value, exc_traceback = sys.exc_info()
model_info.traceback = traceback.format_exception(
invocation_info.traceback = traceback.format_exception(
exc_type, exc_value, exc_traceback
)


def call_benchit(
model_inputs: dict, model_info: util.ModelInfo, tracer_args: TracerArgs
def explore_invocation(
model_inputs: dict,
model_info: util.ModelInfo,
invocation_info: util.UniqueInvocationInfo,
tracer_args: TracerArgs,
) -> None:
"""
Calls the benchit function from within the model forward function
"""

# Update status to "computing"
model_info.status_message = "Computing..."
model_info.status_message_color = printing.Colors.OKBLUE
invocation_info.status_message = "Computing..."
invocation_info.status_message_color = printing.Colors.OKBLUE
status.update(tracer_args.models_found)

# Get a copy of the keyword arguments
Expand All @@ -111,10 +115,10 @@ def call_benchit(
inputs[all_args[i]] = torch.tensor(args[i].detach().numpy())
else:
inputs[all_args[i]] = args[i]
model_info.inputs = inputs
invocation_info.inputs = inputs

build_name = filesystem.get_build_name(
tracer_args.script_name, tracer_args.labels, model_info.hash
tracer_args.script_name, tracer_args.labels, invocation_info.hash
)

# Save model labels
Expand All @@ -124,12 +128,12 @@ def call_benchit(
perf = None
try:
if model_info.model_type == build.ModelType.PYTORCH_COMPILED:
model_info.status_message = (
invocation_info.status_message = (
"Skipping model compiled using torch.compile(). "
"benchit requires models to be in eager mode "
"(regardless of what runtime you have selected)."
)
model_info.status_message_color = printing.Colors.WARNING
invocation_info.status_message_color = printing.Colors.WARNING
else:
perf = benchmark_model(
model_info.model,
Expand All @@ -150,36 +154,36 @@ def call_benchit(
onnx_opset=tracer_args.onnx_opset,
)
if Action.BENCHMARK in tracer_args.actions:
model_info.status_message = "Model successfully benchmarked!"
model_info.performance = perf
model_info.status_message_color = printing.Colors.OKGREEN
invocation_info.status_message = "Model successfully benchmarked!"
invocation_info.performance = perf
invocation_info.status_message_color = printing.Colors.OKGREEN
else:
model_info.status_message = "Model successfully built!"
model_info.status_message_color = printing.Colors.OKGREEN
invocation_info.status_message = "Model successfully built!"
invocation_info.status_message_color = printing.Colors.OKGREEN

except exp.StageError:
build_state = build.load_state(
cache_dir=tracer_args.cache_dir, build_name=build_name
)
model_info.status_message = "Build Error: see log files for details."
model_info.status_message_color = printing.Colors.WARNING
invocation_info.status_message = "Build Error: see log files for details."
invocation_info.status_message_color = printing.Colors.WARNING

_store_traceback(model_info)
_store_traceback(invocation_info)

except exp.Error:
model_info.status_message = "GroqFlowError: see log files for details."
model_info.status_message_color = printing.Colors.WARNING
invocation_info.status_message = "GroqFlowError: see log files for details."
invocation_info.status_message_color = printing.Colors.WARNING

_store_traceback(model_info)
_store_traceback(invocation_info)

# This broad exception is ok since enumerating all exceptions is
# not possible, as the tested software continuously evolves.
except Exception as e: # pylint: disable=broad-except
util.stop_stdout_forward()
model_info.status_message = f"Unknown benchit error: {e}"
model_info.status_message_color = printing.Colors.WARNING
invocation_info.status_message = f"Unknown benchit error: {e}"
invocation_info.status_message_color = printing.Colors.WARNING

_store_traceback(model_info)
_store_traceback(invocation_info)
finally:
# Ensure that stdout is not being forwarded before updating status
if hasattr(sys.stdout, "terminal"):
Expand Down Expand Up @@ -247,7 +251,28 @@ def call_benchit(
def get_model_hash(
model: Union[torch.nn.Module, "tf.keras.Model"], model_type: build.ModelType
):
return build.hash_model(model, model_type, hash_params=True)[:8]
danielholanda marked this conversation as resolved.
Show resolved Hide resolved
return build.hash_model(model, model_type, hash_params=False)[:8]


def get_invocation_hash(
model_hash: str, parent_invocation_hash: str, args: Tuple, kwargs: Dict
) -> str:
"""
Combines the model hash and the input shapes to create the invocation hash
We also ensure that invocations that come from different parents have different hashes
"""

# Merge positional and keyword args
args = {"Positional Arg {}".format(i + 1): arg for i, arg in enumerate(args)}
kwargs = {**kwargs, **args}

# Get input shapes and types
input_shapes, input_dtypes = build.get_shapes_and_dtypes(kwargs)

hashable_content = (
f"{model_hash}{parent_invocation_hash}{input_shapes}{input_dtypes}"
)
return hashlib.sha256(hashable_content.encode()).hexdigest()[:8], input_shapes


def store_model_info(
Expand Down Expand Up @@ -292,7 +317,6 @@ def store_model_info(
depth=depth,
hash=model_hash,
parent_hash=parent_hash,
is_target=model_hash in tracer_args.targets or tracer_args.targets == [],
build_model=build_model,
model_type=model_type,
script_name=tracer_args.script_name,
Expand Down Expand Up @@ -421,11 +445,6 @@ def forward_spy(*args, **kwargs):
# do so by setting the max_depth flag.
return old_forward(*args, **kwargs)

# Keep track of execution time
start_time = time.time()
outputs = old_forward(*args, **kwargs)
end_time = time.time()

# We can only keep track of keras models once they have been executed
if model_type == build.ModelType.KERAS:
store_model_info(
Expand All @@ -438,22 +457,54 @@ def forward_spy(*args, **kwargs):
depth,
parent_hash,
)

# Get parent invocation hash
parent_invocation_hash = None
if parent_hash:
parent_invocation_hash = tracer_args.models_found[
parent_hash
].last_unique_invocation_executed

model_hash = get_model_hash(local_var, model_type)
invocation_hash, input_shapes = get_invocation_hash(
model_hash, parent_invocation_hash, args, kwargs
)
model_info = tracer_args.models_found[model_hash]
model_info.exec_time = model_info.exec_time + end_time - start_time

model_info.executed = model_info.executed + 1
if invocation_hash not in model_info.unique_invocations:
model_info.unique_invocations[
invocation_hash
] = util.UniqueInvocationInfo(
hash=invocation_hash,
is_target=invocation_hash in tracer_args.targets
or len(tracer_args.targets) == 0,
input_shapes=input_shapes,
parent_hash=parent_invocation_hash,
)
model_info.last_unique_invocation_executed = invocation_hash

# Keep track of execution time
start_time = time.time()
outputs = old_forward(*args, **kwargs)
end_time = time.time()

invocation_info = model_info.unique_invocations[invocation_hash]
invocation_info.exec_time = (
invocation_info.exec_time + end_time - start_time
)
invocation_info.executed = invocation_info.executed + 1

# Call groqit if this is the first time the model is being executed
# and this model has been selected by the user
if (
model_info.executed == 1
and model_info.is_target
invocation_info.executed == 1
and invocation_info.is_target
and (model_info.build_model)
):
call_benchit(
explore_invocation(
model_inputs=[args, kwargs],
model_info=model_info,
invocation_info=invocation_info,
tracer_args=tracer_args,
)
# Ensure that groqit() doesn't interfere with our execution count
Expand Down
Loading